openai.go 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933
  1. // openai package provides middleware for partial compatibility with the OpenAI REST API
  2. package openai
  3. import (
  4. "bytes"
  5. "encoding/base64"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log/slog"
  11. "math/rand"
  12. "net/http"
  13. "strings"
  14. "time"
  15. "github.com/gin-gonic/gin"
  16. "github.com/ollama/ollama/api"
  17. "github.com/ollama/ollama/types/model"
  18. )
  19. type Error struct {
  20. Message string `json:"message"`
  21. Type string `json:"type"`
  22. Param interface{} `json:"param"`
  23. Code *string `json:"code"`
  24. }
  25. type ErrorResponse struct {
  26. Error Error `json:"error"`
  27. }
  28. type Message struct {
  29. Role string `json:"role,omitempty"`
  30. Content any `json:"content"`
  31. ToolCalls []ToolCall `json:"tool_calls,omitempty"`
  32. }
  33. type Choice struct {
  34. Index int `json:"index"`
  35. Message Message `json:"message"`
  36. FinishReason *string `json:"finish_reason"`
  37. }
  38. type ChunkChoice struct {
  39. Index int `json:"index"`
  40. Delta Message `json:"delta"`
  41. FinishReason *string `json:"finish_reason"`
  42. }
  43. type CompleteChunkChoice struct {
  44. Text string `json:"text"`
  45. Index int `json:"index"`
  46. FinishReason *string `json:"finish_reason"`
  47. }
  48. type Usage struct {
  49. PromptTokens int `json:"prompt_tokens"`
  50. CompletionTokens int `json:"completion_tokens"`
  51. TotalTokens int `json:"total_tokens"`
  52. }
  53. type ResponseFormat struct {
  54. Type string `json:"type"`
  55. }
  56. type EmbedRequest struct {
  57. Input any `json:"input"`
  58. Model string `json:"model"`
  59. }
  60. type ChatCompletionRequest struct {
  61. Model string `json:"model"`
  62. Messages []Message `json:"messages"`
  63. Stream bool `json:"stream"`
  64. MaxTokens *int `json:"max_tokens"`
  65. Seed *int `json:"seed"`
  66. Stop any `json:"stop"`
  67. Temperature *float64 `json:"temperature"`
  68. FrequencyPenalty *float64 `json:"frequency_penalty"`
  69. PresencePenalty *float64 `json:"presence_penalty"`
  70. TopP *float64 `json:"top_p"`
  71. ResponseFormat *ResponseFormat `json:"response_format"`
  72. Tools []api.Tool `json:"tools"`
  73. }
  74. type ChatCompletion struct {
  75. Id string `json:"id"`
  76. Object string `json:"object"`
  77. Created int64 `json:"created"`
  78. Model string `json:"model"`
  79. SystemFingerprint string `json:"system_fingerprint"`
  80. Choices []Choice `json:"choices"`
  81. Usage Usage `json:"usage,omitempty"`
  82. }
  83. type ChatCompletionChunk struct {
  84. Id string `json:"id"`
  85. Object string `json:"object"`
  86. Created int64 `json:"created"`
  87. Model string `json:"model"`
  88. SystemFingerprint string `json:"system_fingerprint"`
  89. Choices []ChunkChoice `json:"choices"`
  90. }
  91. // TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
  92. type CompletionRequest struct {
  93. Model string `json:"model"`
  94. Prompt string `json:"prompt"`
  95. FrequencyPenalty float32 `json:"frequency_penalty"`
  96. MaxTokens *int `json:"max_tokens"`
  97. PresencePenalty float32 `json:"presence_penalty"`
  98. Seed *int `json:"seed"`
  99. Stop any `json:"stop"`
  100. Stream bool `json:"stream"`
  101. Temperature *float32 `json:"temperature"`
  102. TopP float32 `json:"top_p"`
  103. Suffix string `json:"suffix"`
  104. }
  105. type Completion struct {
  106. Id string `json:"id"`
  107. Object string `json:"object"`
  108. Created int64 `json:"created"`
  109. Model string `json:"model"`
  110. SystemFingerprint string `json:"system_fingerprint"`
  111. Choices []CompleteChunkChoice `json:"choices"`
  112. Usage Usage `json:"usage,omitempty"`
  113. }
  114. type CompletionChunk struct {
  115. Id string `json:"id"`
  116. Object string `json:"object"`
  117. Created int64 `json:"created"`
  118. Choices []CompleteChunkChoice `json:"choices"`
  119. Model string `json:"model"`
  120. SystemFingerprint string `json:"system_fingerprint"`
  121. }
  122. type ToolCall struct {
  123. ID string `json:"id"`
  124. Type string `json:"type"`
  125. Function struct {
  126. Name string `json:"name"`
  127. Arguments string `json:"arguments"`
  128. } `json:"function"`
  129. }
  130. type Model struct {
  131. Id string `json:"id"`
  132. Object string `json:"object"`
  133. Created int64 `json:"created"`
  134. OwnedBy string `json:"owned_by"`
  135. }
  136. type Embedding struct {
  137. Object string `json:"object"`
  138. Embedding []float32 `json:"embedding"`
  139. Index int `json:"index"`
  140. }
  141. type ListCompletion struct {
  142. Object string `json:"object"`
  143. Data []Model `json:"data"`
  144. }
  145. type EmbeddingList struct {
  146. Object string `json:"object"`
  147. Data []Embedding `json:"data"`
  148. Model string `json:"model"`
  149. Usage EmbeddingUsage `json:"usage,omitempty"`
  150. }
  151. type EmbeddingUsage struct {
  152. PromptTokens int `json:"prompt_tokens"`
  153. TotalTokens int `json:"total_tokens"`
  154. }
  155. func NewError(code int, message string) ErrorResponse {
  156. var etype string
  157. switch code {
  158. case http.StatusBadRequest:
  159. etype = "invalid_request_error"
  160. case http.StatusNotFound:
  161. etype = "not_found_error"
  162. default:
  163. etype = "api_error"
  164. }
  165. return ErrorResponse{Error{Type: etype, Message: message}}
  166. }
  167. func toolCallId() string {
  168. const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
  169. b := make([]byte, 8)
  170. for i := range b {
  171. b[i] = letterBytes[rand.Intn(len(letterBytes))]
  172. }
  173. return "call_" + strings.ToLower(string(b))
  174. }
  175. func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
  176. toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
  177. for i, tc := range r.Message.ToolCalls {
  178. toolCalls[i].ID = toolCallId()
  179. toolCalls[i].Type = "function"
  180. toolCalls[i].Function.Name = tc.Function.Name
  181. args, err := json.Marshal(tc.Function.Arguments)
  182. if err != nil {
  183. slog.Error("could not marshall function arguments to json", "error", err)
  184. continue
  185. }
  186. toolCalls[i].Function.Arguments = string(args)
  187. }
  188. return ChatCompletion{
  189. Id: id,
  190. Object: "chat.completion",
  191. Created: r.CreatedAt.Unix(),
  192. Model: r.Model,
  193. SystemFingerprint: "fp_ollama",
  194. Choices: []Choice{{
  195. Index: 0,
  196. Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
  197. FinishReason: func(reason string) *string {
  198. if len(toolCalls) > 0 {
  199. reason = "tool_calls"
  200. }
  201. if len(reason) > 0 {
  202. return &reason
  203. }
  204. return nil
  205. }(r.DoneReason),
  206. }},
  207. Usage: Usage{
  208. PromptTokens: r.PromptEvalCount,
  209. CompletionTokens: r.EvalCount,
  210. TotalTokens: r.PromptEvalCount + r.EvalCount,
  211. },
  212. }
  213. }
  214. func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
  215. return ChatCompletionChunk{
  216. Id: id,
  217. Object: "chat.completion.chunk",
  218. Created: time.Now().Unix(),
  219. Model: r.Model,
  220. SystemFingerprint: "fp_ollama",
  221. Choices: []ChunkChoice{{
  222. Index: 0,
  223. Delta: Message{Content: r.Message.Content},
  224. FinishReason: func(reason string) *string {
  225. if len(reason) > 0 {
  226. return &reason
  227. }
  228. return nil
  229. }(r.DoneReason),
  230. }},
  231. }
  232. }
  233. func toCompletion(id string, r api.GenerateResponse) Completion {
  234. return Completion{
  235. Id: id,
  236. Object: "text_completion",
  237. Created: r.CreatedAt.Unix(),
  238. Model: r.Model,
  239. SystemFingerprint: "fp_ollama",
  240. Choices: []CompleteChunkChoice{{
  241. Text: r.Response,
  242. Index: 0,
  243. FinishReason: func(reason string) *string {
  244. if len(reason) > 0 {
  245. return &reason
  246. }
  247. return nil
  248. }(r.DoneReason),
  249. }},
  250. Usage: Usage{
  251. PromptTokens: r.PromptEvalCount,
  252. CompletionTokens: r.EvalCount,
  253. TotalTokens: r.PromptEvalCount + r.EvalCount,
  254. },
  255. }
  256. }
  257. func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
  258. return CompletionChunk{
  259. Id: id,
  260. Object: "text_completion",
  261. Created: time.Now().Unix(),
  262. Model: r.Model,
  263. SystemFingerprint: "fp_ollama",
  264. Choices: []CompleteChunkChoice{{
  265. Text: r.Response,
  266. Index: 0,
  267. FinishReason: func(reason string) *string {
  268. if len(reason) > 0 {
  269. return &reason
  270. }
  271. return nil
  272. }(r.DoneReason),
  273. }},
  274. }
  275. }
  276. func toListCompletion(r api.ListResponse) ListCompletion {
  277. var data []Model
  278. for _, m := range r.Models {
  279. data = append(data, Model{
  280. Id: m.Name,
  281. Object: "model",
  282. Created: m.ModifiedAt.Unix(),
  283. OwnedBy: model.ParseName(m.Name).Namespace,
  284. })
  285. }
  286. return ListCompletion{
  287. Object: "list",
  288. Data: data,
  289. }
  290. }
  291. func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
  292. if r.Embeddings != nil {
  293. var data []Embedding
  294. for i, e := range r.Embeddings {
  295. data = append(data, Embedding{
  296. Object: "embedding",
  297. Embedding: e,
  298. Index: i,
  299. })
  300. }
  301. return EmbeddingList{
  302. Object: "list",
  303. Data: data,
  304. Model: model,
  305. Usage: EmbeddingUsage{
  306. PromptTokens: r.PromptEvalCount,
  307. TotalTokens: r.PromptEvalCount,
  308. },
  309. }
  310. }
  311. return EmbeddingList{}
  312. }
  313. func toModel(r api.ShowResponse, m string) Model {
  314. return Model{
  315. Id: m,
  316. Object: "model",
  317. Created: r.ModifiedAt.Unix(),
  318. OwnedBy: model.ParseName(m).Namespace,
  319. }
  320. }
  321. func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
  322. var messages []api.Message
  323. for _, msg := range r.Messages {
  324. switch content := msg.Content.(type) {
  325. case string:
  326. messages = append(messages, api.Message{Role: msg.Role, Content: content})
  327. case []any:
  328. for _, c := range content {
  329. data, ok := c.(map[string]any)
  330. if !ok {
  331. return nil, errors.New("invalid message format")
  332. }
  333. switch data["type"] {
  334. case "text":
  335. text, ok := data["text"].(string)
  336. if !ok {
  337. return nil, errors.New("invalid message format")
  338. }
  339. messages = append(messages, api.Message{Role: msg.Role, Content: text})
  340. case "image_url":
  341. var url string
  342. if urlMap, ok := data["image_url"].(map[string]any); ok {
  343. if url, ok = urlMap["url"].(string); !ok {
  344. return nil, errors.New("invalid message format")
  345. }
  346. } else {
  347. if url, ok = data["image_url"].(string); !ok {
  348. return nil, errors.New("invalid message format")
  349. }
  350. }
  351. types := []string{"jpeg", "jpg", "png"}
  352. valid := false
  353. for _, t := range types {
  354. prefix := "data:image/" + t + ";base64,"
  355. if strings.HasPrefix(url, prefix) {
  356. url = strings.TrimPrefix(url, prefix)
  357. valid = true
  358. break
  359. }
  360. }
  361. if !valid {
  362. return nil, errors.New("invalid image input")
  363. }
  364. img, err := base64.StdEncoding.DecodeString(url)
  365. if err != nil {
  366. return nil, errors.New("invalid message format")
  367. }
  368. messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
  369. default:
  370. return nil, errors.New("invalid message format")
  371. }
  372. }
  373. default:
  374. if msg.ToolCalls == nil {
  375. return nil, fmt.Errorf("invalid message content type: %T", content)
  376. }
  377. toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
  378. for i, tc := range msg.ToolCalls {
  379. toolCalls[i].Function.Name = tc.Function.Name
  380. err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
  381. if err != nil {
  382. return nil, errors.New("invalid tool call arguments")
  383. }
  384. }
  385. messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
  386. }
  387. }
  388. options := make(map[string]interface{})
  389. switch stop := r.Stop.(type) {
  390. case string:
  391. options["stop"] = []string{stop}
  392. case []any:
  393. var stops []string
  394. for _, s := range stop {
  395. if str, ok := s.(string); ok {
  396. stops = append(stops, str)
  397. }
  398. }
  399. options["stop"] = stops
  400. }
  401. if r.MaxTokens != nil {
  402. options["num_predict"] = *r.MaxTokens
  403. }
  404. if r.Temperature != nil {
  405. options["temperature"] = *r.Temperature
  406. } else {
  407. options["temperature"] = 1.0
  408. }
  409. if r.Seed != nil {
  410. options["seed"] = *r.Seed
  411. }
  412. if r.FrequencyPenalty != nil {
  413. options["frequency_penalty"] = *r.FrequencyPenalty
  414. }
  415. if r.PresencePenalty != nil {
  416. options["presence_penalty"] = *r.PresencePenalty
  417. }
  418. if r.TopP != nil {
  419. options["top_p"] = *r.TopP
  420. } else {
  421. options["top_p"] = 1.0
  422. }
  423. var format string
  424. if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" {
  425. format = "json"
  426. }
  427. return &api.ChatRequest{
  428. Model: r.Model,
  429. Messages: messages,
  430. Format: format,
  431. Options: options,
  432. Stream: &r.Stream,
  433. Tools: r.Tools,
  434. }, nil
  435. }
  436. func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
  437. options := make(map[string]any)
  438. switch stop := r.Stop.(type) {
  439. case string:
  440. options["stop"] = []string{stop}
  441. case []any:
  442. var stops []string
  443. for _, s := range stop {
  444. if str, ok := s.(string); ok {
  445. stops = append(stops, str)
  446. } else {
  447. return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
  448. }
  449. }
  450. options["stop"] = stops
  451. }
  452. if r.MaxTokens != nil {
  453. options["num_predict"] = *r.MaxTokens
  454. }
  455. if r.Temperature != nil {
  456. options["temperature"] = *r.Temperature
  457. } else {
  458. options["temperature"] = 1.0
  459. }
  460. if r.Seed != nil {
  461. options["seed"] = *r.Seed
  462. }
  463. options["frequency_penalty"] = r.FrequencyPenalty
  464. options["presence_penalty"] = r.PresencePenalty
  465. if r.TopP != 0.0 {
  466. options["top_p"] = r.TopP
  467. } else {
  468. options["top_p"] = 1.0
  469. }
  470. return api.GenerateRequest{
  471. Model: r.Model,
  472. Prompt: r.Prompt,
  473. Options: options,
  474. Stream: &r.Stream,
  475. Suffix: r.Suffix,
  476. }, nil
  477. }
  478. type BaseWriter struct {
  479. gin.ResponseWriter
  480. }
  481. type ChatWriter struct {
  482. stream bool
  483. started bool
  484. id string
  485. BaseWriter
  486. }
  487. type CompleteWriter struct {
  488. stream bool
  489. id string
  490. BaseWriter
  491. }
  492. type ListWriter struct {
  493. BaseWriter
  494. }
  495. type RetrieveWriter struct {
  496. BaseWriter
  497. model string
  498. }
  499. type EmbedWriter struct {
  500. BaseWriter
  501. model string
  502. }
  503. func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
  504. var serr api.StatusError
  505. err := json.Unmarshal(data, &serr)
  506. if err != nil {
  507. return 0, err
  508. }
  509. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  510. err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
  511. if err != nil {
  512. return 0, err
  513. }
  514. return len(data), nil
  515. }
  516. func (w *ChatWriter) writeResponse(data []byte) (int, error) {
  517. var chatResponse api.ChatResponse
  518. err := json.Unmarshal(data, &chatResponse)
  519. if err != nil {
  520. return 0, err
  521. }
  522. if w.stream {
  523. // The first chunk always has empty content so we
  524. // copy the first chunk and set the content to
  525. // empty, and send it first.
  526. if !w.started {
  527. first := chatResponse
  528. first.Message = api.Message{Role: "assistant"}
  529. d, err := json.Marshal(toChunk(w.id, first))
  530. if err != nil {
  531. return 0, err
  532. }
  533. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  534. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  535. if err != nil {
  536. return 0, err
  537. }
  538. w.started = true
  539. }
  540. d, err := json.Marshal(toChunk(w.id, chatResponse))
  541. if err != nil {
  542. return 0, err
  543. }
  544. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  545. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  546. if err != nil {
  547. return 0, err
  548. }
  549. if chatResponse.Done {
  550. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  551. if err != nil {
  552. return 0, err
  553. }
  554. }
  555. return len(data), nil
  556. }
  557. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  558. err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
  559. if err != nil {
  560. return 0, err
  561. }
  562. return len(data), nil
  563. }
  564. func (w *ChatWriter) Write(data []byte) (int, error) {
  565. code := w.ResponseWriter.Status()
  566. if code != http.StatusOK {
  567. return w.writeError(code, data)
  568. }
  569. return w.writeResponse(data)
  570. }
  571. func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
  572. var generateResponse api.GenerateResponse
  573. err := json.Unmarshal(data, &generateResponse)
  574. if err != nil {
  575. return 0, err
  576. }
  577. // completion chunk
  578. if w.stream {
  579. d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
  580. if err != nil {
  581. return 0, err
  582. }
  583. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  584. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  585. if err != nil {
  586. return 0, err
  587. }
  588. if generateResponse.Done {
  589. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  590. if err != nil {
  591. return 0, err
  592. }
  593. }
  594. return len(data), nil
  595. }
  596. // completion
  597. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  598. err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
  599. if err != nil {
  600. return 0, err
  601. }
  602. return len(data), nil
  603. }
  604. func (w *CompleteWriter) Write(data []byte) (int, error) {
  605. code := w.ResponseWriter.Status()
  606. if code != http.StatusOK {
  607. return w.writeError(code, data)
  608. }
  609. return w.writeResponse(data)
  610. }
  611. func (w *ListWriter) writeResponse(data []byte) (int, error) {
  612. var listResponse api.ListResponse
  613. err := json.Unmarshal(data, &listResponse)
  614. if err != nil {
  615. return 0, err
  616. }
  617. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  618. err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
  619. if err != nil {
  620. return 0, err
  621. }
  622. return len(data), nil
  623. }
  624. func (w *ListWriter) Write(data []byte) (int, error) {
  625. code := w.ResponseWriter.Status()
  626. if code != http.StatusOK {
  627. return w.writeError(code, data)
  628. }
  629. return w.writeResponse(data)
  630. }
  631. func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
  632. var showResponse api.ShowResponse
  633. err := json.Unmarshal(data, &showResponse)
  634. if err != nil {
  635. return 0, err
  636. }
  637. // retrieve completion
  638. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  639. err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
  640. if err != nil {
  641. return 0, err
  642. }
  643. return len(data), nil
  644. }
  645. func (w *RetrieveWriter) Write(data []byte) (int, error) {
  646. code := w.ResponseWriter.Status()
  647. if code != http.StatusOK {
  648. return w.writeError(code, data)
  649. }
  650. return w.writeResponse(data)
  651. }
  652. func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
  653. var embedResponse api.EmbedResponse
  654. err := json.Unmarshal(data, &embedResponse)
  655. if err != nil {
  656. return 0, err
  657. }
  658. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  659. err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
  660. if err != nil {
  661. return 0, err
  662. }
  663. return len(data), nil
  664. }
  665. func (w *EmbedWriter) Write(data []byte) (int, error) {
  666. code := w.ResponseWriter.Status()
  667. if code != http.StatusOK {
  668. return w.writeError(code, data)
  669. }
  670. return w.writeResponse(data)
  671. }
  672. func ListMiddleware() gin.HandlerFunc {
  673. return func(c *gin.Context) {
  674. w := &ListWriter{
  675. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  676. }
  677. c.Writer = w
  678. c.Next()
  679. }
  680. }
  681. func RetrieveMiddleware() gin.HandlerFunc {
  682. return func(c *gin.Context) {
  683. var b bytes.Buffer
  684. if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
  685. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  686. return
  687. }
  688. c.Request.Body = io.NopCloser(&b)
  689. // response writer
  690. w := &RetrieveWriter{
  691. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  692. model: c.Param("model"),
  693. }
  694. c.Writer = w
  695. c.Next()
  696. }
  697. }
  698. func CompletionsMiddleware() gin.HandlerFunc {
  699. return func(c *gin.Context) {
  700. var req CompletionRequest
  701. err := c.ShouldBindJSON(&req)
  702. if err != nil {
  703. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  704. return
  705. }
  706. var b bytes.Buffer
  707. genReq, err := fromCompleteRequest(req)
  708. if err != nil {
  709. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  710. return
  711. }
  712. if err := json.NewEncoder(&b).Encode(genReq); err != nil {
  713. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  714. return
  715. }
  716. c.Request.Body = io.NopCloser(&b)
  717. w := &CompleteWriter{
  718. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  719. stream: req.Stream,
  720. id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
  721. }
  722. c.Writer = w
  723. c.Next()
  724. }
  725. }
  726. func EmbeddingsMiddleware() gin.HandlerFunc {
  727. return func(c *gin.Context) {
  728. var req EmbedRequest
  729. err := c.ShouldBindJSON(&req)
  730. if err != nil {
  731. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  732. return
  733. }
  734. if req.Input == "" {
  735. req.Input = []string{""}
  736. }
  737. if req.Input == nil {
  738. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
  739. return
  740. }
  741. if v, ok := req.Input.([]any); ok && len(v) == 0 {
  742. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
  743. return
  744. }
  745. var b bytes.Buffer
  746. if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
  747. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  748. return
  749. }
  750. c.Request.Body = io.NopCloser(&b)
  751. w := &EmbedWriter{
  752. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  753. model: req.Model,
  754. }
  755. c.Writer = w
  756. c.Next()
  757. }
  758. }
  759. func ChatMiddleware() gin.HandlerFunc {
  760. return func(c *gin.Context) {
  761. var req ChatCompletionRequest
  762. err := c.ShouldBindJSON(&req)
  763. if err != nil {
  764. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  765. return
  766. }
  767. if len(req.Messages) == 0 {
  768. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
  769. return
  770. }
  771. var b bytes.Buffer
  772. chatReq, err := fromChatRequest(req)
  773. if err != nil {
  774. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  775. return
  776. }
  777. if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
  778. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  779. return
  780. }
  781. c.Request.Body = io.NopCloser(&b)
  782. w := &ChatWriter{
  783. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  784. stream: req.Stream,
  785. id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
  786. }
  787. c.Writer = w
  788. c.Next()
  789. }
  790. }