cauto преди 1 година
родител
ревизия
c7475f9400
променени са 8 файла, в които са добавени 202 реда и са изтрити 84 реда
  1. 23 0
      api/v1/Chat.go
  2. 12 0
      api/v1/Session.go
  3. 18 11
      gpt/gpt.go
  4. 96 0
      internal/controller/chat.go
  5. 0 49
      internal/controller/hello.go
  6. 19 0
      internal/controller/session.go
  7. 0 21
      internal/service/middleware.go
  8. 34 3
      router/router.go

+ 23 - 0
api/v1/Chat.go

@@ -0,0 +1,23 @@
+package v1
+
+import "github.com/gogf/gf/v2/frame/g"
+
+type ChatReq struct {
+	g.Meta  `path:"/chat" tags:"聊天请求" method:"post" summary:"AI聊天"`
+	Role    string `json:"role" v:"required"`
+	Content string `json:"content" v:"required"`
+}
+
+type ChatRes struct {
+	g.Meta `mime:"application/json" example:"json"`
+	Text   string `json:"text" `
+}
+
+type ChatStreamReq struct {
+	g.Meta  `path:"/chatStream" tags:"聊天请求流传输" method:"post" summary:"AI聊天流传输"`
+	Role    string `json:"role" v:"required"`
+	Content string `json:"content" v:"required"`
+}
+
+type ChatStreamRes struct {
+}

+ 12 - 0
api/v1/Session.go

@@ -0,0 +1,12 @@
+package v1
+
+import "github.com/gogf/gf/v2/frame/g"
+
+type SessionReq struct {
+	g.Meta `path:"/session" tags:"会话" method:"get" summary:"会话获取并返回token"`
+}
+
+type SessionRes struct {
+	g.Meta `mime:"application/json" example:"json"`
+	Token  string `json:"token" `
+}

+ 18 - 11
gpt/gpt.go

@@ -125,20 +125,27 @@ type client struct {
 }
 
 // NewClient returns a new OpenAI GPT-3 API client. An APIKey is required to use the client
