openai.go 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913
  1. // openai package provides middleware for partial compatibility with the OpenAI REST API
  2. package openai
  3. import (
  4. "bytes"
  5. "encoding/base64"
  6. "encoding/json"
  7. "fmt"
  8. "io"
  9. "log/slog"
  10. "math/rand"
  11. "net/http"
  12. "strings"
  13. "time"
  14. "github.com/gin-gonic/gin"
  15. "github.com/ollama/ollama/api"
  16. "github.com/ollama/ollama/types/model"
  17. )
  18. type Error struct {
  19. Message string `json:"message"`
  20. Type string `json:"type"`
  21. Param interface{} `json:"param"`
  22. Code *string `json:"code"`
  23. }
  24. type ErrorResponse struct {
  25. Error Error `json:"error"`
  26. }
  27. type Message struct {
  28. Role string `json:"role"`
  29. Content any `json:"content"`
  30. ToolCalls []ToolCall `json:"tool_calls,omitempty"`
  31. }
  32. type Choice struct {
  33. Index int `json:"index"`
  34. Message Message `json:"message"`
  35. FinishReason *string `json:"finish_reason"`
  36. }
  37. type ChunkChoice struct {
  38. Index int `json:"index"`
  39. Delta Message `json:"delta"`
  40. FinishReason *string `json:"finish_reason"`
  41. }
  42. type CompleteChunkChoice struct {
  43. Text string `json:"text"`
  44. Index int `json:"index"`
  45. FinishReason *string `json:"finish_reason"`
  46. }
  47. type Usage struct {
  48. PromptTokens int `json:"prompt_tokens"`
  49. CompletionTokens int `json:"completion_tokens"`
  50. TotalTokens int `json:"total_tokens"`
  51. }
  52. type ResponseFormat struct {
  53. Type string `json:"type"`
  54. }
  55. type EmbedRequest struct {
  56. Input any `json:"input"`
  57. Model string `json:"model"`
  58. }
  59. type ChatCompletionRequest struct {
  60. Model string `json:"model"`
  61. Messages []Message `json:"messages"`
  62. Stream bool `json:"stream"`
  63. MaxTokens *int `json:"max_tokens"`
  64. Seed *int `json:"seed"`
  65. Stop any `json:"stop"`
  66. Temperature *float64 `json:"temperature"`
  67. FrequencyPenalty *float64 `json:"frequency_penalty"`
  68. PresencePenalty *float64 `json:"presence_penalty_penalty"`
  69. TopP *float64 `json:"top_p"`
  70. ResponseFormat *ResponseFormat `json:"response_format"`
  71. Tools []api.Tool `json:"tools"`
  72. }
  73. type ChatCompletion struct {
  74. Id string `json:"id"`
  75. Object string `json:"object"`
  76. Created int64 `json:"created"`
  77. Model string `json:"model"`
  78. SystemFingerprint string `json:"system_fingerprint"`
  79. Choices []Choice `json:"choices"`
  80. Usage Usage `json:"usage,omitempty"`
  81. }
  82. type ChatCompletionChunk struct {
  83. Id string `json:"id"`
  84. Object string `json:"object"`
  85. Created int64 `json:"created"`
  86. Model string `json:"model"`
  87. SystemFingerprint string `json:"system_fingerprint"`
  88. Choices []ChunkChoice `json:"choices"`
  89. }
  90. // TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
  91. type CompletionRequest struct {
  92. Model string `json:"model"`
  93. Prompt string `json:"prompt"`
  94. FrequencyPenalty float32 `json:"frequency_penalty"`
  95. MaxTokens *int `json:"max_tokens"`
  96. PresencePenalty float32 `json:"presence_penalty"`
  97. Seed *int `json:"seed"`
  98. Stop any `json:"stop"`
  99. Stream bool `json:"stream"`
  100. Temperature *float32 `json:"temperature"`
  101. TopP float32 `json:"top_p"`
  102. Suffix string `json:"suffix"`
  103. }
  104. type Completion struct {
  105. Id string `json:"id"`
  106. Object string `json:"object"`
  107. Created int64 `json:"created"`
  108. Model string `json:"model"`
  109. SystemFingerprint string `json:"system_fingerprint"`
  110. Choices []CompleteChunkChoice `json:"choices"`
  111. Usage Usage `json:"usage,omitempty"`
  112. }
  113. type CompletionChunk struct {
  114. Id string `json:"id"`
  115. Object string `json:"object"`
  116. Created int64 `json:"created"`
  117. Choices []CompleteChunkChoice `json:"choices"`
  118. Model string `json:"model"`
  119. SystemFingerprint string `json:"system_fingerprint"`
  120. }
  121. type ToolCall struct {
  122. ID string `json:"id"`
  123. Type string `json:"type"`
  124. Function struct {
  125. Name string `json:"name"`
  126. Arguments string `json:"arguments"`
  127. } `json:"function"`
  128. }
  129. type Model struct {
  130. Id string `json:"id"`
  131. Object string `json:"object"`
  132. Created int64 `json:"created"`
  133. OwnedBy string `json:"owned_by"`
  134. }
  135. type Embedding struct {
  136. Object string `json:"object"`
  137. Embedding []float32 `json:"embedding"`
  138. Index int `json:"index"`
  139. }
  140. type ListCompletion struct {
  141. Object string `json:"object"`
  142. Data []Model `json:"data"`
  143. }
  144. type EmbeddingList struct {
  145. Object string `json:"object"`
  146. Data []Embedding `json:"data"`
  147. Model string `json:"model"`
  148. Usage EmbeddingUsage `json:"usage,omitempty"`
  149. }
  150. type EmbeddingUsage struct {
  151. PromptTokens int `json:"prompt_tokens"`
  152. TotalTokens int `json:"total_tokens"`
  153. }
  154. func NewError(code int, message string) ErrorResponse {
  155. var etype string
  156. switch code {
  157. case http.StatusBadRequest:
  158. etype = "invalid_request_error"
  159. case http.StatusNotFound:
  160. etype = "not_found_error"
  161. default:
  162. etype = "api_error"
  163. }
  164. return ErrorResponse{Error{Type: etype, Message: message}}
  165. }
  166. func toolCallId() string {
  167. const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
  168. b := make([]byte, 8)
  169. for i := range b {
  170. b[i] = letterBytes[rand.Intn(len(letterBytes))]
  171. }
  172. return "call_" + strings.ToLower(string(b))
  173. }
  174. func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
  175. toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
  176. for i, tc := range r.Message.ToolCalls {
  177. toolCalls[i].ID = toolCallId()
  178. toolCalls[i].Type = "function"
  179. toolCalls[i].Function.Name = tc.Function.Name
  180. args, err := json.Marshal(tc.Function.Arguments)
  181. if err != nil {
  182. slog.Error("could not marshall function arguments to json", "error", err)
  183. continue
  184. }
  185. toolCalls[i].Function.Arguments = string(args)
  186. }
  187. return ChatCompletion{
  188. Id: id,
  189. Object: "chat.completion",
  190. Created: r.CreatedAt.Unix(),
  191. Model: r.Model,
  192. SystemFingerprint: "fp_ollama",
  193. Choices: []Choice{{
  194. Index: 0,
  195. Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
  196. FinishReason: func(reason string) *string {
  197. if len(toolCalls) > 0 {
  198. reason = "tool_calls"
  199. }
  200. if len(reason) > 0 {
  201. return &reason
  202. }
  203. return nil
  204. }(r.DoneReason),
  205. }},
  206. Usage: Usage{
  207. PromptTokens: r.PromptEvalCount,
  208. CompletionTokens: r.EvalCount,
  209. TotalTokens: r.PromptEvalCount + r.EvalCount,
  210. },
  211. }
  212. }
  213. func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
  214. return ChatCompletionChunk{
  215. Id: id,
  216. Object: "chat.completion.chunk",
  217. Created: time.Now().Unix(),
  218. Model: r.Model,
  219. SystemFingerprint: "fp_ollama",
  220. Choices: []ChunkChoice{{
  221. Index: 0,
  222. Delta: Message{Role: "assistant", Content: r.Message.Content},
  223. FinishReason: func(reason string) *string {
  224. if len(reason) > 0 {
  225. return &reason
  226. }
  227. return nil
  228. }(r.DoneReason),
  229. }},
  230. }
  231. }
  232. func toCompletion(id string, r api.GenerateResponse) Completion {
  233. return Completion{
  234. Id: id,
  235. Object: "text_completion",
  236. Created: r.CreatedAt.Unix(),
  237. Model: r.Model,
  238. SystemFingerprint: "fp_ollama",
  239. Choices: []CompleteChunkChoice{{
  240. Text: r.Response,
  241. Index: 0,
  242. FinishReason: func(reason string) *string {
  243. if len(reason) > 0 {
  244. return &reason
  245. }
  246. return nil
  247. }(r.DoneReason),
  248. }},
  249. Usage: Usage{
  250. PromptTokens: r.PromptEvalCount,
  251. CompletionTokens: r.EvalCount,
  252. TotalTokens: r.PromptEvalCount + r.EvalCount,
  253. },
  254. }
  255. }
  256. func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
  257. return CompletionChunk{
  258. Id: id,
  259. Object: "text_completion",
  260. Created: time.Now().Unix(),
  261. Model: r.Model,
  262. SystemFingerprint: "fp_ollama",
  263. Choices: []CompleteChunkChoice{{
  264. Text: r.Response,
  265. Index: 0,
  266. FinishReason: func(reason string) *string {
  267. if len(reason) > 0 {
  268. return &reason
  269. }
  270. return nil
  271. }(r.DoneReason),
  272. }},
  273. }
  274. }
  275. func toListCompletion(r api.ListResponse) ListCompletion {
  276. var data []Model
  277. for _, m := range r.Models {
  278. data = append(data, Model{
  279. Id: m.Name,
  280. Object: "model",
  281. Created: m.ModifiedAt.Unix(),
  282. OwnedBy: model.ParseName(m.Name).Namespace,
  283. })
  284. }
  285. return ListCompletion{
  286. Object: "list",
  287. Data: data,
  288. }
  289. }
  290. func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
  291. if r.Embeddings != nil {
  292. var data []Embedding
  293. for i, e := range r.Embeddings {
  294. data = append(data, Embedding{
  295. Object: "embedding",
  296. Embedding: e,
  297. Index: i,
  298. })
  299. }
  300. return EmbeddingList{
  301. Object: "list",
  302. Data: data,
  303. Model: model,
  304. Usage: EmbeddingUsage{
  305. PromptTokens: r.PromptEvalCount,
  306. TotalTokens: r.PromptEvalCount,
  307. },
  308. }
  309. }
  310. return EmbeddingList{}
  311. }
  312. func toModel(r api.ShowResponse, m string) Model {
  313. return Model{
  314. Id: m,
  315. Object: "model",
  316. Created: r.ModifiedAt.Unix(),
  317. OwnedBy: model.ParseName(m).Namespace,
  318. }
  319. }
  320. func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
  321. var messages []api.Message
  322. for _, msg := range r.Messages {
  323. switch content := msg.Content.(type) {
  324. case string:
  325. messages = append(messages, api.Message{Role: msg.Role, Content: content})
  326. case []any:
  327. for _, c := range content {
  328. data, ok := c.(map[string]any)
  329. if !ok {
  330. return nil, fmt.Errorf("invalid message format")
  331. }
  332. switch data["type"] {
  333. case "text":
  334. text, ok := data["text"].(string)
  335. if !ok {
  336. return nil, fmt.Errorf("invalid message format")
  337. }
  338. messages = append(messages, api.Message{Role: msg.Role, Content: text})
  339. case "image_url":
  340. var url string
  341. if urlMap, ok := data["image_url"].(map[string]any); ok {
  342. if url, ok = urlMap["url"].(string); !ok {
  343. return nil, fmt.Errorf("invalid message format")
  344. }
  345. } else {
  346. if url, ok = data["image_url"].(string); !ok {
  347. return nil, fmt.Errorf("invalid message format")
  348. }
  349. }
  350. types := []string{"jpeg", "jpg", "png"}
  351. valid := false
  352. for _, t := range types {
  353. prefix := "data:image/" + t + ";base64,"
  354. if strings.HasPrefix(url, prefix) {
  355. url = strings.TrimPrefix(url, prefix)
  356. valid = true
  357. break
  358. }
  359. }
  360. if !valid {
  361. return nil, fmt.Errorf("invalid image input")
  362. }
  363. img, err := base64.StdEncoding.DecodeString(url)
  364. if err != nil {
  365. return nil, fmt.Errorf("invalid message format")
  366. }
  367. messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
  368. default:
  369. return nil, fmt.Errorf("invalid message format")
  370. }
  371. }
  372. default:
  373. if msg.ToolCalls == nil {
  374. return nil, fmt.Errorf("invalid message content type: %T", content)
  375. }
  376. toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
  377. for i, tc := range msg.ToolCalls {
  378. toolCalls[i].Function.Name = tc.Function.Name
  379. err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
  380. if err != nil {
  381. return nil, fmt.Errorf("invalid tool call arguments")
  382. }
  383. }
  384. messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
  385. }
  386. }
  387. options := make(map[string]interface{})
  388. switch stop := r.Stop.(type) {
  389. case string:
  390. options["stop"] = []string{stop}
  391. case []any:
  392. var stops []string
  393. for _, s := range stop {
  394. if str, ok := s.(string); ok {
  395. stops = append(stops, str)
  396. }
  397. }
  398. options["stop"] = stops
  399. }
  400. if r.MaxTokens != nil {
  401. options["num_predict"] = *r.MaxTokens
  402. }
  403. if r.Temperature != nil {
  404. options["temperature"] = *r.Temperature * 2.0
  405. } else {
  406. options["temperature"] = 1.0
  407. }
  408. if r.Seed != nil {
  409. options["seed"] = *r.Seed
  410. }
  411. if r.FrequencyPenalty != nil {
  412. options["frequency_penalty"] = *r.FrequencyPenalty * 2.0
  413. }
  414. if r.PresencePenalty != nil {
  415. options["presence_penalty"] = *r.PresencePenalty * 2.0
  416. }
  417. if r.TopP != nil {
  418. options["top_p"] = *r.TopP
  419. } else {
  420. options["top_p"] = 1.0
  421. }
  422. var format string
  423. if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" {
  424. format = "json"
  425. }
  426. return &api.ChatRequest{
  427. Model: r.Model,
  428. Messages: messages,
  429. Format: format,
  430. Options: options,
  431. Stream: &r.Stream,
  432. Tools: r.Tools,
  433. }, nil
  434. }
  435. func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
  436. options := make(map[string]any)
  437. switch stop := r.Stop.(type) {
  438. case string:
  439. options["stop"] = []string{stop}
  440. case []any:
  441. var stops []string
  442. for _, s := range stop {
  443. if str, ok := s.(string); ok {
  444. stops = append(stops, str)
  445. } else {
  446. return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
  447. }
  448. }
  449. options["stop"] = stops
  450. }
  451. if r.MaxTokens != nil {
  452. options["num_predict"] = *r.MaxTokens
  453. }
  454. if r.Temperature != nil {
  455. options["temperature"] = *r.Temperature * 2.0
  456. } else {
  457. options["temperature"] = 1.0
  458. }
  459. if r.Seed != nil {
  460. options["seed"] = *r.Seed
  461. }
  462. options["frequency_penalty"] = r.FrequencyPenalty * 2.0
  463. options["presence_penalty"] = r.PresencePenalty * 2.0
  464. if r.TopP != 0.0 {
  465. options["top_p"] = r.TopP
  466. } else {
  467. options["top_p"] = 1.0
  468. }
  469. return api.GenerateRequest{
  470. Model: r.Model,
  471. Prompt: r.Prompt,
  472. Options: options,
  473. Stream: &r.Stream,
  474. Suffix: r.Suffix,
  475. }, nil
  476. }
  477. type BaseWriter struct {
  478. gin.ResponseWriter
  479. }
  480. type ChatWriter struct {
  481. stream bool
  482. id string
  483. BaseWriter
  484. }
  485. type CompleteWriter struct {
  486. stream bool
  487. id string
  488. BaseWriter
  489. }
  490. type ListWriter struct {
  491. BaseWriter
  492. }
  493. type RetrieveWriter struct {
  494. BaseWriter
  495. model string
  496. }
  497. type EmbedWriter struct {
  498. BaseWriter
  499. model string
  500. }
  501. func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
  502. var serr api.StatusError
  503. err := json.Unmarshal(data, &serr)
  504. if err != nil {
  505. return 0, err
  506. }
  507. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  508. err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
  509. if err != nil {
  510. return 0, err
  511. }
  512. return len(data), nil
  513. }
  514. func (w *ChatWriter) writeResponse(data []byte) (int, error) {
  515. var chatResponse api.ChatResponse
  516. err := json.Unmarshal(data, &chatResponse)
  517. if err != nil {
  518. return 0, err
  519. }
  520. // chat chunk
  521. if w.stream {
  522. d, err := json.Marshal(toChunk(w.id, chatResponse))
  523. if err != nil {
  524. return 0, err
  525. }
  526. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  527. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  528. if err != nil {
  529. return 0, err
  530. }
  531. if chatResponse.Done {
  532. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  533. if err != nil {
  534. return 0, err
  535. }
  536. }
  537. return len(data), nil
  538. }
  539. // chat completion
  540. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  541. err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
  542. if err != nil {
  543. return 0, err
  544. }
  545. return len(data), nil
  546. }
  547. func (w *ChatWriter) Write(data []byte) (int, error) {
  548. code := w.ResponseWriter.Status()
  549. if code != http.StatusOK {
  550. return w.writeError(code, data)
  551. }
  552. return w.writeResponse(data)
  553. }
  554. func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
  555. var generateResponse api.GenerateResponse
  556. err := json.Unmarshal(data, &generateResponse)
  557. if err != nil {
  558. return 0, err
  559. }
  560. // completion chunk
  561. if w.stream {
  562. d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
  563. if err != nil {
  564. return 0, err
  565. }
  566. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  567. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  568. if err != nil {
  569. return 0, err
  570. }
  571. if generateResponse.Done {
  572. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  573. if err != nil {
  574. return 0, err
  575. }
  576. }
  577. return len(data), nil
  578. }
  579. // completion
  580. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  581. err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
  582. if err != nil {
  583. return 0, err
  584. }
  585. return len(data), nil
  586. }
  587. func (w *CompleteWriter) Write(data []byte) (int, error) {
  588. code := w.ResponseWriter.Status()
  589. if code != http.StatusOK {
  590. return w.writeError(code, data)
  591. }
  592. return w.writeResponse(data)
  593. }
  594. func (w *ListWriter) writeResponse(data []byte) (int, error) {
  595. var listResponse api.ListResponse
  596. err := json.Unmarshal(data, &listResponse)
  597. if err != nil {
  598. return 0, err
  599. }
  600. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  601. err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
  602. if err != nil {
  603. return 0, err
  604. }
  605. return len(data), nil
  606. }
  607. func (w *ListWriter) Write(data []byte) (int, error) {
  608. code := w.ResponseWriter.Status()
  609. if code != http.StatusOK {
  610. return w.writeError(code, data)
  611. }
  612. return w.writeResponse(data)
  613. }
  614. func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
  615. var showResponse api.ShowResponse
  616. err := json.Unmarshal(data, &showResponse)
  617. if err != nil {
  618. return 0, err
  619. }
  620. // retrieve completion
  621. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  622. err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
  623. if err != nil {
  624. return 0, err
  625. }
  626. return len(data), nil
  627. }
  628. func (w *RetrieveWriter) Write(data []byte) (int, error) {
  629. code := w.ResponseWriter.Status()
  630. if code != http.StatusOK {
  631. return w.writeError(code, data)
  632. }
  633. return w.writeResponse(data)
  634. }
  635. func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
  636. var embedResponse api.EmbedResponse
  637. err := json.Unmarshal(data, &embedResponse)
  638. if err != nil {
  639. return 0, err
  640. }
  641. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  642. err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
  643. if err != nil {
  644. return 0, err
  645. }
  646. return len(data), nil
  647. }
  648. func (w *EmbedWriter) Write(data []byte) (int, error) {
  649. code := w.ResponseWriter.Status()
  650. if code != http.StatusOK {
  651. return w.writeError(code, data)
  652. }
  653. return w.writeResponse(data)
  654. }
  655. func ListMiddleware() gin.HandlerFunc {
  656. return func(c *gin.Context) {
  657. w := &ListWriter{
  658. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  659. }
  660. c.Writer = w
  661. c.Next()
  662. }
  663. }
  664. func RetrieveMiddleware() gin.HandlerFunc {
  665. return func(c *gin.Context) {
  666. var b bytes.Buffer
  667. if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
  668. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  669. return
  670. }
  671. c.Request.Body = io.NopCloser(&b)
  672. // response writer
  673. w := &RetrieveWriter{
  674. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  675. model: c.Param("model"),
  676. }
  677. c.Writer = w
  678. c.Next()
  679. }
  680. }
  681. func CompletionsMiddleware() gin.HandlerFunc {
  682. return func(c *gin.Context) {
  683. var req CompletionRequest
  684. err := c.ShouldBindJSON(&req)
  685. if err != nil {
  686. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  687. return
  688. }
  689. var b bytes.Buffer
  690. genReq, err := fromCompleteRequest(req)
  691. if err != nil {
  692. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  693. return
  694. }
  695. if err := json.NewEncoder(&b).Encode(genReq); err != nil {
  696. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  697. return
  698. }
  699. c.Request.Body = io.NopCloser(&b)
  700. w := &CompleteWriter{
  701. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  702. stream: req.Stream,
  703. id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
  704. }
  705. c.Writer = w
  706. c.Next()
  707. }
  708. }
  709. func EmbeddingsMiddleware() gin.HandlerFunc {
  710. return func(c *gin.Context) {
  711. var req EmbedRequest
  712. err := c.ShouldBindJSON(&req)
  713. if err != nil {
  714. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  715. return
  716. }
  717. if req.Input == "" {
  718. req.Input = []string{""}
  719. }
  720. if req.Input == nil {
  721. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
  722. return
  723. }
  724. if v, ok := req.Input.([]any); ok && len(v) == 0 {
  725. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
  726. return
  727. }
  728. var b bytes.Buffer
  729. if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
  730. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  731. return
  732. }
  733. c.Request.Body = io.NopCloser(&b)
  734. w := &EmbedWriter{
  735. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  736. model: req.Model,
  737. }
  738. c.Writer = w
  739. c.Next()
  740. }
  741. }
  742. func ChatMiddleware() gin.HandlerFunc {
  743. return func(c *gin.Context) {
  744. var req ChatCompletionRequest
  745. err := c.ShouldBindJSON(&req)
  746. if err != nil {
  747. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  748. return
  749. }
  750. if len(req.Messages) == 0 {
  751. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
  752. return
  753. }
  754. var b bytes.Buffer
  755. chatReq, err := fromChatRequest(req)
  756. if err != nil {
  757. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  758. return
  759. }
  760. if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
  761. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  762. return
  763. }
  764. c.Request.Body = io.NopCloser(&b)
  765. w := &ChatWriter{
  766. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  767. stream: req.Stream,
  768. id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
  769. }
  770. c.Writer = w
  771. c.Next()
  772. }
  773. }