openai.go 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851
  1. // openai package provides middleware for partial compatibility with the OpenAI REST API
  2. package openai
  3. import (
  4. "bytes"
  5. "encoding/base64"
  6. "encoding/json"
  7. "fmt"
  8. "io"
  9. "math/rand"
  10. "net/http"
  11. "strings"
  12. "time"
  13. "github.com/gin-gonic/gin"
  14. "github.com/ollama/ollama/api"
  15. "github.com/ollama/ollama/types/model"
  16. )
  17. type Error struct {
  18. Message string `json:"message"`
  19. Type string `json:"type"`
  20. Param interface{} `json:"param"`
  21. Code *string `json:"code"`
  22. }
  23. type ErrorResponse struct {
  24. Error Error `json:"error"`
  25. }
  26. type Message struct {
  27. Role string `json:"role"`
  28. Content any `json:"content"`
  29. }
  30. type Choice struct {
  31. Index int `json:"index"`
  32. Message Message `json:"message"`
  33. FinishReason *string `json:"finish_reason"`
  34. }
  35. type ChunkChoice struct {
  36. Index int `json:"index"`
  37. Delta Message `json:"delta"`
  38. FinishReason *string `json:"finish_reason"`
  39. }
  40. type CompleteChunkChoice struct {
  41. Text string `json:"text"`
  42. Index int `json:"index"`
  43. FinishReason *string `json:"finish_reason"`
  44. }
  45. type Usage struct {
  46. PromptTokens int `json:"prompt_tokens"`
  47. CompletionTokens int `json:"completion_tokens"`
  48. TotalTokens int `json:"total_tokens"`
  49. }
  50. type ResponseFormat struct {
  51. Type string `json:"type"`
  52. }
  53. type EmbedRequest struct {
  54. Input any `json:"input"`
  55. Model string `json:"model"`
  56. }
  57. type ChatCompletionRequest struct {
  58. Model string `json:"model"`
  59. Messages []Message `json:"messages"`
  60. Stream bool `json:"stream"`
  61. MaxTokens *int `json:"max_tokens"`
  62. Seed *int `json:"seed"`
  63. Stop any `json:"stop"`
  64. Temperature *float64 `json:"temperature"`
  65. FrequencyPenalty *float64 `json:"frequency_penalty"`
  66. PresencePenalty *float64 `json:"presence_penalty_penalty"`
  67. TopP *float64 `json:"top_p"`
  68. ResponseFormat *ResponseFormat `json:"response_format"`
  69. }
  70. type ChatCompletion struct {
  71. Id string `json:"id"`
  72. Object string `json:"object"`
  73. Created int64 `json:"created"`
  74. Model string `json:"model"`
  75. SystemFingerprint string `json:"system_fingerprint"`
  76. Choices []Choice `json:"choices"`
  77. Usage Usage `json:"usage,omitempty"`
  78. }
  79. type ChatCompletionChunk struct {
  80. Id string `json:"id"`
  81. Object string `json:"object"`
  82. Created int64 `json:"created"`
  83. Model string `json:"model"`
  84. SystemFingerprint string `json:"system_fingerprint"`
  85. Choices []ChunkChoice `json:"choices"`
  86. }
  87. // TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
  88. type CompletionRequest struct {
  89. Model string `json:"model"`
  90. Prompt string `json:"prompt"`
  91. FrequencyPenalty float32 `json:"frequency_penalty"`
  92. MaxTokens *int `json:"max_tokens"`
  93. PresencePenalty float32 `json:"presence_penalty"`
  94. Seed *int `json:"seed"`
  95. Stop any `json:"stop"`
  96. Stream bool `json:"stream"`
  97. Temperature *float32 `json:"temperature"`
  98. TopP float32 `json:"top_p"`
  99. }
  100. type Completion struct {
  101. Id string `json:"id"`
  102. Object string `json:"object"`
  103. Created int64 `json:"created"`
  104. Model string `json:"model"`
  105. SystemFingerprint string `json:"system_fingerprint"`
  106. Choices []CompleteChunkChoice `json:"choices"`
  107. Usage Usage `json:"usage,omitempty"`
  108. }
  109. type CompletionChunk struct {
  110. Id string `json:"id"`
  111. Object string `json:"object"`
  112. Created int64 `json:"created"`
  113. Choices []CompleteChunkChoice `json:"choices"`
  114. Model string `json:"model"`
  115. SystemFingerprint string `json:"system_fingerprint"`
  116. }
  117. type Model struct {
  118. Id string `json:"id"`
  119. Object string `json:"object"`
  120. Created int64 `json:"created"`
  121. OwnedBy string `json:"owned_by"`
  122. }
  123. type Embedding struct {
  124. Object string `json:"object"`
  125. Embedding []float32 `json:"embedding"`
  126. Index int `json:"index"`
  127. }
  128. type ListCompletion struct {
  129. Object string `json:"object"`
  130. Data []Model `json:"data"`
  131. }
  132. type EmbeddingList struct {
  133. Object string `json:"object"`
  134. Data []Embedding `json:"data"`
  135. Model string `json:"model"`
  136. }
  137. func NewError(code int, message string) ErrorResponse {
  138. var etype string
  139. switch code {
  140. case http.StatusBadRequest:
  141. etype = "invalid_request_error"
  142. case http.StatusNotFound:
  143. etype = "not_found_error"
  144. default:
  145. etype = "api_error"
  146. }
  147. return ErrorResponse{Error{Type: etype, Message: message}}
  148. }
  149. func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
  150. return ChatCompletion{
  151. Id: id,
  152. Object: "chat.completion",
  153. Created: r.CreatedAt.Unix(),
  154. Model: r.Model,
  155. SystemFingerprint: "fp_ollama",
  156. Choices: []Choice{{
  157. Index: 0,
  158. Message: Message{Role: r.Message.Role, Content: r.Message.Content},
  159. FinishReason: func(reason string) *string {
  160. if len(reason) > 0 {
  161. return &reason
  162. }
  163. return nil
  164. }(r.DoneReason),
  165. }},
  166. Usage: Usage{
  167. // TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
  168. PromptTokens: r.PromptEvalCount,
  169. CompletionTokens: r.EvalCount,
  170. TotalTokens: r.PromptEvalCount + r.EvalCount,
  171. },
  172. }
  173. }
  174. func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
  175. return ChatCompletionChunk{
  176. Id: id,
  177. Object: "chat.completion.chunk",
  178. Created: time.Now().Unix(),
  179. Model: r.Model,
  180. SystemFingerprint: "fp_ollama",
  181. Choices: []ChunkChoice{{
  182. Index: 0,
  183. Delta: Message{Role: "assistant", Content: r.Message.Content},
  184. FinishReason: func(reason string) *string {
  185. if len(reason) > 0 {
  186. return &reason
  187. }
  188. return nil
  189. }(r.DoneReason),
  190. }},
  191. }
  192. }
  193. func toCompletion(id string, r api.GenerateResponse) Completion {
  194. return Completion{
  195. Id: id,
  196. Object: "text_completion",
  197. Created: r.CreatedAt.Unix(),
  198. Model: r.Model,
  199. SystemFingerprint: "fp_ollama",
  200. Choices: []CompleteChunkChoice{{
  201. Text: r.Response,
  202. Index: 0,
  203. FinishReason: func(reason string) *string {
  204. if len(reason) > 0 {
  205. return &reason
  206. }
  207. return nil
  208. }(r.DoneReason),
  209. }},
  210. Usage: Usage{
  211. // TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
  212. PromptTokens: r.PromptEvalCount,
  213. CompletionTokens: r.EvalCount,
  214. TotalTokens: r.PromptEvalCount + r.EvalCount,
  215. },
  216. }
  217. }
  218. func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
  219. return CompletionChunk{
  220. Id: id,
  221. Object: "text_completion",
  222. Created: time.Now().Unix(),
  223. Model: r.Model,
  224. SystemFingerprint: "fp_ollama",
  225. Choices: []CompleteChunkChoice{{
  226. Text: r.Response,
  227. Index: 0,
  228. FinishReason: func(reason string) *string {
  229. if len(reason) > 0 {
  230. return &reason
  231. }
  232. return nil
  233. }(r.DoneReason),
  234. }},
  235. }
  236. }
  237. func toListCompletion(r api.ListResponse) ListCompletion {
  238. var data []Model
  239. for _, m := range r.Models {
  240. data = append(data, Model{
  241. Id: m.Name,
  242. Object: "model",
  243. Created: m.ModifiedAt.Unix(),
  244. OwnedBy: model.ParseName(m.Name).Namespace,
  245. })
  246. }
  247. return ListCompletion{
  248. Object: "list",
  249. Data: data,
  250. }
  251. }
  252. func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
  253. if r.Embeddings != nil {
  254. var data []Embedding
  255. for i, e := range r.Embeddings {
  256. data = append(data, Embedding{
  257. Object: "embedding",
  258. Embedding: e,
  259. Index: i,
  260. })
  261. }
  262. return EmbeddingList{
  263. Object: "list",
  264. Data: data,
  265. Model: model,
  266. }
  267. }
  268. return EmbeddingList{}
  269. }
  270. func toModel(r api.ShowResponse, m string) Model {
  271. return Model{
  272. Id: m,
  273. Object: "model",
  274. Created: r.ModifiedAt.Unix(),
  275. OwnedBy: model.ParseName(m).Namespace,
  276. }
  277. }
  278. func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
  279. var messages []api.Message
  280. for _, msg := range r.Messages {
  281. switch content := msg.Content.(type) {
  282. case string:
  283. messages = append(messages, api.Message{Role: msg.Role, Content: content})
  284. case []any:
  285. message := api.Message{Role: msg.Role}
  286. for _, c := range content {
  287. data, ok := c.(map[string]any)
  288. if !ok {
  289. return nil, fmt.Errorf("invalid message format")
  290. }
  291. switch data["type"] {
  292. case "text":
  293. text, ok := data["text"].(string)
  294. if !ok {
  295. return nil, fmt.Errorf("invalid message format")
  296. }
  297. message.Content = text
  298. case "image_url":
  299. var url string
  300. if urlMap, ok := data["image_url"].(map[string]any); ok {
  301. if url, ok = urlMap["url"].(string); !ok {
  302. return nil, fmt.Errorf("invalid message format")
  303. }
  304. } else {
  305. if url, ok = data["image_url"].(string); !ok {
  306. return nil, fmt.Errorf("invalid message format")
  307. }
  308. }
  309. types := []string{"jpeg", "jpg", "png"}
  310. valid := false
  311. for _, t := range types {
  312. prefix := "data:image/" + t + ";base64,"
  313. if strings.HasPrefix(url, prefix) {
  314. url = strings.TrimPrefix(url, prefix)
  315. valid = true
  316. break
  317. }
  318. }
  319. if !valid {
  320. return nil, fmt.Errorf("invalid image input")
  321. }
  322. img, err := base64.StdEncoding.DecodeString(url)
  323. if err != nil {
  324. return nil, fmt.Errorf("invalid message format")
  325. }
  326. message.Images = append(message.Images, img)
  327. default:
  328. return nil, fmt.Errorf("invalid message format")
  329. }
  330. }
  331. messages = append(messages, message)
  332. default:
  333. return nil, fmt.Errorf("invalid message content type: %T", content)
  334. }
  335. }
  336. options := make(map[string]interface{})
  337. switch stop := r.Stop.(type) {
  338. case string:
  339. options["stop"] = []string{stop}
  340. case []any:
  341. var stops []string
  342. for _, s := range stop {
  343. if str, ok := s.(string); ok {
  344. stops = append(stops, str)
  345. }
  346. }
  347. options["stop"] = stops
  348. }
  349. if r.MaxTokens != nil {
  350. options["num_predict"] = *r.MaxTokens
  351. }
  352. if r.Temperature != nil {
  353. options["temperature"] = *r.Temperature * 2.0
  354. } else {
  355. options["temperature"] = 1.0
  356. }
  357. if r.Seed != nil {
  358. options["seed"] = *r.Seed
  359. }
  360. if r.FrequencyPenalty != nil {
  361. options["frequency_penalty"] = *r.FrequencyPenalty * 2.0
  362. }
  363. if r.PresencePenalty != nil {
  364. options["presence_penalty"] = *r.PresencePenalty * 2.0
  365. }
  366. if r.TopP != nil {
  367. options["top_p"] = *r.TopP
  368. } else {
  369. options["top_p"] = 1.0
  370. }
  371. var format string
  372. if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" {
  373. format = "json"
  374. }
  375. return &api.ChatRequest{
  376. Model: r.Model,
  377. Messages: messages,
  378. Format: format,
  379. Options: options,
  380. Stream: &r.Stream,
  381. }, nil
  382. }
  383. func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
  384. options := make(map[string]any)
  385. switch stop := r.Stop.(type) {
  386. case string:
  387. options["stop"] = []string{stop}
  388. case []any:
  389. var stops []string
  390. for _, s := range stop {
  391. if str, ok := s.(string); ok {
  392. stops = append(stops, str)
  393. } else {
  394. return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
  395. }
  396. }
  397. options["stop"] = stops
  398. }
  399. if r.MaxTokens != nil {
  400. options["num_predict"] = *r.MaxTokens
  401. }
  402. if r.Temperature != nil {
  403. options["temperature"] = *r.Temperature * 2.0
  404. } else {
  405. options["temperature"] = 1.0
  406. }
  407. if r.Seed != nil {
  408. options["seed"] = *r.Seed
  409. }
  410. options["frequency_penalty"] = r.FrequencyPenalty * 2.0
  411. options["presence_penalty"] = r.PresencePenalty * 2.0
  412. if r.TopP != 0.0 {
  413. options["top_p"] = r.TopP
  414. } else {
  415. options["top_p"] = 1.0
  416. }
  417. return api.GenerateRequest{
  418. Model: r.Model,
  419. Prompt: r.Prompt,
  420. Options: options,
  421. Stream: &r.Stream,
  422. }, nil
  423. }
  424. type BaseWriter struct {
  425. gin.ResponseWriter
  426. }
  427. type ChatWriter struct {
  428. stream bool
  429. id string
  430. BaseWriter
  431. }
  432. type CompleteWriter struct {
  433. stream bool
  434. id string
  435. BaseWriter
  436. }
  437. type ListWriter struct {
  438. BaseWriter
  439. }
  440. type RetrieveWriter struct {
  441. BaseWriter
  442. model string
  443. }
  444. type EmbedWriter struct {
  445. BaseWriter
  446. model string
  447. }
  448. func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
  449. var serr api.StatusError
  450. err := json.Unmarshal(data, &serr)
  451. if err != nil {
  452. return 0, err
  453. }
  454. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  455. err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
  456. if err != nil {
  457. return 0, err
  458. }
  459. return len(data), nil
  460. }
  461. func (w *ChatWriter) writeResponse(data []byte) (int, error) {
  462. var chatResponse api.ChatResponse
  463. err := json.Unmarshal(data, &chatResponse)
  464. if err != nil {
  465. return 0, err
  466. }
  467. // chat chunk
  468. if w.stream {
  469. d, err := json.Marshal(toChunk(w.id, chatResponse))
  470. if err != nil {
  471. return 0, err
  472. }
  473. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  474. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  475. if err != nil {
  476. return 0, err
  477. }
  478. if chatResponse.Done {
  479. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  480. if err != nil {
  481. return 0, err
  482. }
  483. }
  484. return len(data), nil
  485. }
  486. // chat completion
  487. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  488. err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
  489. if err != nil {
  490. return 0, err
  491. }
  492. return len(data), nil
  493. }
  494. func (w *ChatWriter) Write(data []byte) (int, error) {
  495. code := w.ResponseWriter.Status()
  496. if code != http.StatusOK {
  497. return w.writeError(code, data)
  498. }
  499. return w.writeResponse(data)
  500. }
  501. func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
  502. var generateResponse api.GenerateResponse
  503. err := json.Unmarshal(data, &generateResponse)
  504. if err != nil {
  505. return 0, err
  506. }
  507. // completion chunk
  508. if w.stream {
  509. d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
  510. if err != nil {
  511. return 0, err
  512. }
  513. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  514. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  515. if err != nil {
  516. return 0, err
  517. }
  518. if generateResponse.Done {
  519. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  520. if err != nil {
  521. return 0, err
  522. }
  523. }
  524. return len(data), nil
  525. }
  526. // completion
  527. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  528. err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
  529. if err != nil {
  530. return 0, err
  531. }
  532. return len(data), nil
  533. }
  534. func (w *CompleteWriter) Write(data []byte) (int, error) {
  535. code := w.ResponseWriter.Status()
  536. if code != http.StatusOK {
  537. return w.writeError(code, data)
  538. }
  539. return w.writeResponse(data)
  540. }
  541. func (w *ListWriter) writeResponse(data []byte) (int, error) {
  542. var listResponse api.ListResponse
  543. err := json.Unmarshal(data, &listResponse)
  544. if err != nil {
  545. return 0, err
  546. }
  547. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  548. err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
  549. if err != nil {
  550. return 0, err
  551. }
  552. return len(data), nil
  553. }
  554. func (w *ListWriter) Write(data []byte) (int, error) {
  555. code := w.ResponseWriter.Status()
  556. if code != http.StatusOK {
  557. return w.writeError(code, data)
  558. }
  559. return w.writeResponse(data)
  560. }
  561. func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
  562. var showResponse api.ShowResponse
  563. err := json.Unmarshal(data, &showResponse)
  564. if err != nil {
  565. return 0, err
  566. }
  567. // retrieve completion
  568. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  569. err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
  570. if err != nil {
  571. return 0, err
  572. }
  573. return len(data), nil
  574. }
  575. func (w *RetrieveWriter) Write(data []byte) (int, error) {
  576. code := w.ResponseWriter.Status()
  577. if code != http.StatusOK {
  578. return w.writeError(code, data)
  579. }
  580. return w.writeResponse(data)
  581. }
  582. func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
  583. var embedResponse api.EmbedResponse
  584. err := json.Unmarshal(data, &embedResponse)
  585. if err != nil {
  586. return 0, err
  587. }
  588. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  589. err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
  590. if err != nil {
  591. return 0, err
  592. }
  593. return len(data), nil
  594. }
  595. func (w *EmbedWriter) Write(data []byte) (int, error) {
  596. code := w.ResponseWriter.Status()
  597. if code != http.StatusOK {
  598. return w.writeError(code, data)
  599. }
  600. return w.writeResponse(data)
  601. }
  602. func ListMiddleware() gin.HandlerFunc {
  603. return func(c *gin.Context) {
  604. w := &ListWriter{
  605. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  606. }
  607. c.Writer = w
  608. c.Next()
  609. }
  610. }
  611. func RetrieveMiddleware() gin.HandlerFunc {
  612. return func(c *gin.Context) {
  613. var b bytes.Buffer
  614. if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
  615. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  616. return
  617. }
  618. c.Request.Body = io.NopCloser(&b)
  619. // response writer
  620. w := &RetrieveWriter{
  621. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  622. model: c.Param("model"),
  623. }
  624. c.Writer = w
  625. c.Next()
  626. }
  627. }
  628. func CompletionsMiddleware() gin.HandlerFunc {
  629. return func(c *gin.Context) {
  630. var req CompletionRequest
  631. err := c.ShouldBindJSON(&req)
  632. if err != nil {
  633. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  634. return
  635. }
  636. var b bytes.Buffer
  637. genReq, err := fromCompleteRequest(req)
  638. if err != nil {
  639. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  640. return
  641. }
  642. if err := json.NewEncoder(&b).Encode(genReq); err != nil {
  643. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  644. return
  645. }
  646. c.Request.Body = io.NopCloser(&b)
  647. w := &CompleteWriter{
  648. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  649. stream: req.Stream,
  650. id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
  651. }
  652. c.Writer = w
  653. c.Next()
  654. }
  655. }
  656. func EmbeddingsMiddleware() gin.HandlerFunc {
  657. return func(c *gin.Context) {
  658. var req EmbedRequest
  659. err := c.ShouldBindJSON(&req)
  660. if err != nil {
  661. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  662. return
  663. }
  664. if req.Input == "" {
  665. req.Input = []string{""}
  666. }
  667. if req.Input == nil {
  668. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
  669. return
  670. }
  671. if v, ok := req.Input.([]any); ok && len(v) == 0 {
  672. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
  673. return
  674. }
  675. var b bytes.Buffer
  676. if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
  677. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  678. return
  679. }
  680. c.Request.Body = io.NopCloser(&b)
  681. w := &EmbedWriter{
  682. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  683. model: req.Model,
  684. }
  685. c.Writer = w
  686. c.Next()
  687. }
  688. }
  689. func ChatMiddleware() gin.HandlerFunc {
  690. return func(c *gin.Context) {
  691. var req ChatCompletionRequest
  692. err := c.ShouldBindJSON(&req)
  693. if err != nil {
  694. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  695. return
  696. }
  697. if len(req.Messages) == 0 {
  698. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
  699. return
  700. }
  701. var b bytes.Buffer
  702. chatReq, err := fromChatRequest(req)
  703. if err != nil {
  704. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  705. }
  706. if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
  707. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  708. return
  709. }
  710. c.Request.Body = io.NopCloser(&b)
  711. w := &ChatWriter{
  712. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  713. stream: req.Stream,
  714. id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
  715. }
  716. c.Writer = w
  717. c.Next()
  718. }
  719. }