openai.go 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993
  1. // openai package provides middleware for partial compatibility with the OpenAI REST API
  2. package openai
  3. import (
  4. "bytes"
  5. "encoding/base64"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log/slog"
  11. "math/rand"
  12. "net/http"
  13. "strings"
  14. "time"
  15. "github.com/gin-gonic/gin"
  16. "github.com/ollama/ollama/api"
  17. "github.com/ollama/ollama/types/model"
  18. )
  19. var finishReasonToolCalls = "tool_calls"
  20. type Error struct {
  21. Message string `json:"message"`
  22. Type string `json:"type"`
  23. Param interface{} `json:"param"`
  24. Code *string `json:"code"`
  25. }
  26. type ErrorResponse struct {
  27. Error Error `json:"error"`
  28. }
  29. type Message struct {
  30. Role string `json:"role"`
  31. Content any `json:"content"`
  32. ToolCalls []ToolCall `json:"tool_calls,omitempty"`
  33. }
  34. type Choice struct {
  35. Index int `json:"index"`
  36. Message Message `json:"message"`
  37. FinishReason *string `json:"finish_reason"`
  38. }
  39. type ChunkChoice struct {
  40. Index int `json:"index"`
  41. Delta Message `json:"delta"`
  42. FinishReason *string `json:"finish_reason"`
  43. }
  44. type CompleteChunkChoice struct {
  45. Text string `json:"text"`
  46. Index int `json:"index"`
  47. FinishReason *string `json:"finish_reason"`
  48. }
  49. type Usage struct {
  50. PromptTokens int `json:"prompt_tokens"`
  51. CompletionTokens int `json:"completion_tokens"`
  52. TotalTokens int `json:"total_tokens"`
  53. }
  54. type ResponseFormat struct {
  55. Type string `json:"type"`
  56. JsonSchema *JsonSchema `json:"json_schema,omitempty"`
  57. }
  58. type JsonSchema struct {
  59. Schema json.RawMessage `json:"schema"`
  60. }
  61. type EmbedRequest struct {
  62. Input any `json:"input"`
  63. Model string `json:"model"`
  64. }
  65. type StreamOptions struct {
  66. IncludeUsage bool `json:"include_usage"`
  67. }
  68. type ChatCompletionRequest struct {
  69. Model string `json:"model"`
  70. Messages []Message `json:"messages"`
  71. Stream bool `json:"stream"`
  72. StreamOptions *StreamOptions `json:"stream_options"`
  73. MaxTokens *int `json:"max_tokens"`
  74. Seed *int `json:"seed"`
  75. Stop any `json:"stop"`
  76. Temperature *float64 `json:"temperature"`
  77. FrequencyPenalty *float64 `json:"frequency_penalty"`
  78. PresencePenalty *float64 `json:"presence_penalty"`
  79. TopP *float64 `json:"top_p"`
  80. ResponseFormat *ResponseFormat `json:"response_format"`
  81. Tools []api.Tool `json:"tools"`
  82. }
  83. type ChatCompletion struct {
  84. Id string `json:"id"`
  85. Object string `json:"object"`
  86. Created int64 `json:"created"`
  87. Model string `json:"model"`
  88. SystemFingerprint string `json:"system_fingerprint"`
  89. Choices []Choice `json:"choices"`
  90. Usage Usage `json:"usage,omitempty"`
  91. }
  92. type ChatCompletionChunk struct {
  93. Id string `json:"id"`
  94. Object string `json:"object"`
  95. Created int64 `json:"created"`
  96. Model string `json:"model"`
  97. SystemFingerprint string `json:"system_fingerprint"`
  98. Choices []ChunkChoice `json:"choices"`
  99. Usage *Usage `json:"usage,omitempty"`
  100. }
  101. // TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
  102. type CompletionRequest struct {
  103. Model string `json:"model"`
  104. Prompt string `json:"prompt"`
  105. FrequencyPenalty float32 `json:"frequency_penalty"`
  106. MaxTokens *int `json:"max_tokens"`
  107. PresencePenalty float32 `json:"presence_penalty"`
  108. Seed *int `json:"seed"`
  109. Stop any `json:"stop"`
  110. Stream bool `json:"stream"`
  111. StreamOptions *StreamOptions `json:"stream_options"`
  112. Temperature *float32 `json:"temperature"`
  113. TopP float32 `json:"top_p"`
  114. Suffix string `json:"suffix"`
  115. }
  116. type Completion struct {
  117. Id string `json:"id"`
  118. Object string `json:"object"`
  119. Created int64 `json:"created"`
  120. Model string `json:"model"`
  121. SystemFingerprint string `json:"system_fingerprint"`
  122. Choices []CompleteChunkChoice `json:"choices"`
  123. Usage Usage `json:"usage,omitempty"`
  124. }
  125. type CompletionChunk struct {
  126. Id string `json:"id"`
  127. Object string `json:"object"`
  128. Created int64 `json:"created"`
  129. Choices []CompleteChunkChoice `json:"choices"`
  130. Model string `json:"model"`
  131. SystemFingerprint string `json:"system_fingerprint"`
  132. Usage *Usage `json:"usage,omitempty"`
  133. }
  134. type ToolCall struct {
  135. ID string `json:"id"`
  136. Index int `json:"index"`
  137. Type string `json:"type"`
  138. Function struct {
  139. Name string `json:"name"`
  140. Arguments string `json:"arguments"`
  141. } `json:"function"`
  142. }
  143. type Model struct {
  144. Id string `json:"id"`
  145. Object string `json:"object"`
  146. Created int64 `json:"created"`
  147. OwnedBy string `json:"owned_by"`
  148. }
  149. type Embedding struct {
  150. Object string `json:"object"`
  151. Embedding []float32 `json:"embedding"`
  152. Index int `json:"index"`
  153. }
  154. type ListCompletion struct {
  155. Object string `json:"object"`
  156. Data []Model `json:"data"`
  157. }
  158. type EmbeddingList struct {
  159. Object string `json:"object"`
  160. Data []Embedding `json:"data"`
  161. Model string `json:"model"`
  162. Usage EmbeddingUsage `json:"usage,omitempty"`
  163. }
  164. type EmbeddingUsage struct {
  165. PromptTokens int `json:"prompt_tokens"`
  166. TotalTokens int `json:"total_tokens"`
  167. }
  168. func NewError(code int, message string) ErrorResponse {
  169. var etype string
  170. switch code {
  171. case http.StatusBadRequest:
  172. etype = "invalid_request_error"
  173. case http.StatusNotFound:
  174. etype = "not_found_error"
  175. default:
  176. etype = "api_error"
  177. }
  178. return ErrorResponse{Error{Type: etype, Message: message}}
  179. }
  180. func toUsage(r api.ChatResponse) Usage {
  181. return Usage{
  182. PromptTokens: r.PromptEvalCount,
  183. CompletionTokens: r.EvalCount,
  184. TotalTokens: r.PromptEvalCount + r.EvalCount,
  185. }
  186. }
  187. func toolCallId() string {
  188. const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
  189. b := make([]byte, 8)
  190. for i := range b {
  191. b[i] = letterBytes[rand.Intn(len(letterBytes))]
  192. }
  193. return "call_" + strings.ToLower(string(b))
  194. }
  195. func toToolCalls(tc []api.ToolCall) []ToolCall {
  196. toolCalls := make([]ToolCall, len(tc))
  197. for i, tc := range tc {
  198. toolCalls[i].ID = toolCallId()
  199. toolCalls[i].Type = "function"
  200. toolCalls[i].Function.Name = tc.Function.Name
  201. toolCalls[i].Index = tc.Function.Index
  202. args, err := json.Marshal(tc.Function.Arguments)
  203. if err != nil {
  204. slog.Error("could not marshall function arguments to json", "error", err)
  205. continue
  206. }
  207. toolCalls[i].Function.Arguments = string(args)
  208. }
  209. return toolCalls
  210. }
  211. func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
  212. toolCalls := toToolCalls(r.Message.ToolCalls)
  213. return ChatCompletion{
  214. Id: id,
  215. Object: "chat.completion",
  216. Created: r.CreatedAt.Unix(),
  217. Model: r.Model,
  218. SystemFingerprint: "fp_ollama",
  219. Choices: []Choice{{
  220. Index: 0,
  221. Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
  222. FinishReason: func(reason string) *string {
  223. if len(toolCalls) > 0 {
  224. reason = "tool_calls"
  225. }
  226. if len(reason) > 0 {
  227. return &reason
  228. }
  229. return nil
  230. }(r.DoneReason),
  231. }},
  232. Usage: toUsage(r),
  233. }
  234. }
  235. func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
  236. toolCalls := toToolCalls(r.Message.ToolCalls)
  237. return ChatCompletionChunk{
  238. Id: id,
  239. Object: "chat.completion.chunk",
  240. Created: time.Now().Unix(),
  241. Model: r.Model,
  242. SystemFingerprint: "fp_ollama",
  243. Choices: []ChunkChoice{{
  244. Index: 0,
  245. Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
  246. FinishReason: func(reason string) *string {
  247. if len(reason) > 0 {
  248. if toolCallSent {
  249. return &finishReasonToolCalls
  250. }
  251. return &reason
  252. }
  253. return nil
  254. }(r.DoneReason),
  255. }},
  256. }
  257. }
  258. func toUsageGenerate(r api.GenerateResponse) Usage {
  259. return Usage{
  260. PromptTokens: r.PromptEvalCount,
  261. CompletionTokens: r.EvalCount,
  262. TotalTokens: r.PromptEvalCount + r.EvalCount,
  263. }
  264. }
  265. func toCompletion(id string, r api.GenerateResponse) Completion {
  266. return Completion{
  267. Id: id,
  268. Object: "text_completion",
  269. Created: r.CreatedAt.Unix(),
  270. Model: r.Model,
  271. SystemFingerprint: "fp_ollama",
  272. Choices: []CompleteChunkChoice{{
  273. Text: r.Response,
  274. Index: 0,
  275. FinishReason: func(reason string) *string {
  276. if len(reason) > 0 {
  277. return &reason
  278. }
  279. return nil
  280. }(r.DoneReason),
  281. }},
  282. Usage: toUsageGenerate(r),
  283. }
  284. }
  285. func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
  286. return CompletionChunk{
  287. Id: id,
  288. Object: "text_completion",
  289. Created: time.Now().Unix(),
  290. Model: r.Model,
  291. SystemFingerprint: "fp_ollama",
  292. Choices: []CompleteChunkChoice{{
  293. Text: r.Response,
  294. Index: 0,
  295. FinishReason: func(reason string) *string {
  296. if len(reason) > 0 {
  297. return &reason
  298. }
  299. return nil
  300. }(r.DoneReason),
  301. }},
  302. }
  303. }
  304. func toListCompletion(r api.ListResponse) ListCompletion {
  305. var data []Model
  306. for _, m := range r.Models {
  307. data = append(data, Model{
  308. Id: m.Name,
  309. Object: "model",
  310. Created: m.ModifiedAt.Unix(),
  311. OwnedBy: model.ParseName(m.Name).Namespace,
  312. })
  313. }
  314. return ListCompletion{
  315. Object: "list",
  316. Data: data,
  317. }
  318. }
  319. func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
  320. if r.Embeddings != nil {
  321. var data []Embedding
  322. for i, e := range r.Embeddings {
  323. data = append(data, Embedding{
  324. Object: "embedding",
  325. Embedding: e,
  326. Index: i,
  327. })
  328. }
  329. return EmbeddingList{
  330. Object: "list",
  331. Data: data,
  332. Model: model,
  333. Usage: EmbeddingUsage{
  334. PromptTokens: r.PromptEvalCount,
  335. TotalTokens: r.PromptEvalCount,
  336. },
  337. }
  338. }
  339. return EmbeddingList{}
  340. }
  341. func toModel(r api.ShowResponse, m string) Model {
  342. return Model{
  343. Id: m,
  344. Object: "model",
  345. Created: r.ModifiedAt.Unix(),
  346. OwnedBy: model.ParseName(m).Namespace,
  347. }
  348. }
  349. func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
  350. var messages []api.Message
  351. for _, msg := range r.Messages {
  352. switch content := msg.Content.(type) {
  353. case string:
  354. messages = append(messages, api.Message{Role: msg.Role, Content: content})
  355. case []any:
  356. for _, c := range content {
  357. data, ok := c.(map[string]any)
  358. if !ok {
  359. return nil, errors.New("invalid message format")
  360. }
  361. switch data["type"] {
  362. case "text":
  363. text, ok := data["text"].(string)
  364. if !ok {
  365. return nil, errors.New("invalid message format")
  366. }
  367. messages = append(messages, api.Message{Role: msg.Role, Content: text})
  368. case "image_url":
  369. var url string
  370. if urlMap, ok := data["image_url"].(map[string]any); ok {
  371. if url, ok = urlMap["url"].(string); !ok {
  372. return nil, errors.New("invalid message format")
  373. }
  374. } else {
  375. if url, ok = data["image_url"].(string); !ok {
  376. return nil, errors.New("invalid message format")
  377. }
  378. }
  379. types := []string{"jpeg", "jpg", "png"}
  380. valid := false
  381. for _, t := range types {
  382. prefix := "data:image/" + t + ";base64,"
  383. if strings.HasPrefix(url, prefix) {
  384. url = strings.TrimPrefix(url, prefix)
  385. valid = true
  386. break
  387. }
  388. }
  389. if !valid {
  390. return nil, errors.New("invalid image input")
  391. }
  392. img, err := base64.StdEncoding.DecodeString(url)
  393. if err != nil {
  394. return nil, errors.New("invalid message format")
  395. }
  396. messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
  397. default:
  398. return nil, errors.New("invalid message format")
  399. }
  400. }
  401. default:
  402. if msg.ToolCalls == nil {
  403. return nil, fmt.Errorf("invalid message content type: %T", content)
  404. }
  405. toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
  406. for i, tc := range msg.ToolCalls {
  407. toolCalls[i].Function.Name = tc.Function.Name
  408. err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
  409. if err != nil {
  410. return nil, errors.New("invalid tool call arguments")
  411. }
  412. }
  413. messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
  414. }
  415. }
  416. options := make(map[string]interface{})
  417. switch stop := r.Stop.(type) {
  418. case string:
  419. options["stop"] = []string{stop}
  420. case []any:
  421. var stops []string
  422. for _, s := range stop {
  423. if str, ok := s.(string); ok {
  424. stops = append(stops, str)
  425. }
  426. }
  427. options["stop"] = stops
  428. }
  429. if r.MaxTokens != nil {
  430. options["num_predict"] = *r.MaxTokens
  431. }
  432. if r.Temperature != nil {
  433. options["temperature"] = *r.Temperature
  434. } else {
  435. options["temperature"] = 1.0
  436. }
  437. if r.Seed != nil {
  438. options["seed"] = *r.Seed
  439. }
  440. if r.FrequencyPenalty != nil {
  441. options["frequency_penalty"] = *r.FrequencyPenalty
  442. }
  443. if r.PresencePenalty != nil {
  444. options["presence_penalty"] = *r.PresencePenalty
  445. }
  446. if r.TopP != nil {
  447. options["top_p"] = *r.TopP
  448. } else {
  449. options["top_p"] = 1.0
  450. }
  451. var format json.RawMessage
  452. if r.ResponseFormat != nil {
  453. switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
  454. // Support the old "json_object" type for OpenAI compatibility
  455. case "json_object":
  456. format = json.RawMessage(`"json"`)
  457. case "json_schema":
  458. if r.ResponseFormat.JsonSchema != nil {
  459. format = r.ResponseFormat.JsonSchema.Schema
  460. }
  461. }
  462. }
  463. return &api.ChatRequest{
  464. Model: r.Model,
  465. Messages: messages,
  466. Format: format,
  467. Options: options,
  468. Stream: &r.Stream,
  469. Tools: r.Tools,
  470. }, nil
  471. }
  472. func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
  473. options := make(map[string]any)
  474. switch stop := r.Stop.(type) {
  475. case string:
  476. options["stop"] = []string{stop}
  477. case []any:
  478. var stops []string
  479. for _, s := range stop {
  480. if str, ok := s.(string); ok {
  481. stops = append(stops, str)
  482. } else {
  483. return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
  484. }
  485. }
  486. options["stop"] = stops
  487. }
  488. if r.MaxTokens != nil {
  489. options["num_predict"] = *r.MaxTokens
  490. }
  491. if r.Temperature != nil {
  492. options["temperature"] = *r.Temperature
  493. } else {
  494. options["temperature"] = 1.0
  495. }
  496. if r.Seed != nil {
  497. options["seed"] = *r.Seed
  498. }
  499. options["frequency_penalty"] = r.FrequencyPenalty
  500. options["presence_penalty"] = r.PresencePenalty
  501. if r.TopP != 0.0 {
  502. options["top_p"] = r.TopP
  503. } else {
  504. options["top_p"] = 1.0
  505. }
  506. return api.GenerateRequest{
  507. Model: r.Model,
  508. Prompt: r.Prompt,
  509. Options: options,
  510. Stream: &r.Stream,
  511. Suffix: r.Suffix,
  512. }, nil
  513. }
  514. type BaseWriter struct {
  515. gin.ResponseWriter
  516. }
  517. type ChatWriter struct {
  518. stream bool
  519. streamOptions *StreamOptions
  520. id string
  521. toolCallSent bool
  522. BaseWriter
  523. }
  524. type CompleteWriter struct {
  525. stream bool
  526. streamOptions *StreamOptions
  527. id string
  528. BaseWriter
  529. }
  530. type ListWriter struct {
  531. BaseWriter
  532. }
  533. type RetrieveWriter struct {
  534. BaseWriter
  535. model string
  536. }
  537. type EmbedWriter struct {
  538. BaseWriter
  539. model string
  540. }
  541. func (w *BaseWriter) writeError(data []byte) (int, error) {
  542. var serr api.StatusError
  543. err := json.Unmarshal(data, &serr)
  544. if err != nil {
  545. return 0, err
  546. }
  547. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  548. err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
  549. if err != nil {
  550. return 0, err
  551. }
  552. return len(data), nil
  553. }
  554. func (w *ChatWriter) writeResponse(data []byte) (int, error) {
  555. var chatResponse api.ChatResponse
  556. err := json.Unmarshal(data, &chatResponse)
  557. if err != nil {
  558. return 0, err
  559. }
  560. // chat chunk
  561. if w.stream {
  562. c := toChunk(w.id, chatResponse, w.toolCallSent)
  563. d, err := json.Marshal(c)
  564. if err != nil {
  565. return 0, err
  566. }
  567. if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
  568. w.toolCallSent = true
  569. }
  570. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  571. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  572. if err != nil {
  573. return 0, err
  574. }
  575. if chatResponse.Done {
  576. if w.streamOptions != nil && w.streamOptions.IncludeUsage {
  577. u := toUsage(chatResponse)
  578. c.Usage = &u
  579. c.Choices = []ChunkChoice{}
  580. d, err := json.Marshal(c)
  581. if err != nil {
  582. return 0, err
  583. }
  584. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  585. if err != nil {
  586. return 0, err
  587. }
  588. }
  589. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  590. if err != nil {
  591. return 0, err
  592. }
  593. }
  594. return len(data), nil
  595. }
  596. // chat completion
  597. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  598. err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
  599. if err != nil {
  600. return 0, err
  601. }
  602. return len(data), nil
  603. }
  604. func (w *ChatWriter) Write(data []byte) (int, error) {
  605. code := w.ResponseWriter.Status()
  606. if code != http.StatusOK {
  607. return w.writeError(data)
  608. }
  609. return w.writeResponse(data)
  610. }
  611. func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
  612. var generateResponse api.GenerateResponse
  613. err := json.Unmarshal(data, &generateResponse)
  614. if err != nil {
  615. return 0, err
  616. }
  617. // completion chunk
  618. if w.stream {
  619. c := toCompleteChunk(w.id, generateResponse)
  620. if w.streamOptions != nil && w.streamOptions.IncludeUsage {
  621. c.Usage = &Usage{}
  622. }
  623. d, err := json.Marshal(c)
  624. if err != nil {
  625. return 0, err
  626. }
  627. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  628. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  629. if err != nil {
  630. return 0, err
  631. }
  632. if generateResponse.Done {
  633. if w.streamOptions != nil && w.streamOptions.IncludeUsage {
  634. u := toUsageGenerate(generateResponse)
  635. c.Usage = &u
  636. c.Choices = []CompleteChunkChoice{}
  637. d, err := json.Marshal(c)
  638. if err != nil {
  639. return 0, err
  640. }
  641. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  642. if err != nil {
  643. return 0, err
  644. }
  645. }
  646. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  647. if err != nil {
  648. return 0, err
  649. }
  650. }
  651. return len(data), nil
  652. }
  653. // completion
  654. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  655. err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
  656. if err != nil {
  657. return 0, err
  658. }
  659. return len(data), nil
  660. }
  661. func (w *CompleteWriter) Write(data []byte) (int, error) {
  662. code := w.ResponseWriter.Status()
  663. if code != http.StatusOK {
  664. return w.writeError(data)
  665. }
  666. return w.writeResponse(data)
  667. }
  668. func (w *ListWriter) writeResponse(data []byte) (int, error) {
  669. var listResponse api.ListResponse
  670. err := json.Unmarshal(data, &listResponse)
  671. if err != nil {
  672. return 0, err
  673. }
  674. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  675. err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
  676. if err != nil {
  677. return 0, err
  678. }
  679. return len(data), nil
  680. }
  681. func (w *ListWriter) Write(data []byte) (int, error) {
  682. code := w.ResponseWriter.Status()
  683. if code != http.StatusOK {
  684. return w.writeError(data)
  685. }
  686. return w.writeResponse(data)
  687. }
  688. func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
  689. var showResponse api.ShowResponse
  690. err := json.Unmarshal(data, &showResponse)
  691. if err != nil {
  692. return 0, err
  693. }
  694. // retrieve completion
  695. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  696. err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
  697. if err != nil {
  698. return 0, err
  699. }
  700. return len(data), nil
  701. }
  702. func (w *RetrieveWriter) Write(data []byte) (int, error) {
  703. code := w.ResponseWriter.Status()
  704. if code != http.StatusOK {
  705. return w.writeError(data)
  706. }
  707. return w.writeResponse(data)
  708. }
  709. func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
  710. var embedResponse api.EmbedResponse
  711. err := json.Unmarshal(data, &embedResponse)
  712. if err != nil {
  713. return 0, err
  714. }
  715. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  716. err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
  717. if err != nil {
  718. return 0, err
  719. }
  720. return len(data), nil
  721. }
  722. func (w *EmbedWriter) Write(data []byte) (int, error) {
  723. code := w.ResponseWriter.Status()
  724. if code != http.StatusOK {
  725. return w.writeError(data)
  726. }
  727. return w.writeResponse(data)
  728. }
  729. func ListMiddleware() gin.HandlerFunc {
  730. return func(c *gin.Context) {
  731. w := &ListWriter{
  732. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  733. }
  734. c.Writer = w
  735. c.Next()
  736. }
  737. }
  738. func RetrieveMiddleware() gin.HandlerFunc {
  739. return func(c *gin.Context) {
  740. var b bytes.Buffer
  741. if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
  742. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  743. return
  744. }
  745. c.Request.Body = io.NopCloser(&b)
  746. // response writer
  747. w := &RetrieveWriter{
  748. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  749. model: c.Param("model"),
  750. }
  751. c.Writer = w
  752. c.Next()
  753. }
  754. }
  755. func CompletionsMiddleware() gin.HandlerFunc {
  756. return func(c *gin.Context) {
  757. var req CompletionRequest
  758. err := c.ShouldBindJSON(&req)
  759. if err != nil {
  760. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  761. return
  762. }
  763. var b bytes.Buffer
  764. genReq, err := fromCompleteRequest(req)
  765. if err != nil {
  766. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  767. return
  768. }
  769. if err := json.NewEncoder(&b).Encode(genReq); err != nil {
  770. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  771. return
  772. }
  773. c.Request.Body = io.NopCloser(&b)
  774. w := &CompleteWriter{
  775. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  776. stream: req.Stream,
  777. id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
  778. streamOptions: req.StreamOptions,
  779. }
  780. c.Writer = w
  781. c.Next()
  782. }
  783. }
  784. func EmbeddingsMiddleware() gin.HandlerFunc {
  785. return func(c *gin.Context) {
  786. var req EmbedRequest
  787. err := c.ShouldBindJSON(&req)
  788. if err != nil {
  789. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  790. return
  791. }
  792. if req.Input == "" {
  793. req.Input = []string{""}
  794. }
  795. if req.Input == nil {
  796. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
  797. return
  798. }
  799. if v, ok := req.Input.([]any); ok && len(v) == 0 {
  800. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
  801. return
  802. }
  803. var b bytes.Buffer
  804. if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
  805. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  806. return
  807. }
  808. c.Request.Body = io.NopCloser(&b)
  809. w := &EmbedWriter{
  810. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  811. model: req.Model,
  812. }
  813. c.Writer = w
  814. c.Next()
  815. }
  816. }
  817. func ChatMiddleware() gin.HandlerFunc {
  818. return func(c *gin.Context) {
  819. var req ChatCompletionRequest
  820. err := c.ShouldBindJSON(&req)
  821. if err != nil {
  822. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  823. return
  824. }
  825. if len(req.Messages) == 0 {
  826. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
  827. return
  828. }
  829. var b bytes.Buffer
  830. chatReq, err := fromChatRequest(req)
  831. if err != nil {
  832. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  833. return
  834. }
  835. if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
  836. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  837. return
  838. }
  839. c.Request.Body = io.NopCloser(&b)
  840. w := &ChatWriter{
  841. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  842. stream: req.Stream,
  843. id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
  844. streamOptions: req.StreamOptions,
  845. }
  846. c.Writer = w
  847. c.Next()
  848. }
  849. }