openai.go 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002
  1. // openai package provides middleware for partial compatibility with the OpenAI REST API
  2. package openai
  3. import (
  4. "bytes"
  5. "encoding/base64"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log/slog"
  11. "math/rand"
  12. "net/http"
  13. "strings"
  14. "time"
  15. "github.com/gin-gonic/gin"
  16. "github.com/ollama/ollama/api"
  17. "github.com/ollama/ollama/types/model"
  18. )
  19. type Error struct {
  20. Message string `json:"message"`
  21. Type string `json:"type"`
  22. Param interface{} `json:"param"`
  23. Code *string `json:"code"`
  24. }
  25. type ErrorResponse struct {
  26. Error Error `json:"error"`
  27. }
  28. type Message struct {
  29. Role string `json:"role"`
  30. Content any `json:"content"`
  31. ToolCalls []ToolCall `json:"tool_calls,omitempty"`
  32. }
  33. type Choice struct {
  34. Index int `json:"index"`
  35. Message Message `json:"message"`
  36. FinishReason *string `json:"finish_reason"`
  37. }
  38. type ChunkChoice struct {
  39. Index int `json:"index"`
  40. Delta Message `json:"delta"`
  41. FinishReason *string `json:"finish_reason"`
  42. }
  43. type CompleteChunkChoice struct {
  44. Text string `json:"text"`
  45. Index int `json:"index"`
  46. FinishReason *string `json:"finish_reason"`
  47. }
  48. type Usage struct {
  49. PromptTokens int `json:"prompt_tokens"`
  50. CompletionTokens int `json:"completion_tokens"`
  51. TotalTokens int `json:"total_tokens"`
  52. }
  53. // ChunkUsage is an alias for Usage with the ability to marshal a marker
  54. // value as null. This is to allow omitting the field in chunks when usage
  55. // isn't requested, and otherwise return null on non-final chunks when it
  56. // is requested to follow OpenAI's behavior.
  57. type ChunkUsage = Usage
  58. // var nullChunkUsage = ChunkUsage{}
  59. // func (u *ChunkUsage) MarshalJSON() ([]byte, error) {
  60. // if u == &nullChunkUsage {
  61. // return []byte("null"), nil
  62. // }
  63. // return json.Marshal(*u)
  64. // }
  65. type ResponseFormat struct {
  66. Type string `json:"type"`
  67. JsonSchema *JsonSchema `json:"json_schema,omitempty"`
  68. }
  69. type JsonSchema struct {
  70. Schema map[string]any `json:"schema"`
  71. }
  72. type EmbedRequest struct {
  73. Input any `json:"input"`
  74. Model string `json:"model"`
  75. }
  76. type StreamOptions struct {
  77. IncludeUsage bool `json:"include_usage"`
  78. }
  79. type ChatCompletionRequest struct {
  80. Model string `json:"model"`
  81. Messages []Message `json:"messages"`
  82. Stream bool `json:"stream"`
  83. StreamOptions *StreamOptions `json:"stream_options"`
  84. MaxTokens *int `json:"max_tokens"`
  85. Seed *int `json:"seed"`
  86. Stop any `json:"stop"`
  87. Temperature *float64 `json:"temperature"`
  88. FrequencyPenalty *float64 `json:"frequency_penalty"`
  89. PresencePenalty *float64 `json:"presence_penalty"`
  90. TopP *float64 `json:"top_p"`
  91. ResponseFormat *ResponseFormat `json:"response_format"`
  92. Tools []api.Tool `json:"tools"`
  93. }
  94. type ChatCompletion struct {
  95. Id string `json:"id"`
  96. Object string `json:"object"`
  97. Created int64 `json:"created"`
  98. Model string `json:"model"`
  99. SystemFingerprint string `json:"system_fingerprint"`
  100. Choices []Choice `json:"choices"`
  101. Usage Usage `json:"usage,omitempty"`
  102. }
  103. type ChatCompletionChunk struct {
  104. Id string `json:"id"`
  105. Object string `json:"object"`
  106. Created int64 `json:"created"`
  107. Model string `json:"model"`
  108. SystemFingerprint string `json:"system_fingerprint"`
  109. Choices []ChunkChoice `json:"choices"`
  110. Usage *ChunkUsage `json:"usage,omitempty"`
  111. }
  112. // TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
  113. type CompletionRequest struct {
  114. Model string `json:"model"`
  115. Prompt string `json:"prompt"`
  116. FrequencyPenalty float32 `json:"frequency_penalty"`
  117. MaxTokens *int `json:"max_tokens"`
  118. PresencePenalty float32 `json:"presence_penalty"`
  119. Seed *int `json:"seed"`
  120. Stop any `json:"stop"`
  121. Stream bool `json:"stream"`
  122. StreamOptions *StreamOptions `json:"stream_options"`
  123. Temperature *float32 `json:"temperature"`
  124. TopP float32 `json:"top_p"`
  125. Suffix string `json:"suffix"`
  126. }
  127. type Completion struct {
  128. Id string `json:"id"`
  129. Object string `json:"object"`
  130. Created int64 `json:"created"`
  131. Model string `json:"model"`
  132. SystemFingerprint string `json:"system_fingerprint"`
  133. Choices []CompleteChunkChoice `json:"choices"`
  134. Usage Usage `json:"usage,omitempty"`
  135. }
  136. type CompletionChunk struct {
  137. Id string `json:"id"`
  138. Object string `json:"object"`
  139. Created int64 `json:"created"`
  140. Choices []CompleteChunkChoice `json:"choices"`
  141. Model string `json:"model"`
  142. SystemFingerprint string `json:"system_fingerprint"`
  143. Usage *ChunkUsage `json:"usage,omitempty"`
  144. }
  145. type ToolCall struct {
  146. ID string `json:"id"`
  147. Index int `json:"index"`
  148. Type string `json:"type"`
  149. Function struct {
  150. Name string `json:"name"`
  151. Arguments string `json:"arguments"`
  152. } `json:"function"`
  153. }
  154. type Model struct {
  155. Id string `json:"id"`
  156. Object string `json:"object"`
  157. Created int64 `json:"created"`
  158. OwnedBy string `json:"owned_by"`
  159. }
  160. type Embedding struct {
  161. Object string `json:"object"`
  162. Embedding []float32 `json:"embedding"`
  163. Index int `json:"index"`
  164. }
  165. type ListCompletion struct {
  166. Object string `json:"object"`
  167. Data []Model `json:"data"`
  168. }
  169. type EmbeddingList struct {
  170. Object string `json:"object"`
  171. Data []Embedding `json:"data"`
  172. Model string `json:"model"`
  173. Usage EmbeddingUsage `json:"usage,omitempty"`
  174. }
  175. type EmbeddingUsage struct {
  176. PromptTokens int `json:"prompt_tokens"`
  177. TotalTokens int `json:"total_tokens"`
  178. }
  179. func NewError(code int, message string) ErrorResponse {
  180. var etype string
  181. switch code {
  182. case http.StatusBadRequest:
  183. etype = "invalid_request_error"
  184. case http.StatusNotFound:
  185. etype = "not_found_error"
  186. default:
  187. etype = "api_error"
  188. }
  189. return ErrorResponse{Error{Type: etype, Message: message}}
  190. }
  191. func toUsage(r api.ChatResponse) Usage {
  192. return Usage{
  193. PromptTokens: r.PromptEvalCount,
  194. CompletionTokens: r.EvalCount,
  195. TotalTokens: r.PromptEvalCount + r.EvalCount,
  196. }
  197. }
  198. func toolCallId() string {
  199. const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
  200. b := make([]byte, 8)
  201. for i := range b {
  202. b[i] = letterBytes[rand.Intn(len(letterBytes))]
  203. }
  204. return "call_" + strings.ToLower(string(b))
  205. }
  206. func toToolCalls(tc []api.ToolCall) []ToolCall {
  207. toolCalls := make([]ToolCall, len(tc))
  208. for i, tc := range tc {
  209. toolCalls[i].ID = toolCallId()
  210. toolCalls[i].Type = "function"
  211. toolCalls[i].Function.Name = tc.Function.Name
  212. toolCalls[i].Index = tc.Function.Index
  213. args, err := json.Marshal(tc.Function.Arguments)
  214. if err != nil {
  215. slog.Error("could not marshall function arguments to json", "error", err)
  216. continue
  217. }
  218. toolCalls[i].Function.Arguments = string(args)
  219. }
  220. return toolCalls
  221. }
  222. func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
  223. toolCalls := toToolCalls(r.Message.ToolCalls)
  224. return ChatCompletion{
  225. Id: id,
  226. Object: "chat.completion",
  227. Created: r.CreatedAt.Unix(),
  228. Model: r.Model,
  229. SystemFingerprint: "fp_ollama",
  230. Choices: []Choice{{
  231. Index: 0,
  232. Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
  233. FinishReason: func(reason string) *string {
  234. if len(toolCalls) > 0 {
  235. reason = "tool_calls"
  236. }
  237. if len(reason) > 0 {
  238. return &reason
  239. }
  240. return nil
  241. }(r.DoneReason),
  242. }},
  243. Usage: toUsage(r),
  244. }
  245. }
  246. func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
  247. toolCalls := toToolCalls(r.Message.ToolCalls)
  248. return ChatCompletionChunk{
  249. Id: id,
  250. Object: "chat.completion.chunk",
  251. Created: time.Now().Unix(),
  252. Model: r.Model,
  253. SystemFingerprint: "fp_ollama",
  254. Choices: []ChunkChoice{{
  255. Index: 0,
  256. Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
  257. FinishReason: func(reason string) *string {
  258. if len(reason) > 0 {
  259. return &reason
  260. }
  261. return nil
  262. }(r.DoneReason),
  263. }},
  264. }
  265. }
  266. func toUsageGenerate(r api.GenerateResponse) Usage {
  267. return Usage{
  268. PromptTokens: r.PromptEvalCount,
  269. CompletionTokens: r.EvalCount,
  270. TotalTokens: r.PromptEvalCount + r.EvalCount,
  271. }
  272. }
  273. func toCompletion(id string, r api.GenerateResponse) Completion {
  274. return Completion{
  275. Id: id,
  276. Object: "text_completion",
  277. Created: r.CreatedAt.Unix(),
  278. Model: r.Model,
  279. SystemFingerprint: "fp_ollama",
  280. Choices: []CompleteChunkChoice{{
  281. Text: r.Response,
  282. Index: 0,
  283. FinishReason: func(reason string) *string {
  284. if len(reason) > 0 {
  285. return &reason
  286. }
  287. return nil
  288. }(r.DoneReason),
  289. }},
  290. Usage: toUsageGenerate(r),
  291. }
  292. }
  293. func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
  294. return CompletionChunk{
  295. Id: id,
  296. Object: "text_completion",
  297. Created: time.Now().Unix(),
  298. Model: r.Model,
  299. SystemFingerprint: "fp_ollama",
  300. Choices: []CompleteChunkChoice{{
  301. Text: r.Response,
  302. Index: 0,
  303. FinishReason: func(reason string) *string {
  304. if len(reason) > 0 {
  305. return &reason
  306. }
  307. return nil
  308. }(r.DoneReason),
  309. }},
  310. }
  311. }
  312. func toListCompletion(r api.ListResponse) ListCompletion {
  313. var data []Model
  314. for _, m := range r.Models {
  315. data = append(data, Model{
  316. Id: m.Name,
  317. Object: "model",
  318. Created: m.ModifiedAt.Unix(),
  319. OwnedBy: model.ParseName(m.Name).Namespace,
  320. })
  321. }
  322. return ListCompletion{
  323. Object: "list",
  324. Data: data,
  325. }
  326. }
  327. func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
  328. if r.Embeddings != nil {
  329. var data []Embedding
  330. for i, e := range r.Embeddings {
  331. data = append(data, Embedding{
  332. Object: "embedding",
  333. Embedding: e,
  334. Index: i,
  335. })
  336. }
  337. return EmbeddingList{
  338. Object: "list",
  339. Data: data,
  340. Model: model,
  341. Usage: EmbeddingUsage{
  342. PromptTokens: r.PromptEvalCount,
  343. TotalTokens: r.PromptEvalCount,
  344. },
  345. }
  346. }
  347. return EmbeddingList{}
  348. }
  349. func toModel(r api.ShowResponse, m string) Model {
  350. return Model{
  351. Id: m,
  352. Object: "model",
  353. Created: r.ModifiedAt.Unix(),
  354. OwnedBy: model.ParseName(m).Namespace,
  355. }
  356. }
  357. func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
  358. var messages []api.Message
  359. for _, msg := range r.Messages {
  360. switch content := msg.Content.(type) {
  361. case string:
  362. messages = append(messages, api.Message{Role: msg.Role, Content: content})
  363. case []any:
  364. for _, c := range content {
  365. data, ok := c.(map[string]any)
  366. if !ok {
  367. return nil, errors.New("invalid message format")
  368. }
  369. switch data["type"] {
  370. case "text":
  371. text, ok := data["text"].(string)
  372. if !ok {
  373. return nil, errors.New("invalid message format")
  374. }
  375. messages = append(messages, api.Message{Role: msg.Role, Content: text})
  376. case "image_url":
  377. var url string
  378. if urlMap, ok := data["image_url"].(map[string]any); ok {
  379. if url, ok = urlMap["url"].(string); !ok {
  380. return nil, errors.New("invalid message format")
  381. }
  382. } else {
  383. if url, ok = data["image_url"].(string); !ok {
  384. return nil, errors.New("invalid message format")
  385. }
  386. }
  387. types := []string{"jpeg", "jpg", "png"}
  388. valid := false
  389. for _, t := range types {
  390. prefix := "data:image/" + t + ";base64,"
  391. if strings.HasPrefix(url, prefix) {
  392. url = strings.TrimPrefix(url, prefix)
  393. valid = true
  394. break
  395. }
  396. }
  397. if !valid {
  398. return nil, errors.New("invalid image input")
  399. }
  400. img, err := base64.StdEncoding.DecodeString(url)
  401. if err != nil {
  402. return nil, errors.New("invalid message format")
  403. }
  404. messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
  405. default:
  406. return nil, errors.New("invalid message format")
  407. }
  408. }
  409. default:
  410. if msg.ToolCalls == nil {
  411. return nil, fmt.Errorf("invalid message content type: %T", content)
  412. }
  413. toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
  414. for i, tc := range msg.ToolCalls {
  415. toolCalls[i].Function.Name = tc.Function.Name
  416. err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
  417. if err != nil {
  418. return nil, errors.New("invalid tool call arguments")
  419. }
  420. }
  421. messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
  422. }
  423. }
  424. options := make(map[string]interface{})
  425. switch stop := r.Stop.(type) {
  426. case string:
  427. options["stop"] = []string{stop}
  428. case []any:
  429. var stops []string
  430. for _, s := range stop {
  431. if str, ok := s.(string); ok {
  432. stops = append(stops, str)
  433. }
  434. }
  435. options["stop"] = stops
  436. }
  437. if r.MaxTokens != nil {
  438. options["num_predict"] = *r.MaxTokens
  439. }
  440. if r.Temperature != nil {
  441. options["temperature"] = *r.Temperature
  442. } else {
  443. options["temperature"] = 1.0
  444. }
  445. if r.Seed != nil {
  446. options["seed"] = *r.Seed
  447. }
  448. if r.FrequencyPenalty != nil {
  449. options["frequency_penalty"] = *r.FrequencyPenalty
  450. }
  451. if r.PresencePenalty != nil {
  452. options["presence_penalty"] = *r.PresencePenalty
  453. }
  454. if r.TopP != nil {
  455. options["top_p"] = *r.TopP
  456. } else {
  457. options["top_p"] = 1.0
  458. }
  459. var format json.RawMessage
  460. if r.ResponseFormat != nil {
  461. switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
  462. // Support the old "json_object" type for OpenAI compatibility
  463. case "json_object":
  464. format = json.RawMessage(`"json"`)
  465. case "json_schema":
  466. if r.ResponseFormat.JsonSchema != nil {
  467. schema, err := json.Marshal(r.ResponseFormat.JsonSchema.Schema)
  468. if err != nil {
  469. return nil, fmt.Errorf("failed to marshal json schema: %w", err)
  470. }
  471. format = schema
  472. }
  473. }
  474. }
  475. return &api.ChatRequest{
  476. Model: r.Model,
  477. Messages: messages,
  478. Format: format,
  479. Options: options,
  480. Stream: &r.Stream,
  481. Tools: r.Tools,
  482. }, nil
  483. }
  484. func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
  485. options := make(map[string]any)
  486. switch stop := r.Stop.(type) {
  487. case string:
  488. options["stop"] = []string{stop}
  489. case []any:
  490. var stops []string
  491. for _, s := range stop {
  492. if str, ok := s.(string); ok {
  493. stops = append(stops, str)
  494. } else {
  495. return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
  496. }
  497. }
  498. options["stop"] = stops
  499. }
  500. if r.MaxTokens != nil {
  501. options["num_predict"] = *r.MaxTokens
  502. }
  503. if r.Temperature != nil {
  504. options["temperature"] = *r.Temperature
  505. } else {
  506. options["temperature"] = 1.0
  507. }
  508. if r.Seed != nil {
  509. options["seed"] = *r.Seed
  510. }
  511. options["frequency_penalty"] = r.FrequencyPenalty
  512. options["presence_penalty"] = r.PresencePenalty
  513. if r.TopP != 0.0 {
  514. options["top_p"] = r.TopP
  515. } else {
  516. options["top_p"] = 1.0
  517. }
  518. return api.GenerateRequest{
  519. Model: r.Model,
  520. Prompt: r.Prompt,
  521. Options: options,
  522. Stream: &r.Stream,
  523. Suffix: r.Suffix,
  524. }, nil
  525. }
  526. type BaseWriter struct {
  527. gin.ResponseWriter
  528. }
  529. type ChatWriter struct {
  530. stream bool
  531. streamOptions *StreamOptions
  532. id string
  533. BaseWriter
  534. }
  535. type CompleteWriter struct {
  536. stream bool
  537. streamOptions *StreamOptions
  538. id string
  539. BaseWriter
  540. }
  541. type ListWriter struct {
  542. BaseWriter
  543. }
  544. type RetrieveWriter struct {
  545. BaseWriter
  546. model string
  547. }
  548. type EmbedWriter struct {
  549. BaseWriter
  550. model string
  551. }
  552. func (w *BaseWriter) writeError(data []byte) (int, error) {
  553. var serr api.StatusError
  554. err := json.Unmarshal(data, &serr)
  555. if err != nil {
  556. return 0, err
  557. }
  558. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  559. err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
  560. if err != nil {
  561. return 0, err
  562. }
  563. return len(data), nil
  564. }
  565. func (w *ChatWriter) writeResponse(data []byte) (int, error) {
  566. var chatResponse api.ChatResponse
  567. err := json.Unmarshal(data, &chatResponse)
  568. if err != nil {
  569. return 0, err
  570. }
  571. // chat chunk
  572. if w.stream {
  573. c := toChunk(w.id, chatResponse)
  574. if w.streamOptions != nil && w.streamOptions.IncludeUsage {
  575. c.Usage = &ChunkUsage{}
  576. }
  577. d, err := json.Marshal(c)
  578. if err != nil {
  579. return 0, err
  580. }
  581. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  582. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  583. if err != nil {
  584. return 0, err
  585. }
  586. if chatResponse.Done {
  587. if w.streamOptions != nil && w.streamOptions.IncludeUsage {
  588. u := toUsage(chatResponse)
  589. d, err := json.Marshal(ChatCompletionChunk{Choices: []ChunkChoice{}, Usage: &u})
  590. if err != nil {
  591. return 0, err
  592. }
  593. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  594. if err != nil {
  595. return 0, err
  596. }
  597. }
  598. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  599. if err != nil {
  600. return 0, err
  601. }
  602. }
  603. return len(data), nil
  604. }
  605. // chat completion
  606. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  607. err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
  608. if err != nil {
  609. return 0, err
  610. }
  611. return len(data), nil
  612. }
  613. func (w *ChatWriter) Write(data []byte) (int, error) {
  614. code := w.ResponseWriter.Status()
  615. if code != http.StatusOK {
  616. return w.writeError(data)
  617. }
  618. return w.writeResponse(data)
  619. }
  620. func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
  621. var generateResponse api.GenerateResponse
  622. err := json.Unmarshal(data, &generateResponse)
  623. if err != nil {
  624. return 0, err
  625. }
  626. // completion chunk
  627. if w.stream {
  628. c := toCompleteChunk(w.id, generateResponse)
  629. if w.streamOptions != nil && w.streamOptions.IncludeUsage {
  630. c.Usage = &ChunkUsage{}
  631. }
  632. d, err := json.Marshal(c)
  633. if err != nil {
  634. return 0, err
  635. }
  636. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  637. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  638. if err != nil {
  639. return 0, err
  640. }
  641. if generateResponse.Done {
  642. if w.streamOptions != nil && w.streamOptions.IncludeUsage {
  643. u := toUsageGenerate(generateResponse)
  644. d, err := json.Marshal(CompletionChunk{Choices: []CompleteChunkChoice{}, Usage: &u})
  645. if err != nil {
  646. return 0, err
  647. }
  648. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  649. if err != nil {
  650. return 0, err
  651. }
  652. }
  653. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  654. if err != nil {
  655. return 0, err
  656. }
  657. }
  658. return len(data), nil
  659. }
  660. // completion
  661. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  662. err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
  663. if err != nil {
  664. return 0, err
  665. }
  666. return len(data), nil
  667. }
  668. func (w *CompleteWriter) Write(data []byte) (int, error) {
  669. code := w.ResponseWriter.Status()
  670. if code != http.StatusOK {
  671. return w.writeError(data)
  672. }
  673. return w.writeResponse(data)
  674. }
  675. func (w *ListWriter) writeResponse(data []byte) (int, error) {
  676. var listResponse api.ListResponse
  677. err := json.Unmarshal(data, &listResponse)
  678. if err != nil {
  679. return 0, err
  680. }
  681. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  682. err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
  683. if err != nil {
  684. return 0, err
  685. }
  686. return len(data), nil
  687. }
  688. func (w *ListWriter) Write(data []byte) (int, error) {
  689. code := w.ResponseWriter.Status()
  690. if code != http.StatusOK {
  691. return w.writeError(data)
  692. }
  693. return w.writeResponse(data)
  694. }
  695. func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
  696. var showResponse api.ShowResponse
  697. err := json.Unmarshal(data, &showResponse)
  698. if err != nil {
  699. return 0, err
  700. }
  701. // retrieve completion
  702. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  703. err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
  704. if err != nil {
  705. return 0, err
  706. }
  707. return len(data), nil
  708. }
  709. func (w *RetrieveWriter) Write(data []byte) (int, error) {
  710. code := w.ResponseWriter.Status()
  711. if code != http.StatusOK {
  712. return w.writeError(data)
  713. }
  714. return w.writeResponse(data)
  715. }
  716. func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
  717. var embedResponse api.EmbedResponse
  718. err := json.Unmarshal(data, &embedResponse)
  719. if err != nil {
  720. return 0, err
  721. }
  722. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  723. err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
  724. if err != nil {
  725. return 0, err
  726. }
  727. return len(data), nil
  728. }
  729. func (w *EmbedWriter) Write(data []byte) (int, error) {
  730. code := w.ResponseWriter.Status()
  731. if code != http.StatusOK {
  732. return w.writeError(data)
  733. }
  734. return w.writeResponse(data)
  735. }
  736. func ListMiddleware() gin.HandlerFunc {
  737. return func(c *gin.Context) {
  738. w := &ListWriter{
  739. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  740. }
  741. c.Writer = w
  742. c.Next()
  743. }
  744. }
  745. func RetrieveMiddleware() gin.HandlerFunc {
  746. return func(c *gin.Context) {
  747. var b bytes.Buffer
  748. if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
  749. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  750. return
  751. }
  752. c.Request.Body = io.NopCloser(&b)
  753. // response writer
  754. w := &RetrieveWriter{
  755. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  756. model: c.Param("model"),
  757. }
  758. c.Writer = w
  759. c.Next()
  760. }
  761. }
  762. func CompletionsMiddleware() gin.HandlerFunc {
  763. return func(c *gin.Context) {
  764. var req CompletionRequest
  765. err := c.ShouldBindJSON(&req)
  766. if err != nil {
  767. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  768. return
  769. }
  770. var b bytes.Buffer
  771. genReq, err := fromCompleteRequest(req)
  772. if err != nil {
  773. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  774. return
  775. }
  776. if err := json.NewEncoder(&b).Encode(genReq); err != nil {
  777. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  778. return
  779. }
  780. c.Request.Body = io.NopCloser(&b)
  781. w := &CompleteWriter{
  782. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  783. stream: req.Stream,
  784. id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
  785. streamOptions: req.StreamOptions,
  786. }
  787. c.Writer = w
  788. c.Next()
  789. }
  790. }
  791. func EmbeddingsMiddleware() gin.HandlerFunc {
  792. return func(c *gin.Context) {
  793. var req EmbedRequest
  794. err := c.ShouldBindJSON(&req)
  795. if err != nil {
  796. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  797. return
  798. }
  799. if req.Input == "" {
  800. req.Input = []string{""}
  801. }
  802. if req.Input == nil {
  803. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
  804. return
  805. }
  806. if v, ok := req.Input.([]any); ok && len(v) == 0 {
  807. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
  808. return
  809. }
  810. var b bytes.Buffer
  811. if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
  812. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  813. return
  814. }
  815. c.Request.Body = io.NopCloser(&b)
  816. w := &EmbedWriter{
  817. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  818. model: req.Model,
  819. }
  820. c.Writer = w
  821. c.Next()
  822. }
  823. }
  824. func ChatMiddleware() gin.HandlerFunc {
  825. return func(c *gin.Context) {
  826. var req ChatCompletionRequest
  827. err := c.ShouldBindJSON(&req)
  828. if err != nil {
  829. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  830. return
  831. }
  832. if len(req.Messages) == 0 {
  833. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
  834. return
  835. }
  836. var b bytes.Buffer
  837. chatReq, err := fromChatRequest(req)
  838. if err != nil {
  839. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  840. return
  841. }
  842. if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
  843. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  844. return
  845. }
  846. c.Request.Body = io.NopCloser(&b)
  847. w := &ChatWriter{
  848. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  849. stream: req.Stream,
  850. id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
  851. streamOptions: req.StreamOptions,
  852. }
  853. c.Writer = w
  854. c.Next()
  855. }
  856. }