-func NewClient(apiKey string, options ...ClientOption) Client {
+func NewClient(apiKey string, proxyUrl string, options ...ClientOption) Client {
+	var httpClient *http.Client
+	if proxyUrl != "" {
+		// Configure the SOCKS5 proxy
+		dialer, err := proxy.SOCKS5("tcp", proxyUrl, nil, proxy.Direct)
+		if err != nil {
+			return nil
+		}
 
-	// Configure the SOCKS5 proxy
-	dialer, err := proxy.SOCKS5("tcp", "127.0.0.1:6153", nil, proxy.Direct)
-	if err != nil {
-		return nil
+		httpClient = &http.Client{
+			Timeout: defaultTimeoutSeconds * time.Second,
+			Transport: &http.Transport{
+				Dial: dialer.Dial, // Use Dial instead of DialContext
+			},
+		}
+	} else {
+		httpClient = &http.Client{
+			Timeout: defaultTimeoutSeconds * time.Second,
+		}
 	}
 
-	httpClient := &http.Client{
-		Timeout: defaultTimeoutSeconds * time.Second,
-		Transport: &http.Transport{
-			Dial: dialer.Dial, // Use Dial instead of DialContext
-		},
-	}
 	cli := &client{
 		userAgent:     defaultUserAgent,
 		apiKey:        apiKey,

+ 96 - 0
internal/controller/chat.go

@@ -0,0 +1,96 @@
+package controller
+
+import (
+	"context"
+	"encoding/json"
+	"github.com/gogf/gf/v2/frame/g"
+	"github.com/gogf/gf/v2/os/glog"
+	v1 "go-gpt/api/v1"
+	"go-gpt/gpt"
+)
+
+type Message struct {
+	Text string `json:"text"`
+}
+
+var (
+	Chat     = sChat{}
+	key      = "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"
+	proxyUrl = "127.0.0.1:6153"
+)
+
+type sChat struct {
+}
+
+// Chat 这里不使用流数据
+func (c *sChat) Chat(ctx context.Context, req *v1.ChatReq) (res *v1.ChatRes, err error) {
+
+	client := gpt.NewClient(key, proxyUrl)
+	completion, err := client.ChatCompletion(ctx, &gpt.ChatCompletionRequest{
+		Model: gpt.GPT3Dot5Turbo,
+		Messages: []gpt.ChatCompletionRequestMessage{
+			{
+				Role:    req.Role,
+				Content: req.Content,
+			},
+		},
+		MaxTokens:   2048,
+		Temperature: 0.8,
+	})
+	if err != nil {
+		return nil, err
+	}
+	res = new(v1.ChatRes)
+	if len(completion.Choices) > 0 {
+		res.Text = completion.Choices[0].Message.Content
+	}
+
+	return
+}
+
+// ChatStream 这里使用流数据
+func (c *sChat) ChatStream(ctx context.Context, req *v1.ChatStreamReq) (res *v1.ChatStreamRes, err error) {
+
+	re := g.RequestFromCtx(ctx)
+	re.Response.Header().Set("Transfer-Encoding", "chunked")
+	re.Response.Header().Set("Content-Type", "application/json")
+	w := re.Response.Writer
+	encoder := json.NewEncoder(w)
+
+	// 初始化client
+	client := gpt.NewClient(key, proxyUrl)
+	err = client.ChatCompletionStream(ctx, &gpt.ChatCompletionRequest{
+		Model: gpt.GPT3Dot5Turbo,
+		Messages: []gpt.ChatCompletionRequestMessage{
+			{
+				Role:    req.Role,
+				Content: req.Content,
+			},
+		},
+		MaxTokens:   150,
+		Temperature: 0.8,
+	}, func(response *gpt.ChatCompletionStreamResponse) {
+
+		// 检查Choices字段是否存在
+		if len(response.Choices) > 0 {
+			// 提取Content字段
+			content := response.Choices[0].Delta.Content
+			// 将Content字段作为一个JSON对象发送给客户端
+			message := Message{Text: content}
+			if err := encoder.Encode(message); err != nil {
+				return
+			}
+			re.Response.WriteHeader(200)
+			w.Flush()
+		} else {
+			re.Exit()
+		}
+	})
+
+	if err != nil {
+		glog.Debug(ctx, err)
+		return
+	}
+
+	return
+}

+ 0 - 49
internal/controller/hello.go

@@ -2,12 +2,6 @@ package controller
 
 import (
 	"context"
-	"encoding/json"
-
-	"github.com/gogf/gf/v2/frame/g"
-	"github.com/gogf/gf/v2/os/glog"
-	"go-gpt/gpt"
-
 	"go-gpt/api/v1"
 )
 
@@ -15,52 +9,9 @@ var (
 	Hello = cHello{}
 )
 
-type Message struct {
-	Content string `json:"content"`
-}
 type cHello struct{}
 
 func (c *cHello) Hello(ctx context.Context, req *v1.HelloReq) (res *v1.HelloRes, err error) {
-	re := g.RequestFromCtx(ctx)
-	re.Response.Header().Set("Transfer-Encoding", "chunked")
-	re.Response.Header().Set("Content-Type", "application/json")
-	w := re.Response.Writer
-	encoder := json.NewEncoder(w)
-
-	// 初始化client
-	client := gpt.NewClient("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")
-	err = client.ChatCompletionStream(ctx, &gpt.ChatCompletionRequest{
-		Model: gpt.GPT3Dot5Turbo,
-		Messages: []gpt.ChatCompletionRequestMessage{
-			{
-				Role:    "user",
-				Content: "编写一个c++冒泡算法",
-			},
-		},
-		MaxTokens:   50,
-		Temperature: 0,
-	}, func(response *gpt.ChatCompletionStreamResponse) {
-
-		// 检查Choices字段是否存在
-		if len(response.Choices) > 0 {
-			// 提取Content字段
-			content := response.Choices[0].Delta.Content
-			// 将Content字段作为一个JSON对象发送给客户端
-			message := Message{Content: content}
-			if err := encoder.Encode(message); err != nil {
-				return
-			}
-			re.Response.WriteHeader(200)
-			w.Flush()
-		} else {
-			re.Exit()
-		}
-	})
-
-	if err != nil {
-		glog.Debug(ctx, err)
-		return
-	}
 
 	return
 }

+ 19 - 0
internal/controller/session.go

@@ -0,0 +1,19 @@
+package controller
+
+import (
+	"context"
+	v1 "go-gpt/api/v1"
+)
+
+var (
+	Session = cSession{}
+)
+
+type cSession struct {
+}
+
+// Session 生成token 并返回
+func (c *cSession) Session(ctx context.Context, req *v1.SessionReq) (res *v1.SessionRes, err error) {
+
+	return
+}

+ 0 - 21
internal/service/middleware.go

@@ -1,21 +0,0 @@
-package service
-
-import "github.com/gogf/gf/v2/net/ghttp"
-
-type middlewareService struct{}
-
-var middleware = middlewareService{}
-
-func Middleware() *middlewareService {
-	return &middleware
-}
-
-func (s *middlewareService) CORS(r *ghttp.Request) {
-	r.Response.CORSDefault()
-	r.Middleware.Next()
-}
-
-func (s *middlewareService) Auth(r *ghttp.Request) {
-	Auth().MiddlewareFunc()(r)
-	r.Middleware.Next()
-}

+ 34 - 3
router/router.go

@@ -6,13 +6,14 @@ import (
 	"github.com/gogf/gf/v2/net/ghttp"
 	"github.com/gogf/gf/v2/os/glog"
 	"go-gpt/internal/controller"
+	"go-gpt/internal/service"
 	"net/http"
 	"sync"
 	"time"
 )
 
 type DefaultHandlerResponse struct {
-	Code    int         `json:"code"`
+	Code    int         `json:"status"`
 	Message string      `json:"message"`
 	Data    interface{} `json:"data"`
 }
@@ -62,6 +63,17 @@ func MiddlewareHandlerResponse(r *ghttp.Request) {
 	})
 }
 
+type middlewareService struct{}
+
+func NewTokenMiddleware() *middlewareService {
+	return &middlewareService{}
+}
+
+func (s *middlewareService) Auth(r *ghttp.Request) {
+	service.Auth().MiddlewareFunc()(r)
+	r.Middleware.Next()
+}
+
 type RateLimitMiddleware struct {
 	sync.Mutex
 	Counter map[string]int64
@@ -107,13 +119,32 @@ func BindController(group *ghttp.RouterGroup) {
 	DomeRouter(group)
 	group.Group("/api/v1", func(group *ghttp.RouterGroup) {
 		vv := NewRateLimitMiddleware()
-		group.Middleware(ghttp.MiddlewareHandlerResponse, MiddlewareCORS, vv.Middleware)
-
+		group.Middleware(MiddlewareCORS, vv.Middleware)
+		SessionRouter(group)
+		ChatRouter(group)
 	})
 }
 
 func DomeRouter(group *ghttp.RouterGroup) {
+	//tokenMi := NewTokenMiddleware()
+	//group.Middleware(tokenMi.Auth)
 	group.Bind(
 		controller.Hello,
 	)
 }
+
+func ChatRouter(group *ghttp.RouterGroup) {
+	group.Middleware(MiddlewareHandlerResponse)
+	group.Bind(
+		controller.Chat,
+	)
+}
+
+func SessionRouter(group *ghttp.RouterGroup) {
+	group.Middleware(MiddlewareHandlerResponse)
+	group.Group("/session", func(group *ghttp.RouterGroup) {
+		group.Bind(
+			controller.Session,
+		)
+	})
+}