gpt.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. // Package gpt provides a client for the OpenAI GPT-3 API
  2. package gpt
  3. import (
  4. "bufio"
  5. "bytes"
  6. "context"
  7. "encoding/json"
  8. "fmt"
  9. "golang.org/x/net/proxy"
  10. "io"
  11. "net/http"
  12. "time"
  13. )
  14. // Define GPT-3 Engine Types
  15. const (
  16. TextAda001Engine = "text-ada-001" // TextAda001Engine Text Ada 001
  17. TextBabbage001Engine = "text-babbage-001" // TextBabbage001Engine Text Babbage 001
  18. TextCurie001Engine = "text-curie-001" // TextCurie001Engine Text Curie 001
  19. TextDavinci001Engine = "text-davinci-001" // TextDavinci001Engine Text Davinci 001
  20. TextDavinci002Engine = "text-davinci-002" // TextDavinci002Engine Text Davinci 002
  21. TextDavinci003Engine = "text-davinci-003" // TextDavinci003Engine Text Davinci 003
  22. AdaEngine = "ada" // AdaEngine Ada
  23. BabbageEngine = "babbage" // BabbageEngine Babbage
  24. CurieEngine = "curie" // CurieEngine Curie
  25. DavinciEngine = "davinci" // DavinciEngine Davinci
  26. DefaultEngine = DavinciEngine // DefaultEngine Default Engine
  27. )
  28. const (
  29. GPT4 = "gpt4" // GPT4 GPT-4
  30. GPT3Dot5Turbo = "gpt-3.5-turbo" // GPT3Dot5Turbo GPT-3.5 Turbo
  31. GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" // GPT3Dot5Turbo0301 GPT-3.5 Turbo 0301
  32. TextSimilarityAda001 = "text-similarity-ada-001" // TextSimilarityAda001 Text Similarity Ada 001
  33. TextSimilarityBabbage001 = "text-similarity-babbage-001" // TextSimilarityBabbage001 Text Similarity Babbage 001
  34. TextSimilarityCurie001 = "text-similarity-curie-001" // TextSimilarityCurie001 Text Similarity Curie 001
  35. TextSimilarityDavinci001 = "text-similarity-davinci-001" // TextSimilarityDavinci001 Text Similarity Davinci 001
  36. TextSearchAdaDoc001 = "text-search-ada-doc-001" // TextSearchAdaDoc001 Text Search Ada Doc 001
  37. TextSearchAdaQuery001 = "text-search-ada-query-001" // TextSearchAdaQuery001 Text Search Ada Query 001
  38. TextSearchBabbageDoc001 = "text-search-babbage-doc-001" // TextSearchBabbageDoc001 Text Search Babbage Doc 001
  39. TextSearchBabbageQuery001 = "text-search-babbage-query-001" // TextSearchBabbageQuery001 Text Search Babbage Query 001
  40. TextSearchCurieDoc001 = "text-search-curie-doc-001" // TextSearchCurieDoc001 Text Search Curie Doc 001
  41. TextSearchCurieQuery001 = "text-search-curie-query-001" // TextSearchCurieQuery001 Text Search Curie Query 001
  42. TextSearchDavinciDoc001 = "text-search-davinci-doc-001" // TextSearchDavinciDoc001 Text Search Davinci Doc 001
  43. TextSearchDavinciQuery001 = "text-search-davinci-query-001" // TextSearchDavinciQuery001 Text Search Davinci Query 001
  44. CodeSearchAdaCode001 = "code-search-ada-code-001" // CodeSearchAdaCode001 Code Search Ada Code 001
  45. CodeSearchAdaText001 = "code-search-ada-text-001" // CodeSearchAdaText001 Code Search Ada Text 001
  46. CodeSearchBabbageCode001 = "code-search-babbage-code-001" // CodeSearchBabbageCode001 Code Search Babbage Code 001
  47. CodeSearchBabbageText001 = "code-search-babbage-text-001" // CodeSearchBabbageText001 Code Search Babbage Text 001
  48. TextEmbeddingAda002 = "text-embedding-ada-002" // TextEmbeddingAda002 Text Embedding Ada 002
  49. )
  50. const (
  51. defaultBaseURL = "https://api.openai.com/v1"
  52. defaultUserAgent = "go-gpt3"
  53. defaultTimeoutSeconds = 30
  54. )
  55. // Image sizes defined by the OpenAI API.
  56. const (
  57. CreateImageSize256x256 = "256x256" // CreateImageSize256x256 256x256
  58. CreateImageSize512x512 = "512x512" // CreateImageSize512x512 512x512
  59. CreateImageSize1024x1024 = "1024x1024" // CreateImageSize1024x1024 1024x1024
  60. CreateImageResponseFormatURL = "url" // CreateImageResponseFormatURL URL
  61. CreateImageResponseFormatB64JSON = "b64_json" // CreateImageResponseFormatB64JSON B64 JSON
  62. )
  63. // Client is an API client to communicate with the OpenAI gpt-3 APIs
  64. type Client interface {
  65. // Engines lists the currently available engines, and provides basic information about each
  66. // option such as the owner and availability.
  67. Engines(ctx context.Context) (*EnginesResponse, error)
  68. // Engine retrieves an engine instance, providing basic information about the engine such
  69. // as the owner and availability.
  70. Engine(ctx context.Context, engine string) (*EngineObject, error)
  71. // ChatCompletion creates a completion with the Chat completion endpoint which
  72. // is what powers the ChatGPT experience.
  73. ChatCompletion(ctx context.Context, request *ChatCompletionRequest) (*ChatCompletionResponse, error)
  74. // ChatCompletionStream creates a completion with the Chat completion endpoint which
  75. // is what powers the ChatGPT experience.
  76. ChatCompletionStream(ctx context.Context, request *ChatCompletionRequest, onData func(*ChatCompletionStreamResponse)) error
  77. // Completion creates a completion with the default engine. This is the main endpoint of the API
  78. // which auto-completes based on the given prompt.
  79. Completion(ctx context.Context, request *CompletionRequest) (*CompletionResponse, error)
  80. // CompletionStream creates a completion with the default engine and streams the results through
  81. // multiple calls to onData.
  82. CompletionStream(ctx context.Context, request *CompletionRequest, onData func(*CompletionResponse)) error
  83. // CompletionWithEngine is the same as Completion except allows overriding the default engine on the client
  84. CompletionWithEngine(ctx context.Context, request *CompletionRequest) (*CompletionResponse, error)
  85. // CompletionStreamWithEngine is the same as CompletionStream allows overriding the default engine on the client
  86. CompletionStreamWithEngine(ctx context.Context, request *CompletionRequest, onData func(*CompletionResponse)) error
  87. // Edits is given a prompt and an instruction, the model will return an edited version of the prompt.
  88. Edits(ctx context.Context, request *EditsRequest) (*EditsResponse, error)
  89. // Search performs a semantic search over a list of documents with the default engine.
  90. Search(ctx context.Context, request *SearchRequest) (*SearchResponse, error)
  91. // SearchWithEngine performs a semantic search over a list of documents with the specified engine.
  92. SearchWithEngine(ctx context.Context, engine string, request *SearchRequest) (*SearchResponse, error)
  93. // Embeddings Returns an embedding using the provided request.
  94. Embeddings(ctx context.Context, request *EmbeddingsRequest) (*EmbeddingsResponse, error)
  95. // Image returns an image using the provided request.
  96. Image(ctx context.Context, request *ImageRequest) (*ImageResponse, error)
  97. }
  98. type client struct {
  99. baseURL string
  100. apiKey string
  101. userAgent string
  102. httpClient *http.Client
  103. defaultEngine string
  104. idOrg string
  105. }
  106. // NewClient returns a new OpenAI GPT-3 API client. An APIKey is required to use the client
  107. func NewClient(apiKey string, options ...ClientOption) Client {
  108. // Configure the SOCKS5 proxy
  109. dialer, err := proxy.SOCKS5("tcp", "127.0.0.1:6153", nil, proxy.Direct)
  110. if err != nil {
  111. return nil
  112. }
  113. httpClient := &http.Client{
  114. Timeout: defaultTimeoutSeconds * time.Second,
  115. Transport: &http.Transport{
  116. Dial: dialer.Dial, // Use Dial instead of DialContext
  117. },
  118. }
  119. cli := &client{
  120. userAgent: defaultUserAgent,
  121. apiKey: apiKey,
  122. baseURL: defaultBaseURL,
  123. httpClient: httpClient,
  124. defaultEngine: DefaultEngine,
  125. idOrg: "",
  126. }
  127. for _, opt := range options {
  128. cli = opt.apply(cli)
  129. }
  130. return cli
  131. }
  132. // Engines lists the currently available engines, and provides basic information about each
  133. // option such as the owner and availability.
  134. func (c *client) Engines(ctx context.Context) (*EnginesResponse, error) {
  135. req, err := c.newRequest(ctx, "GET", "/engines", nil)
  136. if err != nil {
  137. return nil, err
  138. }
  139. rsp, err := c.performRequest(req)
  140. if err != nil {
  141. return nil, err
  142. }
  143. output := new(EnginesResponse)
  144. if err := getResponseObject(rsp, output); err != nil {
  145. return nil, err
  146. }
  147. return output, nil
  148. }
  149. // Engine retrieves an engine instance, providing basic information about the engine such
  150. // as the owner and availability.
  151. func (c *client) Engine(ctx context.Context, engine string) (*EngineObject, error) {
  152. req, err := c.newRequest(ctx, "GET", fmt.Sprintf("/engines/%s", engine), nil)
  153. if err != nil {
  154. return nil, err
  155. }
  156. rsp, err := c.performRequest(req)
  157. if err != nil {
  158. return nil, err
  159. }
  160. output := new(EngineObject)
  161. if err := getResponseObject(rsp, output); err != nil {
  162. return nil, err
  163. }
  164. return output, nil
  165. }
  166. // ChatCompletion creates a completion with the Chat completion endpoint which
  167. // is what powers the ChatGPT experience.
  168. func (c *client) ChatCompletion(ctx context.Context, request *ChatCompletionRequest) (*ChatCompletionResponse, error) {
  169. if request.Model == "" {
  170. request.Model = GPT3Dot5Turbo
  171. }
  172. request.Stream = false
  173. req, err := c.newRequest(ctx, "POST", "/chat/completions", &request)
  174. if err != nil {
  175. return nil, err
  176. }
  177. rsp, err := c.performRequest(req)
  178. if err != nil {
  179. return nil, err
  180. }
  181. output := new(ChatCompletionResponse)
  182. if err := getResponseObject(rsp, output); err != nil {
  183. return nil, err
  184. }
  185. return output, nil
  186. }
  187. // ChatCompletionStream creates a completion with the Chat completion endpoint which
  188. // is what powers the ChatGPT experience.
  189. func (c *client) ChatCompletionStream(ctx context.Context, request *ChatCompletionRequest, onData func(*ChatCompletionStreamResponse)) error {
  190. if request.Model == "" {
  191. request.Model = GPT3Dot5Turbo
  192. }
  193. request.Stream = true
  194. req, err := c.newRequest(ctx, "POST", "/chat/completions", request)
  195. if err != nil {
  196. return err
  197. }
  198. rsp, err := c.performRequest(req)
  199. if err != nil {
  200. return err
  201. }
  202. reader := bufio.NewReader(rsp.Body)
  203. defer rsp.Body.Close()
  204. for {
  205. line, err := reader.ReadBytes('\n')
  206. if err != nil {
  207. return err
  208. }
  209. // make sure there isn't any extra whitespace before or after
  210. line = bytes.TrimSpace(line)
  211. // the completion API only returns data events
  212. if !bytes.HasPrefix(line, dataPrefix) {
  213. continue
  214. }
  215. line = bytes.TrimPrefix(line, dataPrefix)
  216. // the stream is completed when terminated by [DONE]
  217. if bytes.HasPrefix(line, doneSequence) {
  218. break
  219. }
  220. output := new(ChatCompletionStreamResponse)
  221. if err := json.Unmarshal(line, output); err != nil {
  222. return fmt.Errorf("invalid json stream data: %v", err)
  223. }
  224. onData(output)
  225. }
  226. return nil
  227. }
  228. // Completion creates a completion with the default engine.
  229. func (c *client) Completion(ctx context.Context, request *CompletionRequest) (*CompletionResponse, error) {
  230. return c.CompletionWithEngine(ctx, request)
  231. }
  232. // CompletionWithEngine creates a completion with the specified engine.
  233. func (c *client) CompletionWithEngine(ctx context.Context, request *CompletionRequest) (*CompletionResponse, error) {
  234. request.Stream = false
  235. req, err := c.newRequest(ctx, "POST", "/completions", &request)
  236. if err != nil {
  237. return nil, err
  238. }
  239. rsp, err := c.performRequest(req)
  240. if err != nil {
  241. return nil, err
  242. }
  243. output := new(CompletionResponse)
  244. if err := getResponseObject(rsp, output); err != nil {
  245. return nil, err
  246. }
  247. return output, nil
  248. }
  249. // CompletionStream creates a completion with the default engine.
  250. func (c *client) CompletionStream(ctx context.Context, request *CompletionRequest,
  251. onData func(*CompletionResponse)) error {
  252. return c.CompletionStreamWithEngine(ctx, request, onData)
  253. }
  254. var (
  255. dataPrefix = []byte("data: ")
  256. doneSequence = []byte("[DONE]")
  257. )
  258. // CompletionStreamWithEngine creates a completion with the specified engine.
  259. func (c *client) CompletionStreamWithEngine(ctx context.Context, request *CompletionRequest,
  260. onData func(*CompletionResponse)) error {
  261. request.Stream = true
  262. req, err := c.newRequest(ctx, "POST", "/completions", &request)
  263. if err != nil {
  264. return err
  265. }
  266. rsp, err := c.performRequest(req)
  267. if err != nil {
  268. return err
  269. }
  270. reader := bufio.NewReader(rsp.Body)
  271. defer rsp.Body.Close()
  272. for {
  273. line, err := reader.ReadBytes('\n')
  274. if err != nil {
  275. return err
  276. }
  277. // make sure there isn't any extra whitespace before or after
  278. line = bytes.TrimSpace(line)
  279. // the completion API only returns data events
  280. if !bytes.HasPrefix(line, dataPrefix) {
  281. continue
  282. }
  283. line = bytes.TrimPrefix(line, dataPrefix)
  284. // the stream is completed when terminated by [DONE]
  285. if bytes.HasPrefix(line, doneSequence) {
  286. break
  287. }
  288. output := new(CompletionResponse)
  289. if err := json.Unmarshal(line, output); err != nil {
  290. return fmt.Errorf("invalid json stream data: %v", err)
  291. }
  292. onData(output)
  293. }
  294. return nil
  295. }
  296. // Edits is given a prompt and an instruction, the model will return an edited version of the prompt.
  297. func (c *client) Edits(ctx context.Context, request *EditsRequest) (*EditsResponse, error) {
  298. req, err := c.newRequest(ctx, "POST", "/edits", &request)
  299. if err != nil {
  300. return nil, err
  301. }
  302. rsp, err := c.performRequest(req)
  303. if err != nil {
  304. return nil, err
  305. }
  306. output := new(EditsResponse)
  307. if err := getResponseObject(rsp, output); err != nil {
  308. return nil, err
  309. }
  310. return output, nil
  311. }
  312. // Search creates a search with the default engine.
  313. func (c *client) Search(ctx context.Context, request *SearchRequest) (*SearchResponse, error) {
  314. return c.SearchWithEngine(ctx, c.defaultEngine, request)
  315. }
  316. // SearchWithEngine performs a semantic search over a list of documents with the specified engine.
  317. func (c *client) SearchWithEngine(ctx context.Context, engine string, request *SearchRequest) (*SearchResponse, error) {
  318. req, err := c.newRequest(ctx, "POST", fmt.Sprintf("/engines/%s/search", engine), &request)
  319. if err != nil {
  320. return nil, err
  321. }
  322. rsp, err := c.performRequest(req)
  323. if err != nil {
  324. return nil, err
  325. }
  326. output := new(SearchResponse)
  327. if err := getResponseObject(rsp, output); err != nil {
  328. return nil, err
  329. }
  330. return output, nil
  331. }
  332. // Embeddings creates text embeddings for a supplied slice of inputs with a provided model.
  333. // See: https://beta.openai.com/docs/api-reference/embeddings
  334. func (c *client) Embeddings(ctx context.Context, request *EmbeddingsRequest) (*EmbeddingsResponse, error) {
  335. req, err := c.newRequest(ctx, "POST", "/embeddings", &request)
  336. if err != nil {
  337. return nil, err
  338. }
  339. rsp, err := c.performRequest(req)
  340. if err != nil {
  341. return nil, err
  342. }
  343. output := EmbeddingsResponse{}
  344. if err := getResponseObject(rsp, &output); err != nil {
  345. return nil, err
  346. }
  347. return &output, nil
  348. }
  349. // Image creates an image
  350. func (c *client) Image(ctx context.Context, request *ImageRequest) (*ImageResponse, error) {
  351. req, err := c.newRequest(ctx, "POST", "/images/generations", &request)
  352. if err != nil {
  353. return nil, err
  354. }
  355. rsp, err := c.performRequest(req)
  356. if err != nil {
  357. return nil, err
  358. }
  359. output := ImageResponse{}
  360. if err := getResponseObject(rsp, &output); err != nil {
  361. return nil, err
  362. }
  363. return &output, nil
  364. }
  365. func (c *client) performRequest(req *http.Request) (*http.Response, error) {
  366. rsp, err := c.httpClient.Do(req)
  367. if err != nil {
  368. return nil, err
  369. }
  370. if err := checkForSuccess(rsp); err != nil {
  371. return nil, err
  372. }
  373. return rsp, nil
  374. }
  375. // checkForSuccess returns an error if this response includes an error.
  376. func checkForSuccess(rsp *http.Response) error {
  377. if rsp.StatusCode >= 200 && rsp.StatusCode < 300 {
  378. return nil
  379. }
  380. defer rsp.Body.Close()
  381. data, err := io.ReadAll(rsp.Body)
  382. if err != nil {
  383. return fmt.Errorf("failed to read from body: %w", err)
  384. }
  385. var result APIErrorResponse
  386. if err := json.Unmarshal(data, &result); err != nil {
  387. // if we can't decode the json error then create an unexpected error
  388. apiError := APIError{
  389. StatusCode: rsp.StatusCode,
  390. Type: "Unexpected",
  391. Message: string(data),
  392. }
  393. return apiError
  394. }
  395. result.Error.StatusCode = rsp.StatusCode
  396. return result.Error
  397. }
  398. func getResponseObject(rsp *http.Response, v interface{}) error {
  399. defer rsp.Body.Close()
  400. if err := json.NewDecoder(rsp.Body).Decode(v); err != nil {
  401. return fmt.Errorf("invalid json response: %w", err)
  402. }
  403. return nil
  404. }
  405. func jsonBodyReader(body interface{}) (io.Reader, error) {
  406. if body == nil {
  407. return bytes.NewBuffer(nil), nil
  408. }
  409. raw, err := json.Marshal(body)
  410. if err != nil {
  411. return nil, fmt.Errorf("failed encoding json: %w", err)
  412. }
  413. return bytes.NewBuffer(raw), nil
  414. }
  415. func (c *client) newRequest(ctx context.Context, method, path string, payload interface{}) (*http.Request, error) {
  416. bodyReader, err := jsonBodyReader(payload)
  417. if err != nil {
  418. return nil, err
  419. }
  420. url := c.baseURL + path
  421. req, err := http.NewRequestWithContext(ctx, method, url, bodyReader)
  422. if err != nil {
  423. return nil, err
  424. }
  425. if len(c.idOrg) > 0 {
  426. req.Header.Set("OpenAI-Organization", c.idOrg)
  427. }
  428. req.Header.Set("Content-type", "application/json")
  429. req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
  430. return req, nil
  431. }