gpt.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. // Package gpt provides a client for the OpenAI GPT-3 API
  2. package gpt
  3. import (
  4. "bufio"
  5. "bytes"
  6. "context"
  7. "encoding/json"
  8. "fmt"
  9. "golang.org/x/net/proxy"
  10. "io"
  11. "net/http"
  12. "time"
  13. )
  14. // Define GPT-3 Engine Types
  15. const (
  16. TextAda001Engine = "text-ada-001" // TextAda001Engine Text Ada 001
  17. TextBabbage001Engine = "text-babbage-001" // TextBabbage001Engine Text Babbage 001
  18. TextCurie001Engine = "text-curie-001" // TextCurie001Engine Text Curie 001
  19. TextDavinci001Engine = "text-davinci-001" // TextDavinci001Engine Text Davinci 001
  20. TextDavinci002Engine = "text-davinci-002" // TextDavinci002Engine Text Davinci 002
  21. TextDavinci003Engine = "text-davinci-003" // TextDavinci003Engine Text Davinci 003
  22. AdaEngine = "ada" // AdaEngine Ada
  23. BabbageEngine = "babbage" // BabbageEngine Babbage
  24. CurieEngine = "curie" // CurieEngine Curie
  25. DavinciEngine = "davinci" // DavinciEngine Davinci
  26. DefaultEngine = DavinciEngine // DefaultEngine Default Engine
  27. )
  28. const (
  29. GPT4 = "gpt4" // GPT4 GPT-4
  30. GPT3Dot5Turbo = "gpt-3.5-turbo" // GPT3Dot5Turbo GPT-3.5 Turbo
  31. GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" // GPT3Dot5Turbo0301 GPT-3.5 Turbo 0301
  32. TextSimilarityAda001 = "text-similarity-ada-001" // TextSimilarityAda001 Text Similarity Ada 001
  33. TextSimilarityBabbage001 = "text-similarity-babbage-001" // TextSimilarityBabbage001 Text Similarity Babbage 001
  34. TextSimilarityCurie001 = "text-similarity-curie-001" // TextSimilarityCurie001 Text Similarity Curie 001
  35. TextSimilarityDavinci001 = "text-similarity-davinci-001" // TextSimilarityDavinci001 Text Similarity Davinci 001
  36. TextSearchAdaDoc001 = "text-search-ada-doc-001" // TextSearchAdaDoc001 Text Search Ada Doc 001
  37. TextSearchAdaQuery001 = "text-search-ada-query-001" // TextSearchAdaQuery001 Text Search Ada Query 001
  38. TextSearchBabbageDoc001 = "text-search-babbage-doc-001" // TextSearchBabbageDoc001 Text Search Babbage Doc 001
  39. TextSearchBabbageQuery001 = "text-search-babbage-query-001" // TextSearchBabbageQuery001 Text Search Babbage Query 001
  40. TextSearchCurieDoc001 = "text-search-curie-doc-001" // TextSearchCurieDoc001 Text Search Curie Doc 001
  41. TextSearchCurieQuery001 = "text-search-curie-query-001" // TextSearchCurieQuery001 Text Search Curie Query 001
  42. TextSearchDavinciDoc001 = "text-search-davinci-doc-001" // TextSearchDavinciDoc001 Text Search Davinci Doc 001
  43. TextSearchDavinciQuery001 = "text-search-davinci-query-001" // TextSearchDavinciQuery001 Text Search Davinci Query 001
  44. CodeSearchAdaCode001 = "code-search-ada-code-001" // CodeSearchAdaCode001 Code Search Ada Code 001
  45. CodeSearchAdaText001 = "code-search-ada-text-001" // CodeSearchAdaText001 Code Search Ada Text 001
  46. CodeSearchBabbageCode001 = "code-search-babbage-code-001" // CodeSearchBabbageCode001 Code Search Babbage Code 001
  47. CodeSearchBabbageText001 = "code-search-babbage-text-001" // CodeSearchBabbageText001 Code Search Babbage Text 001
  48. TextEmbeddingAda002 = "text-embedding-ada-002" // TextEmbeddingAda002 Text Embedding Ada 002
  49. )
  50. const (
  51. defaultBaseURL = "https://api.openai.com/v1"
  52. defaultUserAgent = "go-gpt3"
  53. defaultTimeoutSeconds = 30
  54. )
  55. // Image sizes defined by the OpenAI API.
  56. const (
  57. CreateImageSize256x256 = "256x256" // CreateImageSize256x256 256x256
  58. CreateImageSize512x512 = "512x512" // CreateImageSize512x512 512x512
  59. CreateImageSize1024x1024 = "1024x1024" // CreateImageSize1024x1024 1024x1024
  60. CreateImageResponseFormatURL = "url" // CreateImageResponseFormatURL URL
  61. CreateImageResponseFormatB64JSON = "b64_json" // CreateImageResponseFormatB64JSON B64 JSON
  62. )
  63. // Client is an API client to communicate with the OpenAI gpt-3 APIs
  64. type Client interface {
  65. // Engines lists the currently available engines, and provides basic information about each
  66. // option such as the owner and availability.
  67. Engines(ctx context.Context) (*EnginesResponse, error)
  68. // Engine retrieves an engine instance, providing basic information about the engine such
  69. // as the owner and availability.
  70. Engine(ctx context.Context, engine string) (*EngineObject, error)
  71. // ChatCompletion creates a completion with the Chat completion endpoint which
  72. // is what powers the ChatGPT experience.
  73. ChatCompletion(ctx context.Context, request *ChatCompletionRequest) (*ChatCompletionResponse, error)
  74. // ChatCompletionStream creates a completion with the Chat completion endpoint which
  75. // is what powers the ChatGPT experience.
  76. ChatCompletionStream(ctx context.Context, request *ChatCompletionRequest, onData func(*ChatCompletionStreamResponse)) error
  77. // Completion creates a completion with the default engine. This is the main endpoint of the API
  78. // which auto-completes based on the given prompt.
  79. Completion(ctx context.Context, request *CompletionRequest) (*CompletionResponse, error)
  80. // CompletionStream creates a completion with the default engine and streams the results through
  81. // multiple calls to onData.
  82. CompletionStream(ctx context.Context, request *CompletionRequest, onData func(*CompletionResponse)) error
  83. // CompletionWithEngine is the same as Completion except allows overriding the default engine on the client
  84. CompletionWithEngine(ctx context.Context, request *CompletionRequest) (*CompletionResponse, error)
  85. // CompletionStreamWithEngine is the same as CompletionStream allows overriding the default engine on the client
  86. CompletionStreamWithEngine(ctx context.Context, request *CompletionRequest, onData func(*CompletionResponse)) error
  87. // Edits is given a prompt and an instruction, the model will return an edited version of the prompt.
  88. Edits(ctx context.Context, request *EditsRequest) (*EditsResponse, error)
  89. // Search performs a semantic search over a list of documents with the default engine.
  90. Search(ctx context.Context, request *SearchRequest) (*SearchResponse, error)
  91. // SearchWithEngine performs a semantic search over a list of documents with the specified engine.
  92. SearchWithEngine(ctx context.Context, engine string, request *SearchRequest) (*SearchResponse, error)
  93. // Embeddings Returns an embedding using the provided request.
  94. Embeddings(ctx context.Context, request *EmbeddingsRequest) (*EmbeddingsResponse, error)
  95. // Image returns an image using the provided request.
  96. Image(ctx context.Context, request *ImageRequest) (*ImageResponse, error)
  97. }
  98. type client struct {
  99. baseURL string
  100. apiKey string
  101. userAgent string
  102. httpClient *http.Client
  103. defaultEngine string
  104. idOrg string
  105. }
  106. // NewClient returns a new OpenAI GPT-3 API client. An APIKey is required to use the client
  107. func NewClient(apiKey string, proxyUrl string, options ...ClientOption) Client {
  108. var httpClient *http.Client
  109. if proxyUrl != "" {
  110. // Configure the SOCKS5 proxy
  111. dialer, err := proxy.SOCKS5("tcp", proxyUrl, nil, proxy.Direct)
  112. if err != nil {
  113. return nil
  114. }
  115. httpClient = &http.Client{
  116. Timeout: defaultTimeoutSeconds * time.Second,
  117. Transport: &http.Transport{
  118. Dial: dialer.Dial, // Use Dial instead of DialContext
  119. },
  120. }
  121. } else {
  122. httpClient = &http.Client{
  123. Timeout: defaultTimeoutSeconds * time.Second,
  124. }
  125. }
  126. cli := &client{
  127. userAgent: defaultUserAgent,
  128. apiKey: apiKey,
  129. baseURL: defaultBaseURL,
  130. httpClient: httpClient,
  131. defaultEngine: DefaultEngine,
  132. idOrg: "",
  133. }
  134. for _, opt := range options {
  135. cli = opt.apply(cli)
  136. }
  137. return cli
  138. }
  139. // Engines lists the currently available engines, and provides basic information about each
  140. // option such as the owner and availability.
  141. func (c *client) Engines(ctx context.Context) (*EnginesResponse, error) {
  142. req, err := c.newRequest(ctx, "GET", "/engines", nil)
  143. if err != nil {
  144. return nil, err
  145. }
  146. rsp, err := c.performRequest(req)
  147. if err != nil {
  148. return nil, err
  149. }
  150. output := new(EnginesResponse)
  151. if err := getResponseObject(rsp, output); err != nil {
  152. return nil, err
  153. }
  154. return output, nil
  155. }
  156. // Engine retrieves an engine instance, providing basic information about the engine such
  157. // as the owner and availability.
  158. func (c *client) Engine(ctx context.Context, engine string) (*EngineObject, error) {
  159. req, err := c.newRequest(ctx, "GET", fmt.Sprintf("/engines/%s", engine), nil)
  160. if err != nil {
  161. return nil, err
  162. }
  163. rsp, err := c.performRequest(req)
  164. if err != nil {
  165. return nil, err
  166. }
  167. output := new(EngineObject)
  168. if err := getResponseObject(rsp, output); err != nil {
  169. return nil, err
  170. }
  171. return output, nil
  172. }
  173. // ChatCompletion creates a completion with the Chat completion endpoint which
  174. // is what powers the ChatGPT experience.
  175. func (c *client) ChatCompletion(ctx context.Context, request *ChatCompletionRequest) (*ChatCompletionResponse, error) {
  176. if request.Model == "" {
  177. request.Model = GPT3Dot5Turbo
  178. }
  179. request.Stream = false
  180. req, err := c.newRequest(ctx, "POST", "/chat/completions", &request)
  181. if err != nil {
  182. return nil, err
  183. }
  184. rsp, err := c.performRequest(req)
  185. if err != nil {
  186. return nil, err
  187. }
  188. output := new(ChatCompletionResponse)
  189. if err := getResponseObject(rsp, output); err != nil {
  190. return nil, err
  191. }
  192. return output, nil
  193. }
  194. // ChatCompletionStream creates a completion with the Chat completion endpoint which
  195. // is what powers the ChatGPT experience.
  196. func (c *client) ChatCompletionStream(ctx context.Context, request *ChatCompletionRequest, onData func(*ChatCompletionStreamResponse)) error {
  197. if request.Model == "" {
  198. request.Model = GPT3Dot5Turbo
  199. }
  200. request.Stream = true
  201. req, err := c.newRequest(ctx, "POST", "/chat/completions", request)
  202. if err != nil {
  203. return err
  204. }
  205. rsp, err := c.performRequest(req)
  206. if err != nil {
  207. return err
  208. }
  209. reader := bufio.NewReader(rsp.Body)
  210. defer rsp.Body.Close()
  211. for {
  212. line, err := reader.ReadBytes('\n')
  213. if err != nil {
  214. return err
  215. }
  216. // make sure there isn't any extra whitespace before or after
  217. line = bytes.TrimSpace(line)
  218. // the completion API only returns data events
  219. if !bytes.HasPrefix(line, dataPrefix) {
  220. continue
  221. }
  222. line = bytes.TrimPrefix(line, dataPrefix)
  223. // the stream is completed when terminated by [DONE]
  224. if bytes.HasPrefix(line, doneSequence) {
  225. break
  226. }
  227. output := new(ChatCompletionStreamResponse)
  228. if err := json.Unmarshal(line, output); err != nil {
  229. return fmt.Errorf("invalid json stream data: %v", err)
  230. }
  231. onData(output)
  232. }
  233. return nil
  234. }
  235. // Completion creates a completion with the default engine.
  236. func (c *client) Completion(ctx context.Context, request *CompletionRequest) (*CompletionResponse, error) {
  237. return c.CompletionWithEngine(ctx, request)
  238. }
  239. // CompletionWithEngine creates a completion with the specified engine.
  240. func (c *client) CompletionWithEngine(ctx context.Context, request *CompletionRequest) (*CompletionResponse, error) {
  241. request.Stream = false
  242. req, err := c.newRequest(ctx, "POST", "/completions", &request)
  243. if err != nil {
  244. return nil, err
  245. }
  246. rsp, err := c.performRequest(req)
  247. if err != nil {
  248. return nil, err
  249. }
  250. output := new(CompletionResponse)
  251. if err := getResponseObject(rsp, output); err != nil {
  252. return nil, err
  253. }
  254. return output, nil
  255. }
  256. // CompletionStream creates a completion with the default engine.
  257. func (c *client) CompletionStream(ctx context.Context, request *CompletionRequest,
  258. onData func(*CompletionResponse)) error {
  259. return c.CompletionStreamWithEngine(ctx, request, onData)
  260. }
  261. var (
  262. dataPrefix = []byte("data: ")
  263. doneSequence = []byte("[DONE]")
  264. )
  265. // CompletionStreamWithEngine creates a completion with the specified engine.
  266. func (c *client) CompletionStreamWithEngine(ctx context.Context, request *CompletionRequest,
  267. onData func(*CompletionResponse)) error {
  268. request.Stream = true
  269. req, err := c.newRequest(ctx, "POST", "/completions", &request)
  270. if err != nil {
  271. return err
  272. }
  273. rsp, err := c.performRequest(req)
  274. if err != nil {
  275. return err
  276. }
  277. reader := bufio.NewReader(rsp.Body)
  278. defer rsp.Body.Close()
  279. for {
  280. line, err := reader.ReadBytes('\n')
  281. if err != nil {
  282. return err
  283. }
  284. // make sure there isn't any extra whitespace before or after
  285. line = bytes.TrimSpace(line)
  286. // the completion API only returns data events
  287. if !bytes.HasPrefix(line, dataPrefix) {
  288. continue
  289. }
  290. line = bytes.TrimPrefix(line, dataPrefix)
  291. // the stream is completed when terminated by [DONE]
  292. if bytes.HasPrefix(line, doneSequence) {
  293. break
  294. }
  295. output := new(CompletionResponse)
  296. if err := json.Unmarshal(line, output); err != nil {
  297. return fmt.Errorf("invalid json stream data: %v", err)
  298. }
  299. onData(output)
  300. }
  301. return nil
  302. }
  303. // Edits is given a prompt and an instruction, the model will return an edited version of the prompt.
  304. func (c *client) Edits(ctx context.Context, request *EditsRequest) (*EditsResponse, error) {
  305. req, err := c.newRequest(ctx, "POST", "/edits", &request)
  306. if err != nil {
  307. return nil, err
  308. }
  309. rsp, err := c.performRequest(req)
  310. if err != nil {
  311. return nil, err
  312. }
  313. output := new(EditsResponse)
  314. if err := getResponseObject(rsp, output); err != nil {
  315. return nil, err
  316. }
  317. return output, nil
  318. }
  319. // Search creates a search with the default engine.
  320. func (c *client) Search(ctx context.Context, request *SearchRequest) (*SearchResponse, error) {
  321. return c.SearchWithEngine(ctx, c.defaultEngine, request)
  322. }
  323. // SearchWithEngine performs a semantic search over a list of documents with the specified engine.
  324. func (c *client) SearchWithEngine(ctx context.Context, engine string, request *SearchRequest) (*SearchResponse, error) {
  325. req, err := c.newRequest(ctx, "POST", fmt.Sprintf("/engines/%s/search", engine), &request)
  326. if err != nil {
  327. return nil, err
  328. }
  329. rsp, err := c.performRequest(req)
  330. if err != nil {
  331. return nil, err
  332. }
  333. output := new(SearchResponse)
  334. if err := getResponseObject(rsp, output); err != nil {
  335. return nil, err
  336. }
  337. return output, nil
  338. }
  339. // Embeddings creates text embeddings for a supplied slice of inputs with a provided model.
  340. // See: https://beta.openai.com/docs/api-reference/embeddings
  341. func (c *client) Embeddings(ctx context.Context, request *EmbeddingsRequest) (*EmbeddingsResponse, error) {
  342. req, err := c.newRequest(ctx, "POST", "/embeddings", &request)
  343. if err != nil {
  344. return nil, err
  345. }
  346. rsp, err := c.performRequest(req)
  347. if err != nil {
  348. return nil, err
  349. }
  350. output := EmbeddingsResponse{}
  351. if err := getResponseObject(rsp, &output); err != nil {
  352. return nil, err
  353. }
  354. return &output, nil
  355. }
  356. // Image creates an image
  357. func (c *client) Image(ctx context.Context, request *ImageRequest) (*ImageResponse, error) {
  358. req, err := c.newRequest(ctx, "POST", "/images/generations", &request)
  359. if err != nil {
  360. return nil, err
  361. }
  362. rsp, err := c.performRequest(req)
  363. if err != nil {
  364. return nil, err
  365. }
  366. output := ImageResponse{}
  367. if err := getResponseObject(rsp, &output); err != nil {
  368. return nil, err
  369. }
  370. return &output, nil
  371. }
  372. func (c *client) performRequest(req *http.Request) (*http.Response, error) {
  373. rsp, err := c.httpClient.Do(req)
  374. if err != nil {
  375. return nil, err
  376. }
  377. if err := checkForSuccess(rsp); err != nil {
  378. return nil, err
  379. }
  380. return rsp, nil
  381. }
  382. // checkForSuccess returns an error if this response includes an error.
  383. func checkForSuccess(rsp *http.Response) error {
  384. if rsp.StatusCode >= 200 && rsp.StatusCode < 300 {
  385. return nil
  386. }
  387. defer rsp.Body.Close()
  388. data, err := io.ReadAll(rsp.Body)
  389. if err != nil {
  390. return fmt.Errorf("failed to read from body: %w", err)
  391. }
  392. var result APIErrorResponse
  393. if err := json.Unmarshal(data, &result); err != nil {
  394. // if we can't decode the json error then create an unexpected error
  395. apiError := APIError{
  396. StatusCode: rsp.StatusCode,
  397. Type: "Unexpected",
  398. Message: string(data),
  399. }
  400. return apiError
  401. }
  402. result.Error.StatusCode = rsp.StatusCode
  403. return result.Error
  404. }
  405. func getResponseObject(rsp *http.Response, v interface{}) error {
  406. defer rsp.Body.Close()
  407. body, err := io.ReadAll(rsp.Body)
  408. if err != nil {
  409. fmt.Println("Error:", err)
  410. return err
  411. }
  412. fmt.Println("Response:")
  413. fmt.Println(string(body))
  414. if err := json.NewDecoder(bytes.NewReader(body)).Decode(v); err != nil {
  415. return fmt.Errorf("invalid json response: %w", err)
  416. }
  417. return nil
  418. }
  419. func jsonBodyReader(body interface{}) (io.Reader, error) {
  420. if body == nil {
  421. return bytes.NewBuffer(nil), nil
  422. }
  423. raw, err := json.Marshal(body)
  424. if err != nil {
  425. return nil, fmt.Errorf("failed encoding json: %w", err)
  426. }
  427. return bytes.NewBuffer(raw), nil
  428. }
  429. func (c *client) newRequest(ctx context.Context, method, path string, payload interface{}) (*http.Request, error) {
  430. bodyReader, err := jsonBodyReader(payload)
  431. if err != nil {
  432. return nil, err
  433. }
  434. url := c.baseURL + path
  435. req, err := http.NewRequestWithContext(ctx, method, url, bodyReader)
  436. if err != nil {
  437. return nil, err
  438. }
  439. if len(c.idOrg) > 0 {
  440. req.Header.Set("OpenAI-Organization", c.idOrg)
  441. }
  442. req.Header.Set("Content-type", "application/json")
  443. req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
  444. return req, nil
  445. }