gpt.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. // Package gpt provides a client for the OpenAI GPT-3 API
  2. package gpt
  3. import (
  4. "bufio"
  5. "bytes"
  6. "context"
  7. "encoding/json"
  8. "fmt"
  9. "io"
  10. "net/http"
  11. "time"
  12. )
  13. // Define GPT-3 Engine Types
  14. const (
  15. TextAda001Engine = "text-ada-001" // TextAda001Engine Text Ada 001
  16. TextBabbage001Engine = "text-babbage-001" // TextBabbage001Engine Text Babbage 001
  17. TextCurie001Engine = "text-curie-001" // TextCurie001Engine Text Curie 001
  18. TextDavinci001Engine = "text-davinci-001" // TextDavinci001Engine Text Davinci 001
  19. TextDavinci002Engine = "text-davinci-002" // TextDavinci002Engine Text Davinci 002
  20. TextDavinci003Engine = "text-davinci-003" // TextDavinci003Engine Text Davinci 003
  21. AdaEngine = "ada" // AdaEngine Ada
  22. BabbageEngine = "babbage" // BabbageEngine Babbage
  23. CurieEngine = "curie" // CurieEngine Curie
  24. DavinciEngine = "davinci" // DavinciEngine Davinci
  25. DefaultEngine = DavinciEngine // DefaultEngine Default Engine
  26. )
  27. const (
  28. GPT4 = "gpt4" // GPT4 GPT-4
  29. GPT3Dot5Turbo = "gpt-3.5-turbo" // GPT3Dot5Turbo GPT-3.5 Turbo
  30. GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" // GPT3Dot5Turbo0301 GPT-3.5 Turbo 0301
  31. TextSimilarityAda001 = "text-similarity-ada-001" // TextSimilarityAda001 Text Similarity Ada 001
  32. TextSimilarityBabbage001 = "text-similarity-babbage-001" // TextSimilarityBabbage001 Text Similarity Babbage 001
  33. TextSimilarityCurie001 = "text-similarity-curie-001" // TextSimilarityCurie001 Text Similarity Curie 001
  34. TextSimilarityDavinci001 = "text-similarity-davinci-001" // TextSimilarityDavinci001 Text Similarity Davinci 001
  35. TextSearchAdaDoc001 = "text-search-ada-doc-001" // TextSearchAdaDoc001 Text Search Ada Doc 001
  36. TextSearchAdaQuery001 = "text-search-ada-query-001" // TextSearchAdaQuery001 Text Search Ada Query 001
  37. TextSearchBabbageDoc001 = "text-search-babbage-doc-001" // TextSearchBabbageDoc001 Text Search Babbage Doc 001
  38. TextSearchBabbageQuery001 = "text-search-babbage-query-001" // TextSearchBabbageQuery001 Text Search Babbage Query 001
  39. TextSearchCurieDoc001 = "text-search-curie-doc-001" // TextSearchCurieDoc001 Text Search Curie Doc 001
  40. TextSearchCurieQuery001 = "text-search-curie-query-001" // TextSearchCurieQuery001 Text Search Curie Query 001
  41. TextSearchDavinciDoc001 = "text-search-davinci-doc-001" // TextSearchDavinciDoc001 Text Search Davinci Doc 001
  42. TextSearchDavinciQuery001 = "text-search-davinci-query-001" // TextSearchDavinciQuery001 Text Search Davinci Query 001
  43. CodeSearchAdaCode001 = "code-search-ada-code-001" // CodeSearchAdaCode001 Code Search Ada Code 001
  44. CodeSearchAdaText001 = "code-search-ada-text-001" // CodeSearchAdaText001 Code Search Ada Text 001
  45. CodeSearchBabbageCode001 = "code-search-babbage-code-001" // CodeSearchBabbageCode001 Code Search Babbage Code 001
  46. CodeSearchBabbageText001 = "code-search-babbage-text-001" // CodeSearchBabbageText001 Code Search Babbage Text 001
  47. TextEmbeddingAda002 = "text-embedding-ada-002" // TextEmbeddingAda002 Text Embedding Ada 002
  48. )
  49. const (
  50. defaultBaseURL = "https://api.openai.com/v1"
  51. defaultUserAgent = "go-gpt3"
  52. defaultTimeoutSeconds = 30
  53. )
  54. // Image sizes defined by the OpenAI API.
  55. const (
  56. CreateImageSize256x256 = "256x256" // CreateImageSize256x256 256x256
  57. CreateImageSize512x512 = "512x512" // CreateImageSize512x512 512x512
  58. CreateImageSize1024x1024 = "1024x1024" // CreateImageSize1024x1024 1024x1024
  59. CreateImageResponseFormatURL = "url" // CreateImageResponseFormatURL URL
  60. CreateImageResponseFormatB64JSON = "b64_json" // CreateImageResponseFormatB64JSON B64 JSON
  61. )
  62. // Client is an API client to communicate with the OpenAI gpt-3 APIs
  63. type Client interface {
  64. // Engines lists the currently available engines, and provides basic information about each
  65. // option such as the owner and availability.
  66. Engines(ctx context.Context) (*EnginesResponse, error)
  67. // Engine retrieves an engine instance, providing basic information about the engine such
  68. // as the owner and availability.
  69. Engine(ctx context.Context, engine string) (*EngineObject, error)
  70. // ChatCompletion creates a completion with the Chat completion endpoint which
  71. // is what powers the ChatGPT experience.
  72. ChatCompletion(ctx context.Context, request *ChatCompletionRequest) (*ChatCompletionResponse, error)
  73. // ChatCompletionStream creates a completion with the Chat completion endpoint which
  74. // is what powers the ChatGPT experience.
  75. ChatCompletionStream(ctx context.Context, request *ChatCompletionRequest, onData func(*ChatCompletionStreamResponse)) error
  76. // Completion creates a completion with the default engine. This is the main endpoint of the API
  77. // which auto-completes based on the given prompt.
  78. Completion(ctx context.Context, request *CompletionRequest) (*CompletionResponse, error)
  79. // CompletionStream creates a completion with the default engine and streams the results through
  80. // multiple calls to onData.
  81. CompletionStream(ctx context.Context, request *CompletionRequest, onData func(*CompletionResponse)) error
  82. // CompletionWithEngine is the same as Completion except allows overriding the default engine on the client
  83. CompletionWithEngine(ctx context.Context, request *CompletionRequest) (*CompletionResponse, error)
  84. // CompletionStreamWithEngine is the same as CompletionStream allows overriding the default engine on the client
  85. CompletionStreamWithEngine(ctx context.Context, request *CompletionRequest, onData func(*CompletionResponse)) error
  86. // Edits is given a prompt and an instruction, the model will return an edited version of the prompt.
  87. Edits(ctx context.Context, request *EditsRequest) (*EditsResponse, error)
  88. // Search performs a semantic search over a list of documents with the default engine.
  89. Search(ctx context.Context, request *SearchRequest) (*SearchResponse, error)
  90. // SearchWithEngine performs a semantic search over a list of documents with the specified engine.
  91. SearchWithEngine(ctx context.Context, engine string, request *SearchRequest) (*SearchResponse, error)
  92. // Embeddings Returns an embedding using the provided request.
  93. Embeddings(ctx context.Context, request *EmbeddingsRequest) (*EmbeddingsResponse, error)
  94. // Image returns an image using the provided request.
  95. Image(ctx context.Context, request *ImageRequest) (*ImageResponse, error)
  96. }
  97. type client struct {
  98. baseURL string
  99. apiKey string
  100. userAgent string
  101. httpClient *http.Client
  102. defaultEngine string
  103. idOrg string
  104. }
  105. // NewClient returns a new OpenAI GPT-3 API client. An APIKey is required to use the client
  106. func NewClient(apiKey string, options ...ClientOption) Client {
  107. httpClient := &http.Client{
  108. Timeout: defaultTimeoutSeconds * time.Second,
  109. }
  110. cli := &client{
  111. userAgent: defaultUserAgent,
  112. apiKey: apiKey,
  113. baseURL: defaultBaseURL,
  114. httpClient: httpClient,
  115. defaultEngine: DefaultEngine,
  116. idOrg: "",
  117. }
  118. for _, opt := range options {
  119. cli = opt.apply(cli)
  120. }
  121. return cli
  122. }
  123. // Engines lists the currently available engines, and provides basic information about each
  124. // option such as the owner and availability.
  125. func (c *client) Engines(ctx context.Context) (*EnginesResponse, error) {
  126. req, err := c.newRequest(ctx, "GET", "/engines", nil)
  127. if err != nil {
  128. return nil, err
  129. }
  130. rsp, err := c.performRequest(req)
  131. if err != nil {
  132. return nil, err
  133. }
  134. output := new(EnginesResponse)
  135. if err := getResponseObject(rsp, output); err != nil {
  136. return nil, err
  137. }
  138. return output, nil
  139. }
  140. // Engine retrieves an engine instance, providing basic information about the engine such
  141. // as the owner and availability.
  142. func (c *client) Engine(ctx context.Context, engine string) (*EngineObject, error) {
  143. req, err := c.newRequest(ctx, "GET", fmt.Sprintf("/engines/%s", engine), nil)
  144. if err != nil {
  145. return nil, err
  146. }
  147. rsp, err := c.performRequest(req)
  148. if err != nil {
  149. return nil, err
  150. }
  151. output := new(EngineObject)
  152. if err := getResponseObject(rsp, output); err != nil {
  153. return nil, err
  154. }
  155. return output, nil
  156. }
  157. // ChatCompletion creates a completion with the Chat completion endpoint which
  158. // is what powers the ChatGPT experience.
  159. func (c *client) ChatCompletion(ctx context.Context, request *ChatCompletionRequest) (*ChatCompletionResponse, error) {
  160. if request.Model == "" {
  161. request.Model = GPT3Dot5Turbo
  162. }
  163. request.Stream = false
  164. req, err := c.newRequest(ctx, "POST", "/chat/completions", &request)
  165. if err != nil {
  166. return nil, err
  167. }
  168. rsp, err := c.performRequest(req)
  169. if err != nil {
  170. return nil, err
  171. }
  172. output := new(ChatCompletionResponse)
  173. if err := getResponseObject(rsp, output); err != nil {
  174. return nil, err
  175. }
  176. return output, nil
  177. }
  178. // ChatCompletionStream creates a completion with the Chat completion endpoint which
  179. // is what powers the ChatGPT experience.
  180. func (c *client) ChatCompletionStream(ctx context.Context, request *ChatCompletionRequest, onData func(*ChatCompletionStreamResponse)) error {
  181. if request.Model == "" {
  182. request.Model = GPT3Dot5Turbo
  183. }
  184. request.Stream = true
  185. req, err := c.newRequest(ctx, "POST", "/chat/completions", request)
  186. if err != nil {
  187. return err
  188. }
  189. rsp, err := c.performRequest(req)
  190. if err != nil {
  191. return err
  192. }
  193. reader := bufio.NewReader(rsp.Body)
  194. defer rsp.Body.Close()
  195. for {
  196. line, err := reader.ReadBytes('\n')
  197. if err != nil {
  198. return err
  199. }
  200. // make sure there isn't any extra whitespace before or after
  201. line = bytes.TrimSpace(line)
  202. // the completion API only returns data events
  203. if !bytes.HasPrefix(line, dataPrefix) {
  204. continue
  205. }
  206. line = bytes.TrimPrefix(line, dataPrefix)
  207. // the stream is completed when terminated by [DONE]
  208. if bytes.HasPrefix(line, doneSequence) {
  209. break
  210. }
  211. output := new(ChatCompletionStreamResponse)
  212. if err := json.Unmarshal(line, output); err != nil {
  213. return fmt.Errorf("invalid json stream data: %v", err)
  214. }
  215. onData(output)
  216. }
  217. return nil
  218. }
  219. // Completion creates a completion with the default engine.
  220. func (c *client) Completion(ctx context.Context, request *CompletionRequest) (*CompletionResponse, error) {
  221. return c.CompletionWithEngine(ctx, request)
  222. }
  223. // CompletionWithEngine creates a completion with the specified engine.
  224. func (c *client) CompletionWithEngine(ctx context.Context, request *CompletionRequest) (*CompletionResponse, error) {
  225. request.Stream = false
  226. req, err := c.newRequest(ctx, "POST", "/completions", &request)
  227. if err != nil {
  228. return nil, err
  229. }
  230. rsp, err := c.performRequest(req)
  231. if err != nil {
  232. return nil, err
  233. }
  234. output := new(CompletionResponse)
  235. if err := getResponseObject(rsp, output); err != nil {
  236. return nil, err
  237. }
  238. return output, nil
  239. }
  240. // CompletionStream creates a completion with the default engine.
  241. func (c *client) CompletionStream(ctx context.Context, request *CompletionRequest,
  242. onData func(*CompletionResponse)) error {
  243. return c.CompletionStreamWithEngine(ctx, request, onData)
  244. }
  245. var (
  246. dataPrefix = []byte("data: ")
  247. doneSequence = []byte("[DONE]")
  248. )
  249. // CompletionStreamWithEngine creates a completion with the specified engine.
  250. func (c *client) CompletionStreamWithEngine(ctx context.Context, request *CompletionRequest,
  251. onData func(*CompletionResponse)) error {
  252. request.Stream = true
  253. req, err := c.newRequest(ctx, "POST", "/completions", &request)
  254. if err != nil {
  255. return err
  256. }
  257. rsp, err := c.performRequest(req)
  258. if err != nil {
  259. return err
  260. }
  261. reader := bufio.NewReader(rsp.Body)
  262. defer rsp.Body.Close()
  263. for {
  264. line, err := reader.ReadBytes('\n')
  265. if err != nil {
  266. return err
  267. }
  268. // make sure there isn't any extra whitespace before or after
  269. line = bytes.TrimSpace(line)
  270. // the completion API only returns data events
  271. if !bytes.HasPrefix(line, dataPrefix) {
  272. continue
  273. }
  274. line = bytes.TrimPrefix(line, dataPrefix)
  275. // the stream is completed when terminated by [DONE]
  276. if bytes.HasPrefix(line, doneSequence) {
  277. break
  278. }
  279. output := new(CompletionResponse)
  280. if err := json.Unmarshal(line, output); err != nil {
  281. return fmt.Errorf("invalid json stream data: %v", err)
  282. }
  283. onData(output)
  284. }
  285. return nil
  286. }
  287. // Edits is given a prompt and an instruction, the model will return an edited version of the prompt.
  288. func (c *client) Edits(ctx context.Context, request *EditsRequest) (*EditsResponse, error) {
  289. req, err := c.newRequest(ctx, "POST", "/edits", &request)
  290. if err != nil {
  291. return nil, err
  292. }
  293. rsp, err := c.performRequest(req)
  294. if err != nil {
  295. return nil, err
  296. }
  297. output := new(EditsResponse)
  298. if err := getResponseObject(rsp, output); err != nil {
  299. return nil, err
  300. }
  301. return output, nil
  302. }
  303. // Search creates a search with the default engine.
  304. func (c *client) Search(ctx context.Context, request *SearchRequest) (*SearchResponse, error) {
  305. return c.SearchWithEngine(ctx, c.defaultEngine, request)
  306. }
  307. // SearchWithEngine performs a semantic search over a list of documents with the specified engine.
  308. func (c *client) SearchWithEngine(ctx context.Context, engine string, request *SearchRequest) (*SearchResponse, error) {
  309. req, err := c.newRequest(ctx, "POST", fmt.Sprintf("/engines/%s/search", engine), &request)
  310. if err != nil {
  311. return nil, err
  312. }
  313. rsp, err := c.performRequest(req)
  314. if err != nil {
  315. return nil, err
  316. }
  317. output := new(SearchResponse)
  318. if err := getResponseObject(rsp, output); err != nil {
  319. return nil, err
  320. }
  321. return output, nil
  322. }
  323. // Embeddings creates text embeddings for a supplied slice of inputs with a provided model.
  324. // See: https://beta.openai.com/docs/api-reference/embeddings
  325. func (c *client) Embeddings(ctx context.Context, request *EmbeddingsRequest) (*EmbeddingsResponse, error) {
  326. req, err := c.newRequest(ctx, "POST", "/embeddings", &request)
  327. if err != nil {
  328. return nil, err
  329. }
  330. rsp, err := c.performRequest(req)
  331. if err != nil {
  332. return nil, err
  333. }
  334. output := EmbeddingsResponse{}
  335. if err := getResponseObject(rsp, &output); err != nil {
  336. return nil, err
  337. }
  338. return &output, nil
  339. }
  340. // Image creates an image
  341. func (c *client) Image(ctx context.Context, request *ImageRequest) (*ImageResponse, error) {
  342. req, err := c.newRequest(ctx, "POST", "/images/generations", &request)
  343. if err != nil {
  344. return nil, err
  345. }
  346. rsp, err := c.performRequest(req)
  347. if err != nil {
  348. return nil, err
  349. }
  350. output := ImageResponse{}
  351. if err := getResponseObject(rsp, &output); err != nil {
  352. return nil, err
  353. }
  354. return &output, nil
  355. }
  356. func (c *client) performRequest(req *http.Request) (*http.Response, error) {
  357. rsp, err := c.httpClient.Do(req)
  358. if err != nil {
  359. return nil, err
  360. }
  361. if err := checkForSuccess(rsp); err != nil {
  362. return nil, err
  363. }
  364. return rsp, nil
  365. }
  366. // checkForSuccess returns an error if this response includes an error.
  367. func checkForSuccess(rsp *http.Response) error {
  368. if rsp.StatusCode >= 200 && rsp.StatusCode < 300 {
  369. return nil
  370. }
  371. defer rsp.Body.Close()
  372. data, err := io.ReadAll(rsp.Body)
  373. if err != nil {
  374. return fmt.Errorf("failed to read from body: %w", err)
  375. }
  376. var result APIErrorResponse
  377. if err := json.Unmarshal(data, &result); err != nil {
  378. // if we can't decode the json error then create an unexpected error
  379. apiError := APIError{
  380. StatusCode: rsp.StatusCode,
  381. Type: "Unexpected",
  382. Message: string(data),
  383. }
  384. return apiError
  385. }
  386. result.Error.StatusCode = rsp.StatusCode
  387. return result.Error
  388. }
  389. func getResponseObject(rsp *http.Response, v interface{}) error {
  390. defer rsp.Body.Close()
  391. if err := json.NewDecoder(rsp.Body).Decode(v); err != nil {
  392. return fmt.Errorf("invalid json response: %w", err)
  393. }
  394. return nil
  395. }
  396. func jsonBodyReader(body interface{}) (io.Reader, error) {
  397. if body == nil {
  398. return bytes.NewBuffer(nil), nil
  399. }
  400. raw, err := json.Marshal(body)
  401. if err != nil {
  402. return nil, fmt.Errorf("failed encoding json: %w", err)
  403. }
  404. return bytes.NewBuffer(raw), nil
  405. }
  406. func (c *client) newRequest(ctx context.Context, method, path string, payload interface{}) (*http.Request, error) {
  407. bodyReader, err := jsonBodyReader(payload)
  408. if err != nil {
  409. return nil, err
  410. }
  411. url := c.baseURL + path
  412. req, err := http.NewRequestWithContext(ctx, method, url, bodyReader)
  413. if err != nil {
  414. return nil, err
  415. }
  416. if len(c.idOrg) > 0 {
  417. req.Header.Set("OpenAI-Organization", c.idOrg)
  418. }
  419. req.Header.Set("Content-type", "application/json")
  420. req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
  421. return req, nil
  422. }