openai.go 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945
  1. // openai package provides middleware for partial compatibility with the OpenAI REST API
  2. package openai
  3. import (
  4. "bytes"
  5. "encoding/base64"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log/slog"
  11. "math/rand"
  12. "net/http"
  13. "strings"
  14. "time"
  15. "github.com/gin-gonic/gin"
  16. "github.com/ollama/ollama/api"
  17. "github.com/ollama/ollama/types/model"
  18. )
  19. type Error struct {
  20. Message string `json:"message"`
  21. Type string `json:"type"`
  22. Param interface{} `json:"param"`
  23. Code *string `json:"code"`
  24. }
  25. type ErrorResponse struct {
  26. Error Error `json:"error"`
  27. }
  28. type Message struct {
  29. Role string `json:"role"`
  30. Content any `json:"content"`
  31. ToolCalls []ToolCall `json:"tool_calls,omitempty"`
  32. }
  33. type Choice struct {
  34. Index int `json:"index"`
  35. Message Message `json:"message"`
  36. FinishReason *string `json:"finish_reason"`
  37. }
  38. type ChunkChoice struct {
  39. Index int `json:"index"`
  40. Delta Message `json:"delta"`
  41. FinishReason *string `json:"finish_reason"`
  42. }
  43. type CompleteChunkChoice struct {
  44. Text string `json:"text"`
  45. Index int `json:"index"`
  46. FinishReason *string `json:"finish_reason"`
  47. }
  48. type Usage struct {
  49. PromptTokens int `json:"prompt_tokens"`
  50. CompletionTokens int `json:"completion_tokens"`
  51. TotalTokens int `json:"total_tokens"`
  52. }
  53. type ResponseFormat struct {
  54. Type string `json:"type"`
  55. JsonSchema *JsonSchema `json:"json_schema,omitempty"`
  56. }
  57. type JsonSchema struct {
  58. Schema map[string]any `json:"schema"`
  59. }
  60. type EmbedRequest struct {
  61. Input any `json:"input"`
  62. Model string `json:"model"`
  63. }
  64. type ChatCompletionRequest struct {
  65. Model string `json:"model"`
  66. Messages []Message `json:"messages"`
  67. Stream bool `json:"stream"`
  68. MaxTokens *int `json:"max_tokens"`
  69. Seed *int `json:"seed"`
  70. Stop any `json:"stop"`
  71. Temperature *float64 `json:"temperature"`
  72. FrequencyPenalty *float64 `json:"frequency_penalty"`
  73. PresencePenalty *float64 `json:"presence_penalty"`
  74. TopP *float64 `json:"top_p"`
  75. ResponseFormat *ResponseFormat `json:"response_format"`
  76. Tools []api.Tool `json:"tools"`
  77. }
  78. type ChatCompletion struct {
  79. Id string `json:"id"`
  80. Object string `json:"object"`
  81. Created int64 `json:"created"`
  82. Model string `json:"model"`
  83. SystemFingerprint string `json:"system_fingerprint"`
  84. Choices []Choice `json:"choices"`
  85. Usage Usage `json:"usage,omitempty"`
  86. }
  87. type ChatCompletionChunk struct {
  88. Id string `json:"id"`
  89. Object string `json:"object"`
  90. Created int64 `json:"created"`
  91. Model string `json:"model"`
  92. SystemFingerprint string `json:"system_fingerprint"`
  93. Choices []ChunkChoice `json:"choices"`
  94. }
  95. // TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
  96. type CompletionRequest struct {
  97. Model string `json:"model"`
  98. Prompt string `json:"prompt"`
  99. FrequencyPenalty float32 `json:"frequency_penalty"`
  100. MaxTokens *int `json:"max_tokens"`
  101. PresencePenalty float32 `json:"presence_penalty"`
  102. Seed *int `json:"seed"`
  103. Stop any `json:"stop"`
  104. Stream bool `json:"stream"`
  105. Temperature *float32 `json:"temperature"`
  106. TopP float32 `json:"top_p"`
  107. Suffix string `json:"suffix"`
  108. }
  109. type Completion struct {
  110. Id string `json:"id"`
  111. Object string `json:"object"`
  112. Created int64 `json:"created"`
  113. Model string `json:"model"`
  114. SystemFingerprint string `json:"system_fingerprint"`
  115. Choices []CompleteChunkChoice `json:"choices"`
  116. Usage Usage `json:"usage,omitempty"`
  117. }
  118. type CompletionChunk struct {
  119. Id string `json:"id"`
  120. Object string `json:"object"`
  121. Created int64 `json:"created"`
  122. Choices []CompleteChunkChoice `json:"choices"`
  123. Model string `json:"model"`
  124. SystemFingerprint string `json:"system_fingerprint"`
  125. }
  126. type ToolCall struct {
  127. ID string `json:"id"`
  128. Index int `json:"index"`
  129. Type string `json:"type"`
  130. Function struct {
  131. Name string `json:"name"`
  132. Arguments string `json:"arguments"`
  133. } `json:"function"`
  134. }
  135. type Model struct {
  136. Id string `json:"id"`
  137. Object string `json:"object"`
  138. Created int64 `json:"created"`
  139. OwnedBy string `json:"owned_by"`
  140. }
  141. type Embedding struct {
  142. Object string `json:"object"`
  143. Embedding []float32 `json:"embedding"`
  144. Index int `json:"index"`
  145. }
  146. type ListCompletion struct {
  147. Object string `json:"object"`
  148. Data []Model `json:"data"`
  149. }
  150. type EmbeddingList struct {
  151. Object string `json:"object"`
  152. Data []Embedding `json:"data"`
  153. Model string `json:"model"`
  154. Usage EmbeddingUsage `json:"usage,omitempty"`
  155. }
  156. type EmbeddingUsage struct {
  157. PromptTokens int `json:"prompt_tokens"`
  158. TotalTokens int `json:"total_tokens"`
  159. }
  160. func NewError(code int, message string) ErrorResponse {
  161. var etype string
  162. switch code {
  163. case http.StatusBadRequest:
  164. etype = "invalid_request_error"
  165. case http.StatusNotFound:
  166. etype = "not_found_error"
  167. default:
  168. etype = "api_error"
  169. }
  170. return ErrorResponse{Error{Type: etype, Message: message}}
  171. }
  172. func toolCallId() string {
  173. const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
  174. b := make([]byte, 8)
  175. for i := range b {
  176. b[i] = letterBytes[rand.Intn(len(letterBytes))]
  177. }
  178. return "call_" + strings.ToLower(string(b))
  179. }
  180. func toToolCalls(tc []api.ToolCall) []ToolCall {
  181. toolCalls := make([]ToolCall, len(tc))
  182. for i, tc := range tc {
  183. toolCalls[i].ID = toolCallId()
  184. toolCalls[i].Type = "function"
  185. toolCalls[i].Function.Name = tc.Function.Name
  186. toolCalls[i].Index = tc.Function.Index
  187. args, err := json.Marshal(tc.Function.Arguments)
  188. if err != nil {
  189. slog.Error("could not marshall function arguments to json", "error", err)
  190. continue
  191. }
  192. toolCalls[i].Function.Arguments = string(args)
  193. }
  194. return toolCalls
  195. }
  196. func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
  197. toolCalls := toToolCalls(r.Message.ToolCalls)
  198. return ChatCompletion{
  199. Id: id,
  200. Object: "chat.completion",
  201. Created: r.CreatedAt.Unix(),
  202. Model: r.Model,
  203. SystemFingerprint: "fp_ollama",
  204. Choices: []Choice{{
  205. Index: 0,
  206. Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
  207. FinishReason: func(reason string) *string {
  208. if len(toolCalls) > 0 {
  209. reason = "tool_calls"
  210. }
  211. if len(reason) > 0 {
  212. return &reason
  213. }
  214. return nil
  215. }(r.DoneReason),
  216. }},
  217. Usage: Usage{
  218. PromptTokens: r.PromptEvalCount,
  219. CompletionTokens: r.EvalCount,
  220. TotalTokens: r.PromptEvalCount + r.EvalCount,
  221. },
  222. }
  223. }
  224. func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
  225. toolCalls := toToolCalls(r.Message.ToolCalls)
  226. return ChatCompletionChunk{
  227. Id: id,
  228. Object: "chat.completion.chunk",
  229. Created: time.Now().Unix(),
  230. Model: r.Model,
  231. SystemFingerprint: "fp_ollama",
  232. Choices: []ChunkChoice{{
  233. Index: 0,
  234. Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
  235. FinishReason: func(reason string) *string {
  236. if len(toolCalls) > 0 {
  237. reason = "tool_calls"
  238. }
  239. if len(reason) > 0 {
  240. return &reason
  241. }
  242. return nil
  243. }(r.DoneReason),
  244. }},
  245. }
  246. }
  247. func toCompletion(id string, r api.GenerateResponse) Completion {
  248. return Completion{
  249. Id: id,
  250. Object: "text_completion",
  251. Created: r.CreatedAt.Unix(),
  252. Model: r.Model,
  253. SystemFingerprint: "fp_ollama",
  254. Choices: []CompleteChunkChoice{{
  255. Text: r.Response,
  256. Index: 0,
  257. FinishReason: func(reason string) *string {
  258. if len(reason) > 0 {
  259. return &reason
  260. }
  261. return nil
  262. }(r.DoneReason),
  263. }},
  264. Usage: Usage{
  265. PromptTokens: r.PromptEvalCount,
  266. CompletionTokens: r.EvalCount,
  267. TotalTokens: r.PromptEvalCount + r.EvalCount,
  268. },
  269. }
  270. }
  271. func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
  272. return CompletionChunk{
  273. Id: id,
  274. Object: "text_completion",
  275. Created: time.Now().Unix(),
  276. Model: r.Model,
  277. SystemFingerprint: "fp_ollama",
  278. Choices: []CompleteChunkChoice{{
  279. Text: r.Response,
  280. Index: 0,
  281. FinishReason: func(reason string) *string {
  282. if len(reason) > 0 {
  283. return &reason
  284. }
  285. return nil
  286. }(r.DoneReason),
  287. }},
  288. }
  289. }
  290. func toListCompletion(r api.ListResponse) ListCompletion {
  291. var data []Model
  292. for _, m := range r.Models {
  293. data = append(data, Model{
  294. Id: m.Name,
  295. Object: "model",
  296. Created: m.ModifiedAt.Unix(),
  297. OwnedBy: model.ParseName(m.Name).Namespace,
  298. })
  299. }
  300. return ListCompletion{
  301. Object: "list",
  302. Data: data,
  303. }
  304. }
  305. func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
  306. if r.Embeddings != nil {
  307. var data []Embedding
  308. for i, e := range r.Embeddings {
  309. data = append(data, Embedding{
  310. Object: "embedding",
  311. Embedding: e,
  312. Index: i,
  313. })
  314. }
  315. return EmbeddingList{
  316. Object: "list",
  317. Data: data,
  318. Model: model,
  319. Usage: EmbeddingUsage{
  320. PromptTokens: r.PromptEvalCount,
  321. TotalTokens: r.PromptEvalCount,
  322. },
  323. }
  324. }
  325. return EmbeddingList{}
  326. }
  327. func toModel(r api.ShowResponse, m string) Model {
  328. return Model{
  329. Id: m,
  330. Object: "model",
  331. Created: r.ModifiedAt.Unix(),
  332. OwnedBy: model.ParseName(m).Namespace,
  333. }
  334. }
  335. func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
  336. var messages []api.Message
  337. for _, msg := range r.Messages {
  338. switch content := msg.Content.(type) {
  339. case string:
  340. messages = append(messages, api.Message{Role: msg.Role, Content: content})
  341. case []any:
  342. for _, c := range content {
  343. data, ok := c.(map[string]any)
  344. if !ok {
  345. return nil, errors.New("invalid message format")
  346. }
  347. switch data["type"] {
  348. case "text":
  349. text, ok := data["text"].(string)
  350. if !ok {
  351. return nil, errors.New("invalid message format")
  352. }
  353. messages = append(messages, api.Message{Role: msg.Role, Content: text})
  354. case "image_url":
  355. var url string
  356. if urlMap, ok := data["image_url"].(map[string]any); ok {
  357. if url, ok = urlMap["url"].(string); !ok {
  358. return nil, errors.New("invalid message format")
  359. }
  360. } else {
  361. if url, ok = data["image_url"].(string); !ok {
  362. return nil, errors.New("invalid message format")
  363. }
  364. }
  365. types := []string{"jpeg", "jpg", "png"}
  366. valid := false
  367. for _, t := range types {
  368. prefix := "data:image/" + t + ";base64,"
  369. if strings.HasPrefix(url, prefix) {
  370. url = strings.TrimPrefix(url, prefix)
  371. valid = true
  372. break
  373. }
  374. }
  375. if !valid {
  376. return nil, errors.New("invalid image input")
  377. }
  378. img, err := base64.StdEncoding.DecodeString(url)
  379. if err != nil {
  380. return nil, errors.New("invalid message format")
  381. }
  382. messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
  383. default:
  384. return nil, errors.New("invalid message format")
  385. }
  386. }
  387. default:
  388. if msg.ToolCalls == nil {
  389. return nil, fmt.Errorf("invalid message content type: %T", content)
  390. }
  391. toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
  392. for i, tc := range msg.ToolCalls {
  393. toolCalls[i].Function.Name = tc.Function.Name
  394. err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
  395. if err != nil {
  396. return nil, errors.New("invalid tool call arguments")
  397. }
  398. }
  399. messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
  400. }
  401. }
  402. options := make(map[string]interface{})
  403. switch stop := r.Stop.(type) {
  404. case string:
  405. options["stop"] = []string{stop}
  406. case []any:
  407. var stops []string
  408. for _, s := range stop {
  409. if str, ok := s.(string); ok {
  410. stops = append(stops, str)
  411. }
  412. }
  413. options["stop"] = stops
  414. }
  415. if r.MaxTokens != nil {
  416. options["num_predict"] = *r.MaxTokens
  417. }
  418. if r.Temperature != nil {
  419. options["temperature"] = *r.Temperature
  420. } else {
  421. options["temperature"] = 1.0
  422. }
  423. if r.Seed != nil {
  424. options["seed"] = *r.Seed
  425. }
  426. if r.FrequencyPenalty != nil {
  427. options["frequency_penalty"] = *r.FrequencyPenalty
  428. }
  429. if r.PresencePenalty != nil {
  430. options["presence_penalty"] = *r.PresencePenalty
  431. }
  432. if r.TopP != nil {
  433. options["top_p"] = *r.TopP
  434. } else {
  435. options["top_p"] = 1.0
  436. }
  437. var format json.RawMessage
  438. if r.ResponseFormat != nil {
  439. switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
  440. // Support the old "json_object" type for OpenAI compatibility
  441. case "json_object":
  442. format = json.RawMessage(`"json"`)
  443. case "json_schema":
  444. if r.ResponseFormat.JsonSchema != nil {
  445. schema, err := json.Marshal(r.ResponseFormat.JsonSchema.Schema)
  446. if err != nil {
  447. return nil, fmt.Errorf("failed to marshal json schema: %w", err)
  448. }
  449. format = schema
  450. }
  451. }
  452. }
  453. return &api.ChatRequest{
  454. Model: r.Model,
  455. Messages: messages,
  456. Format: format,
  457. Options: options,
  458. Stream: &r.Stream,
  459. Tools: r.Tools,
  460. }, nil
  461. }
  462. func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
  463. options := make(map[string]any)
  464. switch stop := r.Stop.(type) {
  465. case string:
  466. options["stop"] = []string{stop}
  467. case []any:
  468. var stops []string
  469. for _, s := range stop {
  470. if str, ok := s.(string); ok {
  471. stops = append(stops, str)
  472. } else {
  473. return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
  474. }
  475. }
  476. options["stop"] = stops
  477. }
  478. if r.MaxTokens != nil {
  479. options["num_predict"] = *r.MaxTokens
  480. }
  481. if r.Temperature != nil {
  482. options["temperature"] = *r.Temperature
  483. } else {
  484. options["temperature"] = 1.0
  485. }
  486. if r.Seed != nil {
  487. options["seed"] = *r.Seed
  488. }
  489. options["frequency_penalty"] = r.FrequencyPenalty
  490. options["presence_penalty"] = r.PresencePenalty
  491. if r.TopP != 0.0 {
  492. options["top_p"] = r.TopP
  493. } else {
  494. options["top_p"] = 1.0
  495. }
  496. return api.GenerateRequest{
  497. Model: r.Model,
  498. Prompt: r.Prompt,
  499. Options: options,
  500. Stream: &r.Stream,
  501. Suffix: r.Suffix,
  502. }, nil
  503. }
  504. type BaseWriter struct {
  505. gin.ResponseWriter
  506. }
  507. type ChatWriter struct {
  508. stream bool
  509. finished bool
  510. id string
  511. BaseWriter
  512. }
  513. type CompleteWriter struct {
  514. stream bool
  515. id string
  516. BaseWriter
  517. }
  518. type ListWriter struct {
  519. BaseWriter
  520. }
  521. type RetrieveWriter struct {
  522. BaseWriter
  523. model string
  524. }
  525. type EmbedWriter struct {
  526. BaseWriter
  527. model string
  528. }
  529. func (w *BaseWriter) writeError(data []byte) (int, error) {
  530. var serr api.StatusError
  531. err := json.Unmarshal(data, &serr)
  532. if err != nil {
  533. return 0, err
  534. }
  535. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  536. err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
  537. if err != nil {
  538. return 0, err
  539. }
  540. return len(data), nil
  541. }
  542. func (w *ChatWriter) writeResponse(data []byte) (int, error) {
  543. var chatResponse api.ChatResponse
  544. err := json.Unmarshal(data, &chatResponse)
  545. if err != nil {
  546. return 0, err
  547. }
  548. // chat chunk
  549. if w.stream {
  550. // If we've already finished, don't send any more chunks with choices.
  551. if !w.finished {
  552. chunk := toChunk(w.id, chatResponse)
  553. d, err := json.Marshal(chunk)
  554. if err != nil {
  555. return 0, err
  556. }
  557. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  558. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  559. if err != nil {
  560. return 0, err
  561. }
  562. }
  563. if chatResponse.Done {
  564. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  565. if err != nil {
  566. return 0, err
  567. }
  568. }
  569. return len(data), nil
  570. }
  571. // chat completion
  572. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  573. err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
  574. if err != nil {
  575. return 0, err
  576. }
  577. return len(data), nil
  578. }
  579. func (w *ChatWriter) Write(data []byte) (int, error) {
  580. code := w.ResponseWriter.Status()
  581. if code != http.StatusOK {
  582. return w.writeError(data)
  583. }
  584. return w.writeResponse(data)
  585. }
  586. func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
  587. var generateResponse api.GenerateResponse
  588. err := json.Unmarshal(data, &generateResponse)
  589. if err != nil {
  590. return 0, err
  591. }
  592. // completion chunk
  593. if w.stream {
  594. d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
  595. if err != nil {
  596. return 0, err
  597. }
  598. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  599. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  600. if err != nil {
  601. return 0, err
  602. }
  603. if generateResponse.Done {
  604. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  605. if err != nil {
  606. return 0, err
  607. }
  608. }
  609. return len(data), nil
  610. }
  611. // completion
  612. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  613. err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
  614. if err != nil {
  615. return 0, err
  616. }
  617. return len(data), nil
  618. }
  619. func (w *CompleteWriter) Write(data []byte) (int, error) {
  620. code := w.ResponseWriter.Status()
  621. if code != http.StatusOK {
  622. return w.writeError(data)
  623. }
  624. return w.writeResponse(data)
  625. }
  626. func (w *ListWriter) writeResponse(data []byte) (int, error) {
  627. var listResponse api.ListResponse
  628. err := json.Unmarshal(data, &listResponse)
  629. if err != nil {
  630. return 0, err
  631. }
  632. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  633. err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
  634. if err != nil {
  635. return 0, err
  636. }
  637. return len(data), nil
  638. }
  639. func (w *ListWriter) Write(data []byte) (int, error) {
  640. code := w.ResponseWriter.Status()
  641. if code != http.StatusOK {
  642. return w.writeError(data)
  643. }
  644. return w.writeResponse(data)
  645. }
  646. func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
  647. var showResponse api.ShowResponse
  648. err := json.Unmarshal(data, &showResponse)
  649. if err != nil {
  650. return 0, err
  651. }
  652. // retrieve completion
  653. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  654. err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
  655. if err != nil {
  656. return 0, err
  657. }
  658. return len(data), nil
  659. }
  660. func (w *RetrieveWriter) Write(data []byte) (int, error) {
  661. code := w.ResponseWriter.Status()
  662. if code != http.StatusOK {
  663. return w.writeError(data)
  664. }
  665. return w.writeResponse(data)
  666. }
  667. func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
  668. var embedResponse api.EmbedResponse
  669. err := json.Unmarshal(data, &embedResponse)
  670. if err != nil {
  671. return 0, err
  672. }
  673. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  674. err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
  675. if err != nil {
  676. return 0, err
  677. }
  678. return len(data), nil
  679. }
  680. func (w *EmbedWriter) Write(data []byte) (int, error) {
  681. code := w.ResponseWriter.Status()
  682. if code != http.StatusOK {
  683. return w.writeError(data)
  684. }
  685. return w.writeResponse(data)
  686. }
  687. func ListMiddleware() gin.HandlerFunc {
  688. return func(c *gin.Context) {
  689. w := &ListWriter{
  690. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  691. }
  692. c.Writer = w
  693. c.Next()
  694. }
  695. }
  696. func RetrieveMiddleware() gin.HandlerFunc {
  697. return func(c *gin.Context) {
  698. var b bytes.Buffer
  699. if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
  700. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  701. return
  702. }
  703. c.Request.Body = io.NopCloser(&b)
  704. // response writer
  705. w := &RetrieveWriter{
  706. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  707. model: c.Param("model"),
  708. }
  709. c.Writer = w
  710. c.Next()
  711. }
  712. }
  713. func CompletionsMiddleware() gin.HandlerFunc {
  714. return func(c *gin.Context) {
  715. var req CompletionRequest
  716. err := c.ShouldBindJSON(&req)
  717. if err != nil {
  718. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  719. return
  720. }
  721. var b bytes.Buffer
  722. genReq, err := fromCompleteRequest(req)
  723. if err != nil {
  724. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  725. return
  726. }
  727. if err := json.NewEncoder(&b).Encode(genReq); err != nil {
  728. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  729. return
  730. }
  731. c.Request.Body = io.NopCloser(&b)
  732. w := &CompleteWriter{
  733. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  734. stream: req.Stream,
  735. id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
  736. }
  737. c.Writer = w
  738. c.Next()
  739. }
  740. }
  741. func EmbeddingsMiddleware() gin.HandlerFunc {
  742. return func(c *gin.Context) {
  743. var req EmbedRequest
  744. err := c.ShouldBindJSON(&req)
  745. if err != nil {
  746. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  747. return
  748. }
  749. if req.Input == "" {
  750. req.Input = []string{""}
  751. }
  752. if req.Input == nil {
  753. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
  754. return
  755. }
  756. if v, ok := req.Input.([]any); ok && len(v) == 0 {
  757. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
  758. return
  759. }
  760. var b bytes.Buffer
  761. if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
  762. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  763. return
  764. }
  765. c.Request.Body = io.NopCloser(&b)
  766. w := &EmbedWriter{
  767. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  768. model: req.Model,
  769. }
  770. c.Writer = w
  771. c.Next()
  772. }
  773. }
  774. func ChatMiddleware() gin.HandlerFunc {
  775. return func(c *gin.Context) {
  776. var req ChatCompletionRequest
  777. err := c.ShouldBindJSON(&req)
  778. if err != nil {
  779. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  780. return
  781. }
  782. if len(req.Messages) == 0 {
  783. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
  784. return
  785. }
  786. var b bytes.Buffer
  787. chatReq, err := fromChatRequest(req)
  788. if err != nil {
  789. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  790. return
  791. }
  792. if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
  793. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  794. return
  795. }
  796. c.Request.Body = io.NopCloser(&b)
  797. w := &ChatWriter{
  798. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  799. stream: req.Stream,
  800. id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
  801. }
  802. c.Writer = w
  803. c.Next()
  804. }
  805. }