openai.go 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002
  1. // openai package provides middleware for partial compatibility with the OpenAI REST API
  2. package openai
  3. import (
  4. "bytes"
  5. "encoding/base64"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log/slog"
  11. "math/rand"
  12. "net/http"
  13. "strings"
  14. "time"
  15. "github.com/gin-gonic/gin"
  16. "github.com/ollama/ollama/api"
  17. "github.com/ollama/ollama/types/model"
  18. )
  19. type Error struct {
  20. Message string `json:"message"`
  21. Type string `json:"type"`
  22. Param interface{} `json:"param"`
  23. Code *string `json:"code"`
  24. }
  25. type ErrorResponse struct {
  26. Error Error `json:"error"`
  27. }
  28. type Message struct {
  29. Role string `json:"role"`
  30. Content any `json:"content"`
  31. ToolCalls []ToolCall `json:"tool_calls,omitempty"`
  32. }
  33. type Choice struct {
  34. Index int `json:"index"`
  35. Message Message `json:"message"`
  36. FinishReason *string `json:"finish_reason"`
  37. }
  38. type ChunkChoice struct {
  39. Index int `json:"index"`
  40. Delta Message `json:"delta"`
  41. FinishReason *string `json:"finish_reason"`
  42. }
  43. type CompleteChunkChoice struct {
  44. Text string `json:"text"`
  45. Index int `json:"index"`
  46. FinishReason *string `json:"finish_reason"`
  47. }
  48. type Usage struct {
  49. PromptTokens int `json:"prompt_tokens"`
  50. CompletionTokens int `json:"completion_tokens"`
  51. TotalTokens int `json:"total_tokens"`
  52. }
  53. type ResponseFormat struct {
  54. Type string `json:"type"`
  55. JsonSchema *JsonSchema `json:"json_schema,omitempty"`
  56. }
  57. type JsonSchema struct {
  58. Schema json.RawMessage `json:"schema"`
  59. }
  60. type EmbedRequest struct {
  61. Input any `json:"input"`
  62. Model string `json:"model"`
  63. }
  64. type StreamOptions struct {
  65. IncludeUsage bool `json:"include_usage"`
  66. }
  67. type ChatCompletionRequest struct {
  68. Model string `json:"model"`
  69. Messages []Message `json:"messages"`
  70. Stream bool `json:"stream"`
  71. StreamOptions *StreamOptions `json:"stream_options"`
  72. MaxCompletionTokens *int `json:"max_completion_tokens"`
  73. MaxTokens *int `json:"max_tokens" deprecated:"use max_completion_tokens instead"`
  74. Seed *int `json:"seed"`
  75. Stop any `json:"stop"`
  76. Temperature *float64 `json:"temperature"`
  77. FrequencyPenalty *float64 `json:"frequency_penalty"`
  78. PresencePenalty *float64 `json:"presence_penalty"`
  79. TopP *float64 `json:"top_p"`
  80. ResponseFormat *ResponseFormat `json:"response_format"`
  81. Tools []api.Tool `json:"tools"`
  82. NumCtx *int `json:"num_ctx"`
  83. }
  84. type ChatCompletion struct {
  85. Id string `json:"id"`
  86. Object string `json:"object"`
  87. Created int64 `json:"created"`
  88. Model string `json:"model"`
  89. SystemFingerprint string `json:"system_fingerprint"`
  90. Choices []Choice `json:"choices"`
  91. Usage Usage `json:"usage,omitempty"`
  92. }
  93. type ChatCompletionChunk struct {
  94. Id string `json:"id"`
  95. Object string `json:"object"`
  96. Created int64 `json:"created"`
  97. Model string `json:"model"`
  98. SystemFingerprint string `json:"system_fingerprint"`
  99. Choices []ChunkChoice `json:"choices"`
  100. Usage *Usage `json:"usage,omitempty"`
  101. }
  102. // TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
  103. type CompletionRequest struct {
  104. Model string `json:"model"`
  105. Prompt string `json:"prompt"`
  106. FrequencyPenalty float32 `json:"frequency_penalty"`
  107. MaxTokens *int `json:"max_tokens"`
  108. PresencePenalty float32 `json:"presence_penalty"`
  109. Seed *int `json:"seed"`
  110. Stop any `json:"stop"`
  111. Stream bool `json:"stream"`
  112. StreamOptions *StreamOptions `json:"stream_options"`
  113. Temperature *float32 `json:"temperature"`
  114. TopP float32 `json:"top_p"`
  115. Suffix string `json:"suffix"`
  116. }
  117. type Completion struct {
  118. Id string `json:"id"`
  119. Object string `json:"object"`
  120. Created int64 `json:"created"`
  121. Model string `json:"model"`
  122. SystemFingerprint string `json:"system_fingerprint"`
  123. Choices []CompleteChunkChoice `json:"choices"`
  124. Usage Usage `json:"usage,omitempty"`
  125. }
  126. type CompletionChunk struct {
  127. Id string `json:"id"`
  128. Object string `json:"object"`
  129. Created int64 `json:"created"`
  130. Choices []CompleteChunkChoice `json:"choices"`
  131. Model string `json:"model"`
  132. SystemFingerprint string `json:"system_fingerprint"`
  133. Usage *Usage `json:"usage,omitempty"`
  134. }
  135. type ToolCall struct {
  136. ID string `json:"id"`
  137. Index int `json:"index"`
  138. Type string `json:"type"`
  139. Function struct {
  140. Name string `json:"name"`
  141. Arguments string `json:"arguments"`
  142. } `json:"function"`
  143. }
  144. type Model struct {
  145. Id string `json:"id"`
  146. Object string `json:"object"`
  147. Created int64 `json:"created"`
  148. OwnedBy string `json:"owned_by"`
  149. }
  150. type Embedding struct {
  151. Object string `json:"object"`
  152. Embedding []float32 `json:"embedding"`
  153. Index int `json:"index"`
  154. }
  155. type ListCompletion struct {
  156. Object string `json:"object"`
  157. Data []Model `json:"data"`
  158. }
  159. type EmbeddingList struct {
  160. Object string `json:"object"`
  161. Data []Embedding `json:"data"`
  162. Model string `json:"model"`
  163. Usage EmbeddingUsage `json:"usage,omitempty"`
  164. }
  165. type EmbeddingUsage struct {
  166. PromptTokens int `json:"prompt_tokens"`
  167. TotalTokens int `json:"total_tokens"`
  168. }
  169. func NewError(code int, message string) ErrorResponse {
  170. var etype string
  171. switch code {
  172. case http.StatusBadRequest:
  173. etype = "invalid_request_error"
  174. case http.StatusNotFound:
  175. etype = "not_found_error"
  176. default:
  177. etype = "api_error"
  178. }
  179. return ErrorResponse{Error{Type: etype, Message: message}}
  180. }
  181. func toUsage(r api.ChatResponse) Usage {
  182. return Usage{
  183. PromptTokens: r.PromptEvalCount,
  184. CompletionTokens: r.EvalCount,
  185. TotalTokens: r.PromptEvalCount + r.EvalCount,
  186. }
  187. }
  188. func toolCallId() string {
  189. const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
  190. b := make([]byte, 8)
  191. for i := range b {
  192. b[i] = letterBytes[rand.Intn(len(letterBytes))]
  193. }
  194. return "call_" + strings.ToLower(string(b))
  195. }
  196. func toToolCalls(tc []api.ToolCall) []ToolCall {
  197. toolCalls := make([]ToolCall, len(tc))
  198. for i, tc := range tc {
  199. toolCalls[i].ID = toolCallId()
  200. toolCalls[i].Type = "function"
  201. toolCalls[i].Function.Name = tc.Function.Name
  202. toolCalls[i].Index = tc.Function.Index
  203. args, err := json.Marshal(tc.Function.Arguments)
  204. if err != nil {
  205. slog.Error("could not marshall function arguments to json", "error", err)
  206. continue
  207. }
  208. toolCalls[i].Function.Arguments = string(args)
  209. }
  210. return toolCalls
  211. }
  212. func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
  213. toolCalls := toToolCalls(r.Message.ToolCalls)
  214. return ChatCompletion{
  215. Id: id,
  216. Object: "chat.completion",
  217. Created: r.CreatedAt.Unix(),
  218. Model: r.Model,
  219. SystemFingerprint: "fp_ollama",
  220. Choices: []Choice{{
  221. Index: 0,
  222. Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
  223. FinishReason: func(reason string) *string {
  224. if len(toolCalls) > 0 {
  225. reason = "tool_calls"
  226. }
  227. if len(reason) > 0 {
  228. return &reason
  229. }
  230. return nil
  231. }(r.DoneReason),
  232. }},
  233. Usage: toUsage(r),
  234. }
  235. }
  236. func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
  237. toolCalls := toToolCalls(r.Message.ToolCalls)
  238. return ChatCompletionChunk{
  239. Id: id,
  240. Object: "chat.completion.chunk",
  241. Created: time.Now().Unix(),
  242. Model: r.Model,
  243. SystemFingerprint: "fp_ollama",
  244. Choices: []ChunkChoice{{
  245. Index: 0,
  246. Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
  247. FinishReason: func(reason string) *string {
  248. if len(reason) > 0 {
  249. return &reason
  250. }
  251. return nil
  252. }(r.DoneReason),
  253. }},
  254. }
  255. }
  256. func toUsageGenerate(r api.GenerateResponse) Usage {
  257. return Usage{
  258. PromptTokens: r.PromptEvalCount,
  259. CompletionTokens: r.EvalCount,
  260. TotalTokens: r.PromptEvalCount + r.EvalCount,
  261. }
  262. }
  263. func toCompletion(id string, r api.GenerateResponse) Completion {
  264. return Completion{
  265. Id: id,
  266. Object: "text_completion",
  267. Created: r.CreatedAt.Unix(),
  268. Model: r.Model,
  269. SystemFingerprint: "fp_ollama",
  270. Choices: []CompleteChunkChoice{{
  271. Text: r.Response,
  272. Index: 0,
  273. FinishReason: func(reason string) *string {
  274. if len(reason) > 0 {
  275. return &reason
  276. }
  277. return nil
  278. }(r.DoneReason),
  279. }},
  280. Usage: toUsageGenerate(r),
  281. }
  282. }
  283. func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
  284. return CompletionChunk{
  285. Id: id,
  286. Object: "text_completion",
  287. Created: time.Now().Unix(),
  288. Model: r.Model,
  289. SystemFingerprint: "fp_ollama",
  290. Choices: []CompleteChunkChoice{{
  291. Text: r.Response,
  292. Index: 0,
  293. FinishReason: func(reason string) *string {
  294. if len(reason) > 0 {
  295. return &reason
  296. }
  297. return nil
  298. }(r.DoneReason),
  299. }},
  300. }
  301. }
  302. func toListCompletion(r api.ListResponse) ListCompletion {
  303. var data []Model
  304. for _, m := range r.Models {
  305. data = append(data, Model{
  306. Id: m.Name,
  307. Object: "model",
  308. Created: m.ModifiedAt.Unix(),
  309. OwnedBy: model.ParseName(m.Name).Namespace,
  310. })
  311. }
  312. return ListCompletion{
  313. Object: "list",
  314. Data: data,
  315. }
  316. }
  317. func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
  318. if r.Embeddings != nil {
  319. var data []Embedding
  320. for i, e := range r.Embeddings {
  321. data = append(data, Embedding{
  322. Object: "embedding",
  323. Embedding: e,
  324. Index: i,
  325. })
  326. }
  327. return EmbeddingList{
  328. Object: "list",
  329. Data: data,
  330. Model: model,
  331. Usage: EmbeddingUsage{
  332. PromptTokens: r.PromptEvalCount,
  333. TotalTokens: r.PromptEvalCount,
  334. },
  335. }
  336. }
  337. return EmbeddingList{}
  338. }
  339. func toModel(r api.ShowResponse, m string) Model {
  340. return Model{
  341. Id: m,
  342. Object: "model",
  343. Created: r.ModifiedAt.Unix(),
  344. OwnedBy: model.ParseName(m).Namespace,
  345. }
  346. }
  347. func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
  348. var messages []api.Message
  349. for _, msg := range r.Messages {
  350. switch content := msg.Content.(type) {
  351. case string:
  352. messages = append(messages, api.Message{Role: msg.Role, Content: content})
  353. case []any:
  354. for _, c := range content {
  355. data, ok := c.(map[string]any)
  356. if !ok {
  357. return nil, errors.New("invalid message format")
  358. }
  359. switch data["type"] {
  360. case "text":
  361. text, ok := data["text"].(string)
  362. if !ok {
  363. return nil, errors.New("invalid message format")
  364. }
  365. messages = append(messages, api.Message{Role: msg.Role, Content: text})
  366. case "image_url":
  367. var url string
  368. if urlMap, ok := data["image_url"].(map[string]any); ok {
  369. if url, ok = urlMap["url"].(string); !ok {
  370. return nil, errors.New("invalid message format")
  371. }
  372. } else {
  373. if url, ok = data["image_url"].(string); !ok {
  374. return nil, errors.New("invalid message format")
  375. }
  376. }
  377. types := []string{"jpeg", "jpg", "png"}
  378. valid := false
  379. for _, t := range types {
  380. prefix := "data:image/" + t + ";base64,"
  381. if strings.HasPrefix(url, prefix) {
  382. url = strings.TrimPrefix(url, prefix)
  383. valid = true
  384. break
  385. }
  386. }
  387. if !valid {
  388. return nil, errors.New("invalid image input")
  389. }
  390. img, err := base64.StdEncoding.DecodeString(url)
  391. if err != nil {
  392. return nil, errors.New("invalid message format")
  393. }
  394. messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
  395. default:
  396. return nil, errors.New("invalid message format")
  397. }
  398. }
  399. default:
  400. if msg.ToolCalls == nil {
  401. return nil, fmt.Errorf("invalid message content type: %T", content)
  402. }
  403. toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
  404. for i, tc := range msg.ToolCalls {
  405. toolCalls[i].Function.Name = tc.Function.Name
  406. err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
  407. if err != nil {
  408. return nil, errors.New("invalid tool call arguments")
  409. }
  410. }
  411. messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
  412. }
  413. }
  414. options := make(map[string]interface{})
  415. switch stop := r.Stop.(type) {
  416. case string:
  417. options["stop"] = []string{stop}
  418. case []any:
  419. var stops []string
  420. for _, s := range stop {
  421. if str, ok := s.(string); ok {
  422. stops = append(stops, str)
  423. }
  424. }
  425. options["stop"] = stops
  426. }
  427. // Deprecated: MaxTokens is deprecated, use MaxCompletionTokens instead
  428. if r.MaxTokens != nil {
  429. r.MaxCompletionTokens = r.MaxTokens
  430. }
  431. if r.NumCtx != nil {
  432. options["num_ctx"] = *r.NumCtx
  433. }
  434. DEFAULT_NUM_CTX := 2048
  435. // set num_ctx to max_completion_tokens if it's greater than num_ctx
  436. if r.MaxCompletionTokens != nil {
  437. options["num_predict"] = *r.MaxCompletionTokens
  438. if r.NumCtx != nil && *r.MaxCompletionTokens > *r.NumCtx {
  439. options["num_ctx"] = *r.MaxCompletionTokens
  440. } else if *r.MaxCompletionTokens > DEFAULT_NUM_CTX {
  441. options["num_ctx"] = *r.MaxCompletionTokens
  442. }
  443. }
  444. if r.Temperature != nil {
  445. options["temperature"] = *r.Temperature
  446. } else {
  447. options["temperature"] = 1.0
  448. }
  449. if r.Seed != nil {
  450. options["seed"] = *r.Seed
  451. }
  452. if r.FrequencyPenalty != nil {
  453. options["frequency_penalty"] = *r.FrequencyPenalty
  454. }
  455. if r.PresencePenalty != nil {
  456. options["presence_penalty"] = *r.PresencePenalty
  457. }
  458. if r.TopP != nil {
  459. options["top_p"] = *r.TopP
  460. } else {
  461. options["top_p"] = 1.0
  462. }
  463. var format json.RawMessage
  464. if r.ResponseFormat != nil {
  465. switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
  466. // Support the old "json_object" type for OpenAI compatibility
  467. case "json_object":
  468. format = json.RawMessage(`"json"`)
  469. case "json_schema":
  470. if r.ResponseFormat.JsonSchema != nil {
  471. format = r.ResponseFormat.JsonSchema.Schema
  472. }
  473. }
  474. }
  475. return &api.ChatRequest{
  476. Model: r.Model,
  477. Messages: messages,
  478. Format: format,
  479. Options: options,
  480. Stream: &r.Stream,
  481. Tools: r.Tools,
  482. }, nil
  483. }
  484. func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
  485. options := make(map[string]any)
  486. switch stop := r.Stop.(type) {
  487. case string:
  488. options["stop"] = []string{stop}
  489. case []any:
  490. var stops []string
  491. for _, s := range stop {
  492. if str, ok := s.(string); ok {
  493. stops = append(stops, str)
  494. } else {
  495. return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
  496. }
  497. }
  498. options["stop"] = stops
  499. }
  500. if r.MaxTokens != nil {
  501. options["num_predict"] = *r.MaxTokens
  502. }
  503. if r.Temperature != nil {
  504. options["temperature"] = *r.Temperature
  505. } else {
  506. options["temperature"] = 1.0
  507. }
  508. if r.Seed != nil {
  509. options["seed"] = *r.Seed
  510. }
  511. options["frequency_penalty"] = r.FrequencyPenalty
  512. options["presence_penalty"] = r.PresencePenalty
  513. if r.TopP != 0.0 {
  514. options["top_p"] = r.TopP
  515. } else {
  516. options["top_p"] = 1.0
  517. }
  518. return api.GenerateRequest{
  519. Model: r.Model,
  520. Prompt: r.Prompt,
  521. Options: options,
  522. Stream: &r.Stream,
  523. Suffix: r.Suffix,
  524. }, nil
  525. }
  526. type BaseWriter struct {
  527. gin.ResponseWriter
  528. }
  529. type ChatWriter struct {
  530. stream bool
  531. streamOptions *StreamOptions
  532. id string
  533. BaseWriter
  534. }
  535. type CompleteWriter struct {
  536. stream bool
  537. streamOptions *StreamOptions
  538. id string
  539. BaseWriter
  540. }
  541. type ListWriter struct {
  542. BaseWriter
  543. }
  544. type RetrieveWriter struct {
  545. BaseWriter
  546. model string
  547. }
  548. type EmbedWriter struct {
  549. BaseWriter
  550. model string
  551. }
  552. func (w *BaseWriter) writeError(data []byte) (int, error) {
  553. var serr api.StatusError
  554. err := json.Unmarshal(data, &serr)
  555. if err != nil {
  556. return 0, err
  557. }
  558. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  559. err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
  560. if err != nil {
  561. return 0, err
  562. }
  563. return len(data), nil
  564. }
  565. func (w *ChatWriter) writeResponse(data []byte) (int, error) {
  566. var chatResponse api.ChatResponse
  567. err := json.Unmarshal(data, &chatResponse)
  568. if err != nil {
  569. return 0, err
  570. }
  571. // chat chunk
  572. if w.stream {
  573. c := toChunk(w.id, chatResponse)
  574. d, err := json.Marshal(c)
  575. if err != nil {
  576. return 0, err
  577. }
  578. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  579. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  580. if err != nil {
  581. return 0, err
  582. }
  583. if chatResponse.Done {
  584. if w.streamOptions != nil && w.streamOptions.IncludeUsage {
  585. u := toUsage(chatResponse)
  586. c.Usage = &u
  587. c.Choices = []ChunkChoice{}
  588. d, err := json.Marshal(c)
  589. if err != nil {
  590. return 0, err
  591. }
  592. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  593. if err != nil {
  594. return 0, err
  595. }
  596. }
  597. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  598. if err != nil {
  599. return 0, err
  600. }
  601. }
  602. return len(data), nil
  603. }
  604. // chat completion
  605. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  606. err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
  607. if err != nil {
  608. return 0, err
  609. }
  610. return len(data), nil
  611. }
  612. func (w *ChatWriter) Write(data []byte) (int, error) {
  613. code := w.ResponseWriter.Status()
  614. if code != http.StatusOK {
  615. return w.writeError(data)
  616. }
  617. return w.writeResponse(data)
  618. }
  619. func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
  620. var generateResponse api.GenerateResponse
  621. err := json.Unmarshal(data, &generateResponse)
  622. if err != nil {
  623. return 0, err
  624. }
  625. // completion chunk
  626. if w.stream {
  627. c := toCompleteChunk(w.id, generateResponse)
  628. if w.streamOptions != nil && w.streamOptions.IncludeUsage {
  629. c.Usage = &Usage{}
  630. }
  631. d, err := json.Marshal(c)
  632. if err != nil {
  633. return 0, err
  634. }
  635. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  636. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  637. if err != nil {
  638. return 0, err
  639. }
  640. if generateResponse.Done {
  641. if w.streamOptions != nil && w.streamOptions.IncludeUsage {
  642. u := toUsageGenerate(generateResponse)
  643. c.Usage = &u
  644. c.Choices = []CompleteChunkChoice{}
  645. d, err := json.Marshal(c)
  646. if err != nil {
  647. return 0, err
  648. }
  649. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  650. if err != nil {
  651. return 0, err
  652. }
  653. }
  654. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  655. if err != nil {
  656. return 0, err
  657. }
  658. }
  659. return len(data), nil
  660. }
  661. // completion
  662. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  663. err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
  664. if err != nil {
  665. return 0, err
  666. }
  667. return len(data), nil
  668. }
  669. func (w *CompleteWriter) Write(data []byte) (int, error) {
  670. code := w.ResponseWriter.Status()
  671. if code != http.StatusOK {
  672. return w.writeError(data)
  673. }
  674. return w.writeResponse(data)
  675. }
  676. func (w *ListWriter) writeResponse(data []byte) (int, error) {
  677. var listResponse api.ListResponse
  678. err := json.Unmarshal(data, &listResponse)
  679. if err != nil {
  680. return 0, err
  681. }
  682. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  683. err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
  684. if err != nil {
  685. return 0, err
  686. }
  687. return len(data), nil
  688. }
  689. func (w *ListWriter) Write(data []byte) (int, error) {
  690. code := w.ResponseWriter.Status()
  691. if code != http.StatusOK {
  692. return w.writeError(data)
  693. }
  694. return w.writeResponse(data)
  695. }
  696. func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
  697. var showResponse api.ShowResponse
  698. err := json.Unmarshal(data, &showResponse)
  699. if err != nil {
  700. return 0, err
  701. }
  702. // retrieve completion
  703. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  704. err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
  705. if err != nil {
  706. return 0, err
  707. }
  708. return len(data), nil
  709. }
  710. func (w *RetrieveWriter) Write(data []byte) (int, error) {
  711. code := w.ResponseWriter.Status()
  712. if code != http.StatusOK {
  713. return w.writeError(data)
  714. }
  715. return w.writeResponse(data)
  716. }
  717. func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
  718. var embedResponse api.EmbedResponse
  719. err := json.Unmarshal(data, &embedResponse)
  720. if err != nil {
  721. return 0, err
  722. }
  723. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  724. err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
  725. if err != nil {
  726. return 0, err
  727. }
  728. return len(data), nil
  729. }
  730. func (w *EmbedWriter) Write(data []byte) (int, error) {
  731. code := w.ResponseWriter.Status()
  732. if code != http.StatusOK {
  733. return w.writeError(data)
  734. }
  735. return w.writeResponse(data)
  736. }
  737. func ListMiddleware() gin.HandlerFunc {
  738. return func(c *gin.Context) {
  739. w := &ListWriter{
  740. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  741. }
  742. c.Writer = w
  743. c.Next()
  744. }
  745. }
  746. func RetrieveMiddleware() gin.HandlerFunc {
  747. return func(c *gin.Context) {
  748. var b bytes.Buffer
  749. if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
  750. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  751. return
  752. }
  753. c.Request.Body = io.NopCloser(&b)
  754. // response writer
  755. w := &RetrieveWriter{
  756. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  757. model: c.Param("model"),
  758. }
  759. c.Writer = w
  760. c.Next()
  761. }
  762. }
  763. func CompletionsMiddleware() gin.HandlerFunc {
  764. return func(c *gin.Context) {
  765. var req CompletionRequest
  766. err := c.ShouldBindJSON(&req)
  767. if err != nil {
  768. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  769. return
  770. }
  771. var b bytes.Buffer
  772. genReq, err := fromCompleteRequest(req)
  773. if err != nil {
  774. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  775. return
  776. }
  777. if err := json.NewEncoder(&b).Encode(genReq); err != nil {
  778. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  779. return
  780. }
  781. c.Request.Body = io.NopCloser(&b)
  782. w := &CompleteWriter{
  783. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  784. stream: req.Stream,
  785. id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
  786. streamOptions: req.StreamOptions,
  787. }
  788. c.Writer = w
  789. c.Next()
  790. }
  791. }
  792. func EmbeddingsMiddleware() gin.HandlerFunc {
  793. return func(c *gin.Context) {
  794. var req EmbedRequest
  795. err := c.ShouldBindJSON(&req)
  796. if err != nil {
  797. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  798. return
  799. }
  800. if req.Input == "" {
  801. req.Input = []string{""}
  802. }
  803. if req.Input == nil {
  804. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
  805. return
  806. }
  807. if v, ok := req.Input.([]any); ok && len(v) == 0 {
  808. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
  809. return
  810. }
  811. var b bytes.Buffer
  812. if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
  813. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  814. return
  815. }
  816. c.Request.Body = io.NopCloser(&b)
  817. w := &EmbedWriter{
  818. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  819. model: req.Model,
  820. }
  821. c.Writer = w
  822. c.Next()
  823. }
  824. }
  825. func ChatMiddleware() gin.HandlerFunc {
  826. return func(c *gin.Context) {
  827. var req ChatCompletionRequest
  828. err := c.ShouldBindJSON(&req)
  829. if err != nil {
  830. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  831. return
  832. }
  833. if len(req.Messages) == 0 {
  834. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
  835. return
  836. }
  837. var b bytes.Buffer
  838. chatReq, err := fromChatRequest(req)
  839. if err != nil {
  840. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  841. return
  842. }
  843. if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
  844. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  845. return
  846. }
  847. c.Request.Body = io.NopCloser(&b)
  848. w := &ChatWriter{
  849. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  850. stream: req.Stream,
  851. id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
  852. streamOptions: req.StreamOptions,
  853. }
  854. c.Writer = w
  855. c.Next()
  856. }
  857. }