chat.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. package controller
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/gogf/gf/v2/frame/g"
  7. "github.com/gogf/gf/v2/net/ghttp"
  8. "github.com/gogf/gf/v2/os/glog"
  9. "github.com/gogf/gf/v2/util/guid"
  10. "github.com/google/uuid"
  11. v1 "go-gpt/api/v1"
  12. "go-gpt/gpt"
  13. )
  14. type Message struct {
  15. Text string `json:"text"`
  16. ConversationID string `json:"conversation_id"`
  17. MessageId string `json:"message_id,omitempty"`
  18. }
  19. var (
  20. Chat = sChat{}
  21. key = "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"
  22. proxyUrl = "127.0.0.1:6153"
  23. )
  24. type sChat struct {
  25. }
  26. // Chat 这里不使用流数据
  27. func (c *sChat) Chat(ctx context.Context, req *v1.ChatReq) (res *v1.ChatRes, err error) {
  28. var chatRequst gpt.ChatCompletionRequest
  29. var chatMessage []gpt.ChatCompletionRequestMessage
  30. re := g.RequestFromCtx(ctx)
  31. ConversationID := req.ConversationID
  32. MessageId := req.MessageId
  33. chatRequst.Model = gpt.GPT3Dot5Turbo
  34. chatRequst.MaxTokens = 1024
  35. chatRequst.Temperature = 0
  36. if req.ConversationID == "" || req.MessageId == "" {
  37. ConversationID = CreateSessionId(re)
  38. MessageId = uuid.NewString()
  39. }
  40. chatMessage = append(chatMessage, gpt.ChatCompletionRequestMessage{
  41. Role: req.Role,
  42. Content: req.Content,
  43. })
  44. chatRequst.Messages = chatMessage
  45. client := gpt.NewClient(key, proxyUrl)
  46. completion, err := client.ChatCompletion(ctx, &chatRequst)
  47. if err != nil {
  48. return nil, err
  49. }
  50. res = new(v1.ChatRes)
  51. if len(completion.Choices) > 0 {
  52. res.ConversationID = ConversationID
  53. res.MessageId = MessageId
  54. res.Text = completion.Choices[0].Message.Content
  55. }
  56. return
  57. }
  58. // ChatStream 这里使用流数据
  59. func (c *sChat) ChatStream(ctx context.Context, req *v1.ChatStreamReq) (res *v1.ChatStreamRes, err error) {
  60. re := g.RequestFromCtx(ctx)
  61. re.Response.Header().Set("Transfer-Encoding", "chunked")
  62. re.Response.Header().Set("Content-Type", "application/json")
  63. w := re.Response.Writer
  64. encoder := json.NewEncoder(w)
  65. // 初始化client
  66. client := gpt.NewClient(key, proxyUrl)
  67. err = client.ChatCompletionStream(ctx, &gpt.ChatCompletionRequest{
  68. Model: gpt.GPT3Dot5Turbo,
  69. Messages: []gpt.ChatCompletionRequestMessage{
  70. {
  71. Role: req.Role,
  72. Content: req.Content,
  73. },
  74. },
  75. MaxTokens: 150,
  76. Temperature: 0.8,
  77. }, func(response *gpt.ChatCompletionStreamResponse) {
  78. // 检查Choices字段是否存在
  79. if len(response.Choices) > 0 {
  80. // 提取Content字段
  81. content := response.Choices[0].Delta.Content
  82. // 将Content字段作为一个JSON对象发送给客户端
  83. message := Message{Text: content}
  84. if err := encoder.Encode(message); err != nil {
  85. return
  86. }
  87. re.Response.WriteHeader(200)
  88. w.Flush()
  89. } else {
  90. re.Exit()
  91. }
  92. })
  93. if err != nil {
  94. glog.Debug(ctx, err)
  95. return
  96. }
  97. return
  98. }
  99. func CreateSessionId(r *ghttp.Request) string {
  100. var (
  101. address = r.RemoteAddr
  102. header = fmt.Sprintf("%v", r.Header)
  103. )
  104. return guid.S([]byte(address), []byte(header))
  105. }