openai.go 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998
  1. // openai package provides middleware for partial compatibility with the OpenAI REST API
  2. package openai
  3. import (
  4. "bytes"
  5. "encoding/base64"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log/slog"
  11. "math/rand"
  12. "net/http"
  13. "strings"
  14. "time"
  15. "github.com/gin-gonic/gin"
  16. "github.com/ollama/ollama/api"
  17. "github.com/ollama/ollama/types/model"
  18. )
  19. type Error struct {
  20. Message string `json:"message"`
  21. Type string `json:"type"`
  22. Param interface{} `json:"param"`
  23. Code *string `json:"code"`
  24. }
  25. type ErrorResponse struct {
  26. Error Error `json:"error"`
  27. }
  28. type Message struct {
  29. Role string `json:"role"`
  30. Content any `json:"content"`
  31. ToolCalls []ToolCall `json:"tool_calls,omitempty"`
  32. }
  33. type Choice struct {
  34. Index int `json:"index"`
  35. Message Message `json:"message"`
  36. FinishReason *string `json:"finish_reason"`
  37. }
  38. type ChunkChoice struct {
  39. Index int `json:"index"`
  40. Delta Message `json:"delta"`
  41. FinishReason *string `json:"finish_reason"`
  42. }
  43. type CompleteChunkChoice struct {
  44. Text string `json:"text"`
  45. Index int `json:"index"`
  46. FinishReason *string `json:"finish_reason"`
  47. }
  48. type Usage struct {
  49. PromptTokens int `json:"prompt_tokens"`
  50. CompletionTokens int `json:"completion_tokens"`
  51. TotalTokens int `json:"total_tokens"`
  52. }
  53. type ResponseFormat struct {
  54. Type string `json:"type"`
  55. JsonSchema *JsonSchema `json:"json_schema,omitempty"`
  56. }
  57. type JsonSchema struct {
  58. Schema json.RawMessage `json:"schema"`
  59. }
  60. type EmbedRequest struct {
  61. Input any `json:"input"`
  62. Model string `json:"model"`
  63. }
  64. type StreamOptions struct {
  65. IncludeUsage bool `json:"include_usage"`
  66. }
  67. type ChatCompletionRequest struct {
  68. Model string `json:"model"`
  69. Messages []Message `json:"messages"`
  70. Stream bool `json:"stream"`
  71. StreamOptions *StreamOptions `json:"stream_options"`
  72. MaxCompletionTokens *int `json:"max_completion_tokens"`
  73. // Deprecated: Use [ChatCompletionRequest.MaxCompletionTokens]
  74. MaxTokens *int `json:"max_tokens"`
  75. Seed *int `json:"seed"`
  76. Stop any `json:"stop"`
  77. Temperature *float64 `json:"temperature"`
  78. FrequencyPenalty *float64 `json:"frequency_penalty"`
  79. PresencePenalty *float64 `json:"presence_penalty"`
  80. TopP *float64 `json:"top_p"`
  81. ResponseFormat *ResponseFormat `json:"response_format"`
  82. Tools []api.Tool `json:"tools"`
  83. ContextWindow *int `json:"context_window"`
  84. }
  85. type ChatCompletion struct {
  86. Id string `json:"id"`
  87. Object string `json:"object"`
  88. Created int64 `json:"created"`
  89. Model string `json:"model"`
  90. SystemFingerprint string `json:"system_fingerprint"`
  91. Choices []Choice `json:"choices"`
  92. Usage Usage `json:"usage,omitempty"`
  93. }
  94. type ChatCompletionChunk struct {
  95. Id string `json:"id"`
  96. Object string `json:"object"`
  97. Created int64 `json:"created"`
  98. Model string `json:"model"`
  99. SystemFingerprint string `json:"system_fingerprint"`
  100. Choices []ChunkChoice `json:"choices"`
  101. Usage *Usage `json:"usage,omitempty"`
  102. }
  103. // TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
  104. type CompletionRequest struct {
  105. Model string `json:"model"`
  106. Prompt string `json:"prompt"`
  107. FrequencyPenalty float32 `json:"frequency_penalty"`
  108. MaxTokens *int `json:"max_tokens"`
  109. PresencePenalty float32 `json:"presence_penalty"`
  110. Seed *int `json:"seed"`
  111. Stop any `json:"stop"`
  112. Stream bool `json:"stream"`
  113. StreamOptions *StreamOptions `json:"stream_options"`
  114. Temperature *float32 `json:"temperature"`
  115. TopP float32 `json:"top_p"`
  116. Suffix string `json:"suffix"`
  117. }
  118. type Completion struct {
  119. Id string `json:"id"`
  120. Object string `json:"object"`
  121. Created int64 `json:"created"`
  122. Model string `json:"model"`
  123. SystemFingerprint string `json:"system_fingerprint"`
  124. Choices []CompleteChunkChoice `json:"choices"`
  125. Usage Usage `json:"usage,omitempty"`
  126. }
  127. type CompletionChunk struct {
  128. Id string `json:"id"`
  129. Object string `json:"object"`
  130. Created int64 `json:"created"`
  131. Choices []CompleteChunkChoice `json:"choices"`
  132. Model string `json:"model"`
  133. SystemFingerprint string `json:"system_fingerprint"`
  134. Usage *Usage `json:"usage,omitempty"`
  135. }
  136. type ToolCall struct {
  137. ID string `json:"id"`
  138. Index int `json:"index"`
  139. Type string `json:"type"`
  140. Function struct {
  141. Name string `json:"name"`
  142. Arguments string `json:"arguments"`
  143. } `json:"function"`
  144. }
  145. type Model struct {
  146. Id string `json:"id"`
  147. Object string `json:"object"`
  148. Created int64 `json:"created"`
  149. OwnedBy string `json:"owned_by"`
  150. }
  151. type Embedding struct {
  152. Object string `json:"object"`
  153. Embedding []float32 `json:"embedding"`
  154. Index int `json:"index"`
  155. }
  156. type ListCompletion struct {
  157. Object string `json:"object"`
  158. Data []Model `json:"data"`
  159. }
  160. type EmbeddingList struct {
  161. Object string `json:"object"`
  162. Data []Embedding `json:"data"`
  163. Model string `json:"model"`
  164. Usage EmbeddingUsage `json:"usage,omitempty"`
  165. }
  166. type EmbeddingUsage struct {
  167. PromptTokens int `json:"prompt_tokens"`
  168. TotalTokens int `json:"total_tokens"`
  169. }
  170. func NewError(code int, message string) ErrorResponse {
  171. var etype string
  172. switch code {
  173. case http.StatusBadRequest:
  174. etype = "invalid_request_error"
  175. case http.StatusNotFound:
  176. etype = "not_found_error"
  177. default:
  178. etype = "api_error"
  179. }
  180. return ErrorResponse{Error{Type: etype, Message: message}}
  181. }
  182. func toUsage(r api.ChatResponse) Usage {
  183. return Usage{
  184. PromptTokens: r.PromptEvalCount,
  185. CompletionTokens: r.EvalCount,
  186. TotalTokens: r.PromptEvalCount + r.EvalCount,
  187. }
  188. }
  189. func toolCallId() string {
  190. const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
  191. b := make([]byte, 8)
  192. for i := range b {
  193. b[i] = letterBytes[rand.Intn(len(letterBytes))]
  194. }
  195. return "call_" + strings.ToLower(string(b))
  196. }
  197. func toToolCalls(tc []api.ToolCall) []ToolCall {
  198. toolCalls := make([]ToolCall, len(tc))
  199. for i, tc := range tc {
  200. toolCalls[i].ID = toolCallId()
  201. toolCalls[i].Type = "function"
  202. toolCalls[i].Function.Name = tc.Function.Name
  203. toolCalls[i].Index = tc.Function.Index
  204. args, err := json.Marshal(tc.Function.Arguments)
  205. if err != nil {
  206. slog.Error("could not marshall function arguments to json", "error", err)
  207. continue
  208. }
  209. toolCalls[i].Function.Arguments = string(args)
  210. }
  211. return toolCalls
  212. }
  213. func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
  214. toolCalls := toToolCalls(r.Message.ToolCalls)
  215. return ChatCompletion{
  216. Id: id,
  217. Object: "chat.completion",
  218. Created: r.CreatedAt.Unix(),
  219. Model: r.Model,
  220. SystemFingerprint: "fp_ollama",
  221. Choices: []Choice{{
  222. Index: 0,
  223. Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
  224. FinishReason: func(reason string) *string {
  225. if len(toolCalls) > 0 {
  226. reason = "tool_calls"
  227. }
  228. if len(reason) > 0 {
  229. return &reason
  230. }
  231. return nil
  232. }(r.DoneReason),
  233. }},
  234. Usage: toUsage(r),
  235. }
  236. }
  237. func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
  238. toolCalls := toToolCalls(r.Message.ToolCalls)
  239. return ChatCompletionChunk{
  240. Id: id,
  241. Object: "chat.completion.chunk",
  242. Created: time.Now().Unix(),
  243. Model: r.Model,
  244. SystemFingerprint: "fp_ollama",
  245. Choices: []ChunkChoice{{
  246. Index: 0,
  247. Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
  248. FinishReason: func(reason string) *string {
  249. if len(reason) > 0 {
  250. return &reason
  251. }
  252. return nil
  253. }(r.DoneReason),
  254. }},
  255. }
  256. }
  257. func toUsageGenerate(r api.GenerateResponse) Usage {
  258. return Usage{
  259. PromptTokens: r.PromptEvalCount,
  260. CompletionTokens: r.EvalCount,
  261. TotalTokens: r.PromptEvalCount + r.EvalCount,
  262. }
  263. }
  264. func toCompletion(id string, r api.GenerateResponse) Completion {
  265. return Completion{
  266. Id: id,
  267. Object: "text_completion",
  268. Created: r.CreatedAt.Unix(),
  269. Model: r.Model,
  270. SystemFingerprint: "fp_ollama",
  271. Choices: []CompleteChunkChoice{{
  272. Text: r.Response,
  273. Index: 0,
  274. FinishReason: func(reason string) *string {
  275. if len(reason) > 0 {
  276. return &reason
  277. }
  278. return nil
  279. }(r.DoneReason),
  280. }},
  281. Usage: toUsageGenerate(r),
  282. }
  283. }
  284. func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
  285. return CompletionChunk{
  286. Id: id,
  287. Object: "text_completion",
  288. Created: time.Now().Unix(),
  289. Model: r.Model,
  290. SystemFingerprint: "fp_ollama",
  291. Choices: []CompleteChunkChoice{{
  292. Text: r.Response,
  293. Index: 0,
  294. FinishReason: func(reason string) *string {
  295. if len(reason) > 0 {
  296. return &reason
  297. }
  298. return nil
  299. }(r.DoneReason),
  300. }},
  301. }
  302. }
  303. func toListCompletion(r api.ListResponse) ListCompletion {
  304. var data []Model
  305. for _, m := range r.Models {
  306. data = append(data, Model{
  307. Id: m.Name,
  308. Object: "model",
  309. Created: m.ModifiedAt.Unix(),
  310. OwnedBy: model.ParseName(m.Name).Namespace,
  311. })
  312. }
  313. return ListCompletion{
  314. Object: "list",
  315. Data: data,
  316. }
  317. }
  318. func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
  319. if r.Embeddings != nil {
  320. var data []Embedding
  321. for i, e := range r.Embeddings {
  322. data = append(data, Embedding{
  323. Object: "embedding",
  324. Embedding: e,
  325. Index: i,
  326. })
  327. }
  328. return EmbeddingList{
  329. Object: "list",
  330. Data: data,
  331. Model: model,
  332. Usage: EmbeddingUsage{
  333. PromptTokens: r.PromptEvalCount,
  334. TotalTokens: r.PromptEvalCount,
  335. },
  336. }
  337. }
  338. return EmbeddingList{}
  339. }
  340. func toModel(r api.ShowResponse, m string) Model {
  341. return Model{
  342. Id: m,
  343. Object: "model",
  344. Created: r.ModifiedAt.Unix(),
  345. OwnedBy: model.ParseName(m).Namespace,
  346. }
  347. }
  348. func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
  349. var messages []api.Message
  350. for _, msg := range r.Messages {
  351. switch content := msg.Content.(type) {
  352. case string:
  353. messages = append(messages, api.Message{Role: msg.Role, Content: content})
  354. case []any:
  355. for _, c := range content {
  356. data, ok := c.(map[string]any)
  357. if !ok {
  358. return nil, errors.New("invalid message format")
  359. }
  360. switch data["type"] {
  361. case "text":
  362. text, ok := data["text"].(string)
  363. if !ok {
  364. return nil, errors.New("invalid message format")
  365. }
  366. messages = append(messages, api.Message{Role: msg.Role, Content: text})
  367. case "image_url":
  368. var url string
  369. if urlMap, ok := data["image_url"].(map[string]any); ok {
  370. if url, ok = urlMap["url"].(string); !ok {
  371. return nil, errors.New("invalid message format")
  372. }
  373. } else {
  374. if url, ok = data["image_url"].(string); !ok {
  375. return nil, errors.New("invalid message format")
  376. }
  377. }
  378. types := []string{"jpeg", "jpg", "png"}
  379. valid := false
  380. for _, t := range types {
  381. prefix := "data:image/" + t + ";base64,"
  382. if strings.HasPrefix(url, prefix) {
  383. url = strings.TrimPrefix(url, prefix)
  384. valid = true
  385. break
  386. }
  387. }
  388. if !valid {
  389. return nil, errors.New("invalid image input")
  390. }
  391. img, err := base64.StdEncoding.DecodeString(url)
  392. if err != nil {
  393. return nil, errors.New("invalid message format")
  394. }
  395. messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
  396. default:
  397. return nil, errors.New("invalid message format")
  398. }
  399. }
  400. default:
  401. if msg.ToolCalls == nil {
  402. return nil, fmt.Errorf("invalid message content type: %T", content)
  403. }
  404. toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
  405. for i, tc := range msg.ToolCalls {
  406. toolCalls[i].Function.Name = tc.Function.Name
  407. err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
  408. if err != nil {
  409. return nil, errors.New("invalid tool call arguments")
  410. }
  411. }
  412. messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
  413. }
  414. }
  415. options := make(map[string]interface{})
  416. switch stop := r.Stop.(type) {
  417. case string:
  418. options["stop"] = []string{stop}
  419. case []any:
  420. var stops []string
  421. for _, s := range stop {
  422. if str, ok := s.(string); ok {
  423. stops = append(stops, str)
  424. }
  425. }
  426. options["stop"] = stops
  427. }
  428. if r.ContextWindow != nil {
  429. slog.Info("context_window in if", "context_window", *r.ContextWindow)
  430. options["num_ctx"] = *r.ContextWindow
  431. }
  432. // Deprecated: MaxTokens is deprecated, use MaxCompletionTokens instead
  433. if r.MaxTokens != nil {
  434. r.MaxCompletionTokens = r.MaxTokens
  435. }
  436. if r.MaxCompletionTokens != nil {
  437. options["num_predict"] = *r.MaxCompletionTokens
  438. }
  439. if r.Temperature != nil {
  440. options["temperature"] = *r.Temperature
  441. } else {
  442. options["temperature"] = 1.0
  443. }
  444. if r.Seed != nil {
  445. options["seed"] = *r.Seed
  446. }
  447. if r.FrequencyPenalty != nil {
  448. options["frequency_penalty"] = *r.FrequencyPenalty
  449. }
  450. if r.PresencePenalty != nil {
  451. options["presence_penalty"] = *r.PresencePenalty
  452. }
  453. if r.TopP != nil {
  454. options["top_p"] = *r.TopP
  455. } else {
  456. options["top_p"] = 1.0
  457. }
  458. var format json.RawMessage
  459. if r.ResponseFormat != nil {
  460. switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
  461. // Support the old "json_object" type for OpenAI compatibility
  462. case "json_object":
  463. format = json.RawMessage(`"json"`)
  464. case "json_schema":
  465. if r.ResponseFormat.JsonSchema != nil {
  466. format = r.ResponseFormat.JsonSchema.Schema
  467. }
  468. }
  469. }
  470. return &api.ChatRequest{
  471. Model: r.Model,
  472. Messages: messages,
  473. Format: format,
  474. Options: options,
  475. Stream: &r.Stream,
  476. Tools: r.Tools,
  477. }, nil
  478. }
  479. func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
  480. options := make(map[string]any)
  481. switch stop := r.Stop.(type) {
  482. case string:
  483. options["stop"] = []string{stop}
  484. case []any:
  485. var stops []string
  486. for _, s := range stop {
  487. if str, ok := s.(string); ok {
  488. stops = append(stops, str)
  489. } else {
  490. return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
  491. }
  492. }
  493. options["stop"] = stops
  494. }
  495. if r.MaxTokens != nil {
  496. options["num_predict"] = *r.MaxTokens
  497. }
  498. if r.Temperature != nil {
  499. options["temperature"] = *r.Temperature
  500. } else {
  501. options["temperature"] = 1.0
  502. }
  503. if r.Seed != nil {
  504. options["seed"] = *r.Seed
  505. }
  506. options["frequency_penalty"] = r.FrequencyPenalty
  507. options["presence_penalty"] = r.PresencePenalty
  508. if r.TopP != 0.0 {
  509. options["top_p"] = r.TopP
  510. } else {
  511. options["top_p"] = 1.0
  512. }
  513. return api.GenerateRequest{
  514. Model: r.Model,
  515. Prompt: r.Prompt,
  516. Options: options,
  517. Stream: &r.Stream,
  518. Suffix: r.Suffix,
  519. }, nil
  520. }
  521. type BaseWriter struct {
  522. gin.ResponseWriter
  523. }
  524. type ChatWriter struct {
  525. stream bool
  526. streamOptions *StreamOptions
  527. id string
  528. BaseWriter
  529. }
  530. type CompleteWriter struct {
  531. stream bool
  532. streamOptions *StreamOptions
  533. id string
  534. BaseWriter
  535. }
  536. type ListWriter struct {
  537. BaseWriter
  538. }
  539. type RetrieveWriter struct {
  540. BaseWriter
  541. model string
  542. }
  543. type EmbedWriter struct {
  544. BaseWriter
  545. model string
  546. }
  547. func (w *BaseWriter) writeError(data []byte) (int, error) {
  548. var serr api.StatusError
  549. err := json.Unmarshal(data, &serr)
  550. if err != nil {
  551. return 0, err
  552. }
  553. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  554. err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
  555. if err != nil {
  556. return 0, err
  557. }
  558. return len(data), nil
  559. }
  560. func (w *ChatWriter) writeResponse(data []byte) (int, error) {
  561. var chatResponse api.ChatResponse
  562. err := json.Unmarshal(data, &chatResponse)
  563. if err != nil {
  564. return 0, err
  565. }
  566. // chat chunk
  567. if w.stream {
  568. c := toChunk(w.id, chatResponse)
  569. d, err := json.Marshal(c)
  570. if err != nil {
  571. return 0, err
  572. }
  573. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  574. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  575. if err != nil {
  576. return 0, err
  577. }
  578. if chatResponse.Done {
  579. if w.streamOptions != nil && w.streamOptions.IncludeUsage {
  580. u := toUsage(chatResponse)
  581. c.Usage = &u
  582. c.Choices = []ChunkChoice{}
  583. d, err := json.Marshal(c)
  584. if err != nil {
  585. return 0, err
  586. }
  587. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  588. if err != nil {
  589. return 0, err
  590. }
  591. }
  592. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  593. if err != nil {
  594. return 0, err
  595. }
  596. }
  597. return len(data), nil
  598. }
  599. // chat completion
  600. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  601. err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
  602. if err != nil {
  603. return 0, err
  604. }
  605. return len(data), nil
  606. }
  607. func (w *ChatWriter) Write(data []byte) (int, error) {
  608. code := w.ResponseWriter.Status()
  609. if code != http.StatusOK {
  610. return w.writeError(data)
  611. }
  612. return w.writeResponse(data)
  613. }
  614. func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
  615. var generateResponse api.GenerateResponse
  616. err := json.Unmarshal(data, &generateResponse)
  617. if err != nil {
  618. return 0, err
  619. }
  620. // completion chunk
  621. if w.stream {
  622. c := toCompleteChunk(w.id, generateResponse)
  623. if w.streamOptions != nil && w.streamOptions.IncludeUsage {
  624. c.Usage = &Usage{}
  625. }
  626. d, err := json.Marshal(c)
  627. if err != nil {
  628. return 0, err
  629. }
  630. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  631. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  632. if err != nil {
  633. return 0, err
  634. }
  635. if generateResponse.Done {
  636. if w.streamOptions != nil && w.streamOptions.IncludeUsage {
  637. u := toUsageGenerate(generateResponse)
  638. c.Usage = &u
  639. c.Choices = []CompleteChunkChoice{}
  640. d, err := json.Marshal(c)
  641. if err != nil {
  642. return 0, err
  643. }
  644. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  645. if err != nil {
  646. return 0, err
  647. }
  648. }
  649. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  650. if err != nil {
  651. return 0, err
  652. }
  653. }
  654. return len(data), nil
  655. }
  656. // completion
  657. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  658. err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
  659. if err != nil {
  660. return 0, err
  661. }
  662. return len(data), nil
  663. }
  664. func (w *CompleteWriter) Write(data []byte) (int, error) {
  665. code := w.ResponseWriter.Status()
  666. if code != http.StatusOK {
  667. return w.writeError(data)
  668. }
  669. return w.writeResponse(data)
  670. }
  671. func (w *ListWriter) writeResponse(data []byte) (int, error) {
  672. var listResponse api.ListResponse
  673. err := json.Unmarshal(data, &listResponse)
  674. if err != nil {
  675. return 0, err
  676. }
  677. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  678. err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
  679. if err != nil {
  680. return 0, err
  681. }
  682. return len(data), nil
  683. }
  684. func (w *ListWriter) Write(data []byte) (int, error) {
  685. code := w.ResponseWriter.Status()
  686. if code != http.StatusOK {
  687. return w.writeError(data)
  688. }
  689. return w.writeResponse(data)
  690. }
  691. func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
  692. var showResponse api.ShowResponse
  693. err := json.Unmarshal(data, &showResponse)
  694. if err != nil {
  695. return 0, err
  696. }
  697. // retrieve completion
  698. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  699. err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
  700. if err != nil {
  701. return 0, err
  702. }
  703. return len(data), nil
  704. }
  705. func (w *RetrieveWriter) Write(data []byte) (int, error) {
  706. code := w.ResponseWriter.Status()
  707. if code != http.StatusOK {
  708. return w.writeError(data)
  709. }
  710. return w.writeResponse(data)
  711. }
  712. func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
  713. var embedResponse api.EmbedResponse
  714. err := json.Unmarshal(data, &embedResponse)
  715. if err != nil {
  716. return 0, err
  717. }
  718. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  719. err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
  720. if err != nil {
  721. return 0, err
  722. }
  723. return len(data), nil
  724. }
  725. func (w *EmbedWriter) Write(data []byte) (int, error) {
  726. code := w.ResponseWriter.Status()
  727. if code != http.StatusOK {
  728. return w.writeError(data)
  729. }
  730. return w.writeResponse(data)
  731. }
  732. func ListMiddleware() gin.HandlerFunc {
  733. return func(c *gin.Context) {
  734. w := &ListWriter{
  735. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  736. }
  737. c.Writer = w
  738. c.Next()
  739. }
  740. }
  741. func RetrieveMiddleware() gin.HandlerFunc {
  742. return func(c *gin.Context) {
  743. var b bytes.Buffer
  744. if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
  745. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  746. return
  747. }
  748. c.Request.Body = io.NopCloser(&b)
  749. // response writer
  750. w := &RetrieveWriter{
  751. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  752. model: c.Param("model"),
  753. }
  754. c.Writer = w
  755. c.Next()
  756. }
  757. }
  758. func CompletionsMiddleware() gin.HandlerFunc {
  759. return func(c *gin.Context) {
  760. var req CompletionRequest
  761. err := c.ShouldBindJSON(&req)
  762. if err != nil {
  763. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  764. return
  765. }
  766. var b bytes.Buffer
  767. genReq, err := fromCompleteRequest(req)
  768. if err != nil {
  769. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  770. return
  771. }
  772. if err := json.NewEncoder(&b).Encode(genReq); err != nil {
  773. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  774. return
  775. }
  776. c.Request.Body = io.NopCloser(&b)
  777. w := &CompleteWriter{
  778. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  779. stream: req.Stream,
  780. id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
  781. streamOptions: req.StreamOptions,
  782. }
  783. c.Writer = w
  784. c.Next()
  785. }
  786. }
  787. func EmbeddingsMiddleware() gin.HandlerFunc {
  788. return func(c *gin.Context) {
  789. var req EmbedRequest
  790. err := c.ShouldBindJSON(&req)
  791. if err != nil {
  792. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  793. return
  794. }
  795. if req.Input == "" {
  796. req.Input = []string{""}
  797. }
  798. if req.Input == nil {
  799. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
  800. return
  801. }
  802. if v, ok := req.Input.([]any); ok && len(v) == 0 {
  803. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
  804. return
  805. }
  806. var b bytes.Buffer
  807. if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
  808. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  809. return
  810. }
  811. c.Request.Body = io.NopCloser(&b)
  812. w := &EmbedWriter{
  813. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  814. model: req.Model,
  815. }
  816. c.Writer = w
  817. c.Next()
  818. }
  819. }
  820. func ChatMiddleware() gin.HandlerFunc {
  821. return func(c *gin.Context) {
  822. var req ChatCompletionRequest
  823. err := c.ShouldBindJSON(&req)
  824. if err != nil {
  825. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  826. return
  827. }
  828. if len(req.Messages) == 0 {
  829. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
  830. return
  831. }
  832. var b bytes.Buffer
  833. chatReq, err := fromChatRequest(req)
  834. if err != nil {
  835. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  836. return
  837. }
  838. slog.Info("num_ctx", "num_ctx", chatReq.Options["num_ctx"])
  839. if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
  840. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  841. return
  842. }
  843. c.Request.Body = io.NopCloser(&b)
  844. w := &ChatWriter{
  845. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  846. stream: req.Stream,
  847. id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
  848. streamOptions: req.StreamOptions,
  849. }
  850. c.Writer = w
  851. c.Next()
  852. }
  853. }