openai.go 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920
  1. // openai package provides middleware for partial compatibility with the OpenAI REST API
  2. package openai
  3. import (
  4. "bytes"
  5. "encoding/base64"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log/slog"
  11. "math/rand"
  12. "net/http"
  13. "strings"
  14. "time"
  15. "github.com/gin-gonic/gin"
  16. "github.com/ollama/ollama/api"
  17. "github.com/ollama/ollama/types/model"
  18. )
  19. type Error struct {
  20. Message string `json:"message"`
  21. Type string `json:"type"`
  22. Param interface{} `json:"param"`
  23. Code *string `json:"code"`
  24. }
  25. type ErrorResponse struct {
  26. Error Error `json:"error"`
  27. }
  28. type Message struct {
  29. Role string `json:"role"`
  30. Content any `json:"content"`
  31. ToolCalls []ToolCall `json:"tool_calls,omitempty"`
  32. }
  33. type Choice struct {
  34. Index int `json:"index"`
  35. Message Message `json:"message"`
  36. FinishReason *string `json:"finish_reason"`
  37. }
  38. type ChunkChoice struct {
  39. Index int `json:"index"`
  40. Delta Message `json:"delta"`
  41. FinishReason *string `json:"finish_reason"`
  42. }
  43. type CompleteChunkChoice struct {
  44. Text string `json:"text"`
  45. Index int `json:"index"`
  46. FinishReason *string `json:"finish_reason"`
  47. }
  48. type Usage struct {
  49. PromptTokens int `json:"prompt_tokens"`
  50. CompletionTokens int `json:"completion_tokens"`
  51. TotalTokens int `json:"total_tokens"`
  52. }
  53. type ResponseFormat struct {
  54. Type string `json:"type"`
  55. }
  56. type EmbedRequest struct {
  57. Input any `json:"input"`
  58. Model string `json:"model"`
  59. }
  60. type ChatCompletionRequest struct {
  61. Model string `json:"model"`
  62. Messages []Message `json:"messages"`
  63. Stream bool `json:"stream"`
  64. MaxTokens *int `json:"max_tokens"`
  65. Seed *int `json:"seed"`
  66. Stop any `json:"stop"`
  67. Temperature *float64 `json:"temperature"`
  68. FrequencyPenalty *float64 `json:"frequency_penalty"`
  69. PresencePenalty *float64 `json:"presence_penalty"`
  70. TopP *float64 `json:"top_p"`
  71. ResponseFormat *ResponseFormat `json:"response_format"`
  72. Tools []api.Tool `json:"tools"`
  73. }
  74. type ChatCompletion struct {
  75. Id string `json:"id"`
  76. Object string `json:"object"`
  77. Created int64 `json:"created"`
  78. Model string `json:"model"`
  79. SystemFingerprint string `json:"system_fingerprint"`
  80. Choices []Choice `json:"choices"`
  81. Usage Usage `json:"usage,omitempty"`
  82. }
  83. type ChatCompletionChunk struct {
  84. Id string `json:"id"`
  85. Object string `json:"object"`
  86. Created int64 `json:"created"`
  87. Model string `json:"model"`
  88. SystemFingerprint string `json:"system_fingerprint"`
  89. Choices []ChunkChoice `json:"choices"`
  90. }
  91. // TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
  92. type CompletionRequest struct {
  93. Model string `json:"model"`
  94. Prompt string `json:"prompt"`
  95. FrequencyPenalty float32 `json:"frequency_penalty"`
  96. MaxTokens *int `json:"max_tokens"`
  97. PresencePenalty float32 `json:"presence_penalty"`
  98. Seed *int `json:"seed"`
  99. Stop any `json:"stop"`
  100. Stream bool `json:"stream"`
  101. Temperature *float32 `json:"temperature"`
  102. TopP float32 `json:"top_p"`
  103. Suffix string `json:"suffix"`
  104. }
  105. type Completion struct {
  106. Id string `json:"id"`
  107. Object string `json:"object"`
  108. Created int64 `json:"created"`
  109. Model string `json:"model"`
  110. SystemFingerprint string `json:"system_fingerprint"`
  111. Choices []CompleteChunkChoice `json:"choices"`
  112. Usage Usage `json:"usage,omitempty"`
  113. }
  114. type CompletionChunk struct {
  115. Id string `json:"id"`
  116. Object string `json:"object"`
  117. Created int64 `json:"created"`
  118. Choices []CompleteChunkChoice `json:"choices"`
  119. Model string `json:"model"`
  120. SystemFingerprint string `json:"system_fingerprint"`
  121. }
  122. type ToolCall struct {
  123. ID string `json:"id"`
  124. Index int `json:"index"`
  125. Type string `json:"type"`
  126. Function struct {
  127. Name string `json:"name"`
  128. Arguments string `json:"arguments"`
  129. } `json:"function"`
  130. }
  131. type Model struct {
  132. Id string `json:"id"`
  133. Object string `json:"object"`
  134. Created int64 `json:"created"`
  135. OwnedBy string `json:"owned_by"`
  136. }
  137. type Embedding struct {
  138. Object string `json:"object"`
  139. Embedding []float32 `json:"embedding"`
  140. Index int `json:"index"`
  141. }
  142. type ListCompletion struct {
  143. Object string `json:"object"`
  144. Data []Model `json:"data"`
  145. }
  146. type EmbeddingList struct {
  147. Object string `json:"object"`
  148. Data []Embedding `json:"data"`
  149. Model string `json:"model"`
  150. Usage EmbeddingUsage `json:"usage,omitempty"`
  151. }
  152. type EmbeddingUsage struct {
  153. PromptTokens int `json:"prompt_tokens"`
  154. TotalTokens int `json:"total_tokens"`
  155. }
  156. func NewError(code int, message string) ErrorResponse {
  157. var etype string
  158. switch code {
  159. case http.StatusBadRequest:
  160. etype = "invalid_request_error"
  161. case http.StatusNotFound:
  162. etype = "not_found_error"
  163. default:
  164. etype = "api_error"
  165. }
  166. return ErrorResponse{Error{Type: etype, Message: message}}
  167. }
  168. func toolCallId() string {
  169. const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
  170. b := make([]byte, 8)
  171. for i := range b {
  172. b[i] = letterBytes[rand.Intn(len(letterBytes))]
  173. }
  174. return "call_" + strings.ToLower(string(b))
  175. }
  176. func toToolCalls(tc []api.ToolCall) []ToolCall {
  177. toolCalls := make([]ToolCall, len(tc))
  178. for i, tc := range tc {
  179. toolCalls[i].ID = toolCallId()
  180. toolCalls[i].Type = "function"
  181. toolCalls[i].Function.Name = tc.Function.Name
  182. toolCalls[i].Index = tc.Function.Index
  183. args, err := json.Marshal(tc.Function.Arguments)
  184. if err != nil {
  185. slog.Error("could not marshall function arguments to json", "error", err)
  186. continue
  187. }
  188. toolCalls[i].Function.Arguments = string(args)
  189. }
  190. return toolCalls
  191. }
  192. func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
  193. toolCalls := toToolCalls(r.Message.ToolCalls)
  194. return ChatCompletion{
  195. Id: id,
  196. Object: "chat.completion",
  197. Created: r.CreatedAt.Unix(),
  198. Model: r.Model,
  199. SystemFingerprint: "fp_ollama",
  200. Choices: []Choice{{
  201. Index: 0,
  202. Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
  203. FinishReason: func(reason string) *string {
  204. if len(toolCalls) > 0 {
  205. reason = "tool_calls"
  206. }
  207. if len(reason) > 0 {
  208. return &reason
  209. }
  210. return nil
  211. }(r.DoneReason),
  212. }},
  213. Usage: Usage{
  214. PromptTokens: r.PromptEvalCount,
  215. CompletionTokens: r.EvalCount,
  216. TotalTokens: r.PromptEvalCount + r.EvalCount,
  217. },
  218. }
  219. }
  220. func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
  221. toolCalls := toToolCalls(r.Message.ToolCalls)
  222. return ChatCompletionChunk{
  223. Id: id,
  224. Object: "chat.completion.chunk",
  225. Created: time.Now().Unix(),
  226. Model: r.Model,
  227. SystemFingerprint: "fp_ollama",
  228. Choices: []ChunkChoice{{
  229. Index: 0,
  230. Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
  231. FinishReason: func(reason string) *string {
  232. if len(reason) > 0 {
  233. return &reason
  234. }
  235. return nil
  236. }(r.DoneReason),
  237. }},
  238. }
  239. }
  240. func toCompletion(id string, r api.GenerateResponse) Completion {
  241. return Completion{
  242. Id: id,
  243. Object: "text_completion",
  244. Created: r.CreatedAt.Unix(),
  245. Model: r.Model,
  246. SystemFingerprint: "fp_ollama",
  247. Choices: []CompleteChunkChoice{{
  248. Text: r.Response,
  249. Index: 0,
  250. FinishReason: func(reason string) *string {
  251. if len(reason) > 0 {
  252. return &reason
  253. }
  254. return nil
  255. }(r.DoneReason),
  256. }},
  257. Usage: Usage{
  258. PromptTokens: r.PromptEvalCount,
  259. CompletionTokens: r.EvalCount,
  260. TotalTokens: r.PromptEvalCount + r.EvalCount,
  261. },
  262. }
  263. }
  264. func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
  265. return CompletionChunk{
  266. Id: id,
  267. Object: "text_completion",
  268. Created: time.Now().Unix(),
  269. Model: r.Model,
  270. SystemFingerprint: "fp_ollama",
  271. Choices: []CompleteChunkChoice{{
  272. Text: r.Response,
  273. Index: 0,
  274. FinishReason: func(reason string) *string {
  275. if len(reason) > 0 {
  276. return &reason
  277. }
  278. return nil
  279. }(r.DoneReason),
  280. }},
  281. }
  282. }
  283. func toListCompletion(r api.ListResponse) ListCompletion {
  284. var data []Model
  285. for _, m := range r.Models {
  286. data = append(data, Model{
  287. Id: m.Name,
  288. Object: "model",
  289. Created: m.ModifiedAt.Unix(),
  290. OwnedBy: model.ParseName(m.Name).Namespace,
  291. })
  292. }
  293. return ListCompletion{
  294. Object: "list",
  295. Data: data,
  296. }
  297. }
  298. func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
  299. if r.Embeddings != nil {
  300. var data []Embedding
  301. for i, e := range r.Embeddings {
  302. data = append(data, Embedding{
  303. Object: "embedding",
  304. Embedding: e,
  305. Index: i,
  306. })
  307. }
  308. return EmbeddingList{
  309. Object: "list",
  310. Data: data,
  311. Model: model,
  312. Usage: EmbeddingUsage{
  313. PromptTokens: r.PromptEvalCount,
  314. TotalTokens: r.PromptEvalCount,
  315. },
  316. }
  317. }
  318. return EmbeddingList{}
  319. }
  320. func toModel(r api.ShowResponse, m string) Model {
  321. return Model{
  322. Id: m,
  323. Object: "model",
  324. Created: r.ModifiedAt.Unix(),
  325. OwnedBy: model.ParseName(m).Namespace,
  326. }
  327. }
  328. func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
  329. var messages []api.Message
  330. for _, msg := range r.Messages {
  331. switch content := msg.Content.(type) {
  332. case string:
  333. messages = append(messages, api.Message{Role: msg.Role, Content: content})
  334. case []any:
  335. for _, c := range content {
  336. data, ok := c.(map[string]any)
  337. if !ok {
  338. return nil, errors.New("invalid message format")
  339. }
  340. switch data["type"] {
  341. case "text":
  342. text, ok := data["text"].(string)
  343. if !ok {
  344. return nil, errors.New("invalid message format")
  345. }
  346. messages = append(messages, api.Message{Role: msg.Role, Content: text})
  347. case "image_url":
  348. var url string
  349. if urlMap, ok := data["image_url"].(map[string]any); ok {
  350. if url, ok = urlMap["url"].(string); !ok {
  351. return nil, errors.New("invalid message format")
  352. }
  353. } else {
  354. if url, ok = data["image_url"].(string); !ok {
  355. return nil, errors.New("invalid message format")
  356. }
  357. }
  358. types := []string{"jpeg", "jpg", "png"}
  359. valid := false
  360. for _, t := range types {
  361. prefix := "data:image/" + t + ";base64,"
  362. if strings.HasPrefix(url, prefix) {
  363. url = strings.TrimPrefix(url, prefix)
  364. valid = true
  365. break
  366. }
  367. }
  368. if !valid {
  369. return nil, errors.New("invalid image input")
  370. }
  371. img, err := base64.StdEncoding.DecodeString(url)
  372. if err != nil {
  373. return nil, errors.New("invalid message format")
  374. }
  375. messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
  376. default:
  377. return nil, errors.New("invalid message format")
  378. }
  379. }
  380. default:
  381. if msg.ToolCalls == nil {
  382. return nil, fmt.Errorf("invalid message content type: %T", content)
  383. }
  384. toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
  385. for i, tc := range msg.ToolCalls {
  386. toolCalls[i].Function.Name = tc.Function.Name
  387. err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
  388. if err != nil {
  389. return nil, errors.New("invalid tool call arguments")
  390. }
  391. }
  392. messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
  393. }
  394. }
  395. options := make(map[string]interface{})
  396. switch stop := r.Stop.(type) {
  397. case string:
  398. options["stop"] = []string{stop}
  399. case []any:
  400. var stops []string
  401. for _, s := range stop {
  402. if str, ok := s.(string); ok {
  403. stops = append(stops, str)
  404. }
  405. }
  406. options["stop"] = stops
  407. }
  408. if r.MaxTokens != nil {
  409. options["num_predict"] = *r.MaxTokens
  410. }
  411. if r.Temperature != nil {
  412. options["temperature"] = *r.Temperature
  413. } else {
  414. options["temperature"] = 1.0
  415. }
  416. if r.Seed != nil {
  417. options["seed"] = *r.Seed
  418. }
  419. if r.FrequencyPenalty != nil {
  420. options["frequency_penalty"] = *r.FrequencyPenalty
  421. }
  422. if r.PresencePenalty != nil {
  423. options["presence_penalty"] = *r.PresencePenalty
  424. }
  425. if r.TopP != nil {
  426. options["top_p"] = *r.TopP
  427. } else {
  428. options["top_p"] = 1.0
  429. }
  430. var format string
  431. if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" {
  432. format = "json"
  433. }
  434. return &api.ChatRequest{
  435. Model: r.Model,
  436. Messages: messages,
  437. Format: format,
  438. Options: options,
  439. Stream: &r.Stream,
  440. Tools: r.Tools,
  441. }, nil
  442. }
  443. func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
  444. options := make(map[string]any)
  445. switch stop := r.Stop.(type) {
  446. case string:
  447. options["stop"] = []string{stop}
  448. case []any:
  449. var stops []string
  450. for _, s := range stop {
  451. if str, ok := s.(string); ok {
  452. stops = append(stops, str)
  453. } else {
  454. return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
  455. }
  456. }
  457. options["stop"] = stops
  458. }
  459. if r.MaxTokens != nil {
  460. options["num_predict"] = *r.MaxTokens
  461. }
  462. if r.Temperature != nil {
  463. options["temperature"] = *r.Temperature
  464. } else {
  465. options["temperature"] = 1.0
  466. }
  467. if r.Seed != nil {
  468. options["seed"] = *r.Seed
  469. }
  470. options["frequency_penalty"] = r.FrequencyPenalty
  471. options["presence_penalty"] = r.PresencePenalty
  472. if r.TopP != 0.0 {
  473. options["top_p"] = r.TopP
  474. } else {
  475. options["top_p"] = 1.0
  476. }
  477. return api.GenerateRequest{
  478. Model: r.Model,
  479. Prompt: r.Prompt,
  480. Options: options,
  481. Stream: &r.Stream,
  482. Suffix: r.Suffix,
  483. }, nil
  484. }
  485. type BaseWriter struct {
  486. gin.ResponseWriter
  487. }
  488. type ChatWriter struct {
  489. stream bool
  490. id string
  491. BaseWriter
  492. }
  493. type CompleteWriter struct {
  494. stream bool
  495. id string
  496. BaseWriter
  497. }
  498. type ListWriter struct {
  499. BaseWriter
  500. }
  501. type RetrieveWriter struct {
  502. BaseWriter
  503. model string
  504. }
  505. type EmbedWriter struct {
  506. BaseWriter
  507. model string
  508. }
  509. func (w *BaseWriter) writeError(data []byte) (int, error) {
  510. var serr api.StatusError
  511. err := json.Unmarshal(data, &serr)
  512. if err != nil {
  513. return 0, err
  514. }
  515. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  516. err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
  517. if err != nil {
  518. return 0, err
  519. }
  520. return len(data), nil
  521. }
  522. func (w *ChatWriter) writeResponse(data []byte) (int, error) {
  523. var chatResponse api.ChatResponse
  524. err := json.Unmarshal(data, &chatResponse)
  525. if err != nil {
  526. return 0, err
  527. }
  528. // chat chunk
  529. if w.stream {
  530. d, err := json.Marshal(toChunk(w.id, chatResponse))
  531. if err != nil {
  532. return 0, err
  533. }
  534. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  535. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  536. if err != nil {
  537. return 0, err
  538. }
  539. if chatResponse.Done {
  540. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  541. if err != nil {
  542. return 0, err
  543. }
  544. }
  545. return len(data), nil
  546. }
  547. // chat completion
  548. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  549. err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
  550. if err != nil {
  551. return 0, err
  552. }
  553. return len(data), nil
  554. }
  555. func (w *ChatWriter) Write(data []byte) (int, error) {
  556. code := w.ResponseWriter.Status()
  557. if code != http.StatusOK {
  558. return w.writeError(data)
  559. }
  560. return w.writeResponse(data)
  561. }
  562. func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
  563. var generateResponse api.GenerateResponse
  564. err := json.Unmarshal(data, &generateResponse)
  565. if err != nil {
  566. return 0, err
  567. }
  568. // completion chunk
  569. if w.stream {
  570. d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
  571. if err != nil {
  572. return 0, err
  573. }
  574. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  575. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  576. if err != nil {
  577. return 0, err
  578. }
  579. if generateResponse.Done {
  580. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  581. if err != nil {
  582. return 0, err
  583. }
  584. }
  585. return len(data), nil
  586. }
  587. // completion
  588. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  589. err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
  590. if err != nil {
  591. return 0, err
  592. }
  593. return len(data), nil
  594. }
  595. func (w *CompleteWriter) Write(data []byte) (int, error) {
  596. code := w.ResponseWriter.Status()
  597. if code != http.StatusOK {
  598. return w.writeError(data)
  599. }
  600. return w.writeResponse(data)
  601. }
  602. func (w *ListWriter) writeResponse(data []byte) (int, error) {
  603. var listResponse api.ListResponse
  604. err := json.Unmarshal(data, &listResponse)
  605. if err != nil {
  606. return 0, err
  607. }
  608. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  609. err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
  610. if err != nil {
  611. return 0, err
  612. }
  613. return len(data), nil
  614. }
  615. func (w *ListWriter) Write(data []byte) (int, error) {
  616. code := w.ResponseWriter.Status()
  617. if code != http.StatusOK {
  618. return w.writeError(data)
  619. }
  620. return w.writeResponse(data)
  621. }
  622. func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
  623. var showResponse api.ShowResponse
  624. err := json.Unmarshal(data, &showResponse)
  625. if err != nil {
  626. return 0, err
  627. }
  628. // retrieve completion
  629. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  630. err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
  631. if err != nil {
  632. return 0, err
  633. }
  634. return len(data), nil
  635. }
  636. func (w *RetrieveWriter) Write(data []byte) (int, error) {
  637. code := w.ResponseWriter.Status()
  638. if code != http.StatusOK {
  639. return w.writeError(data)
  640. }
  641. return w.writeResponse(data)
  642. }
  643. func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
  644. var embedResponse api.EmbedResponse
  645. err := json.Unmarshal(data, &embedResponse)
  646. if err != nil {
  647. return 0, err
  648. }
  649. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  650. err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
  651. if err != nil {
  652. return 0, err
  653. }
  654. return len(data), nil
  655. }
  656. func (w *EmbedWriter) Write(data []byte) (int, error) {
  657. code := w.ResponseWriter.Status()
  658. if code != http.StatusOK {
  659. return w.writeError(data)
  660. }
  661. return w.writeResponse(data)
  662. }
  663. func ListMiddleware() gin.HandlerFunc {
  664. return func(c *gin.Context) {
  665. w := &ListWriter{
  666. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  667. }
  668. c.Writer = w
  669. c.Next()
  670. }
  671. }
  672. func RetrieveMiddleware() gin.HandlerFunc {
  673. return func(c *gin.Context) {
  674. var b bytes.Buffer
  675. if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
  676. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  677. return
  678. }
  679. c.Request.Body = io.NopCloser(&b)
  680. // response writer
  681. w := &RetrieveWriter{
  682. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  683. model: c.Param("model"),
  684. }
  685. c.Writer = w
  686. c.Next()
  687. }
  688. }
  689. func CompletionsMiddleware() gin.HandlerFunc {
  690. return func(c *gin.Context) {
  691. var req CompletionRequest
  692. err := c.ShouldBindJSON(&req)
  693. if err != nil {
  694. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  695. return
  696. }
  697. var b bytes.Buffer
  698. genReq, err := fromCompleteRequest(req)
  699. if err != nil {
  700. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  701. return
  702. }
  703. if err := json.NewEncoder(&b).Encode(genReq); err != nil {
  704. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  705. return
  706. }
  707. c.Request.Body = io.NopCloser(&b)
  708. w := &CompleteWriter{
  709. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  710. stream: req.Stream,
  711. id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
  712. }
  713. c.Writer = w
  714. c.Next()
  715. }
  716. }
  717. func EmbeddingsMiddleware() gin.HandlerFunc {
  718. return func(c *gin.Context) {
  719. var req EmbedRequest
  720. err := c.ShouldBindJSON(&req)
  721. if err != nil {
  722. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  723. return
  724. }
  725. if req.Input == "" {
  726. req.Input = []string{""}
  727. }
  728. if req.Input == nil {
  729. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
  730. return
  731. }
  732. if v, ok := req.Input.([]any); ok && len(v) == 0 {
  733. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
  734. return
  735. }
  736. var b bytes.Buffer
  737. if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
  738. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  739. return
  740. }
  741. c.Request.Body = io.NopCloser(&b)
  742. w := &EmbedWriter{
  743. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  744. model: req.Model,
  745. }
  746. c.Writer = w
  747. c.Next()
  748. }
  749. }
  750. func ChatMiddleware() gin.HandlerFunc {
  751. return func(c *gin.Context) {
  752. var req ChatCompletionRequest
  753. err := c.ShouldBindJSON(&req)
  754. if err != nil {
  755. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  756. return
  757. }
  758. if len(req.Messages) == 0 {
  759. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
  760. return
  761. }
  762. var b bytes.Buffer
  763. chatReq, err := fromChatRequest(req)
  764. if err != nil {
  765. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  766. return
  767. }
  768. if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
  769. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  770. return
  771. }
  772. c.Request.Body = io.NopCloser(&b)
  773. w := &ChatWriter{
  774. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  775. stream: req.Stream,
  776. id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
  777. }
  778. c.Writer = w
  779. c.Next()
  780. }
  781. }