123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- package controller
- import (
- "context"
- "encoding/json"
- "fmt"
- "github.com/gogf/gf/v2/frame/g"
- "github.com/gogf/gf/v2/net/ghttp"
- "github.com/gogf/gf/v2/os/glog"
- "github.com/gogf/gf/v2/util/guid"
- "github.com/google/uuid"
- v1 "go-gpt/api/v1"
- "go-gpt/gpt"
- )
- type Message struct {
- Text string `json:"text"`
- ConversationID string `json:"conversation_id"`
- MessageId string `json:"message_id,omitempty"`
- }
- var (
- Chat = sChat{}
- key = "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"
- proxyUrl = "127.0.0.1:6153"
- )
- type sChat struct {
- }
- // Chat 这里不使用流数据
- func (c *sChat) Chat(ctx context.Context, req *v1.ChatReq) (res *v1.ChatRes, err error) {
- var chatRequst gpt.ChatCompletionRequest
- var chatMessage []gpt.ChatCompletionRequestMessage
- re := g.RequestFromCtx(ctx)
- ConversationID := req.ConversationID
- MessageId := req.MessageId
- chatRequst.Model = gpt.GPT3Dot5Turbo
- chatRequst.MaxTokens = 1024
- chatRequst.Temperature = 0
- if req.ConversationID == "" || req.MessageId == "" {
- ConversationID = CreateSessionId(re)
- MessageId = uuid.NewString()
- }
- chatMessage = append(chatMessage, gpt.ChatCompletionRequestMessage{
- Role: req.Role,
- Content: req.Content,
- })
- chatRequst.Messages = chatMessage
- client := gpt.NewClient(key, proxyUrl)
- completion, err := client.ChatCompletion(ctx, &chatRequst)
- if err != nil {
- return nil, err
- }
- res = new(v1.ChatRes)
- if len(completion.Choices) > 0 {
- res.ConversationID = ConversationID
- res.MessageId = MessageId
- res.Text = completion.Choices[0].Message.Content
- }
- return
- }
- // ChatStream 这里使用流数据
- func (c *sChat) ChatStream(ctx context.Context, req *v1.ChatStreamReq) (res *v1.ChatStreamRes, err error) {
- re := g.RequestFromCtx(ctx)
- re.Response.Header().Set("Transfer-Encoding", "chunked")
- re.Response.Header().Set("Content-Type", "application/json")
- w := re.Response.Writer
- encoder := json.NewEncoder(w)
- // 初始化client
- client := gpt.NewClient(key, proxyUrl)
- err = client.ChatCompletionStream(ctx, &gpt.ChatCompletionRequest{
- Model: gpt.GPT3Dot5Turbo,
- Messages: []gpt.ChatCompletionRequestMessage{
- {
- Role: req.Role,
- Content: req.Content,
- },
- },
- MaxTokens: 150,
- Temperature: 0.8,
- }, func(response *gpt.ChatCompletionStreamResponse) {
- // 检查Choices字段是否存在
- if len(response.Choices) > 0 {
- // 提取Content字段
- content := response.Choices[0].Delta.Content
- // 将Content字段作为一个JSON对象发送给客户端
- message := Message{Text: content}
- if err := encoder.Encode(message); err != nil {
- return
- }
- re.Response.WriteHeader(200)
- w.Flush()
- } else {
- re.Exit()
- }
- })
- if err != nil {
- glog.Debug(ctx, err)
- return
- }
- return
- }
- func CreateSessionId(r *ghttp.Request) string {
- var (
- address = r.RemoteAddr
- header = fmt.Sprintf("%v", r.Header)
- )
- return guid.S([]byte(address), []byte(header))
- }
|