openai.go 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978
  1. // openai package provides middleware for partial compatibility with the OpenAI REST API
  2. package openai
  3. import (
  4. "bytes"
  5. "encoding/base64"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log/slog"
  11. "math/rand"
  12. "net/http"
  13. "strings"
  14. "time"
  15. "github.com/gin-gonic/gin"
  16. "github.com/ollama/ollama/api"
  17. "github.com/ollama/ollama/types/model"
  18. )
  19. type Error struct {
  20. Message string `json:"message"`
  21. Type string `json:"type"`
  22. Param interface{} `json:"param"`
  23. Code *string `json:"code"`
  24. }
  25. type ErrorResponse struct {
  26. Error Error `json:"error"`
  27. }
  28. type Message struct {
  29. Role string `json:"role"`
  30. Content any `json:"content"`
  31. ToolCalls []ToolCall `json:"tool_calls,omitempty"`
  32. }
  33. type Choice struct {
  34. Index int `json:"index"`
  35. Message Message `json:"message"`
  36. FinishReason *string `json:"finish_reason"`
  37. }
  38. type ChunkChoice struct {
  39. Index int `json:"index"`
  40. Delta Message `json:"delta"`
  41. FinishReason *string `json:"finish_reason"`
  42. }
  43. type CompleteChunkChoice struct {
  44. Text string `json:"text"`
  45. Index int `json:"index"`
  46. FinishReason *string `json:"finish_reason"`
  47. }
  48. type Usage struct {
  49. PromptTokens int `json:"prompt_tokens"`
  50. CompletionTokens int `json:"completion_tokens"`
  51. TotalTokens int `json:"total_tokens"`
  52. }
  53. // ChunkUsage is an alias for Usage with the ability to marshal a marker
  54. // value as null. This is to allow omitting the field in chunks when usage
  55. // isn't requested, and otherwise return null on non-final chunks when it
  56. // is requested to follow OpenAI's behavior.
  57. type ChunkUsage = Usage
  58. var nullChunkUsage = ChunkUsage{}
  59. func (u *ChunkUsage) MarshalJSON() ([]byte, error) {
  60. if u == &nullChunkUsage {
  61. return []byte("null"), nil
  62. }
  63. return json.Marshal(*u)
  64. }
  65. type ResponseFormat struct {
  66. Type string `json:"type"`
  67. }
  68. type EmbedRequest struct {
  69. Input any `json:"input"`
  70. Model string `json:"model"`
  71. }
  72. type StreamOptions struct {
  73. IncludeUsage bool `json:"include_usage"`
  74. }
  75. type ChatCompletionRequest struct {
  76. Model string `json:"model"`
  77. Messages []Message `json:"messages"`
  78. Stream bool `json:"stream"`
  79. StreamOptions *StreamOptions `json:"stream_options"`
  80. MaxTokens *int `json:"max_tokens"`
  81. Seed *int `json:"seed"`
  82. Stop any `json:"stop"`
  83. Temperature *float64 `json:"temperature"`
  84. FrequencyPenalty *float64 `json:"frequency_penalty"`
  85. PresencePenalty *float64 `json:"presence_penalty"`
  86. TopP *float64 `json:"top_p"`
  87. ResponseFormat *ResponseFormat `json:"response_format"`
  88. Tools []api.Tool `json:"tools"`
  89. }
  90. type ChatCompletion struct {
  91. Id string `json:"id"`
  92. Object string `json:"object"`
  93. Created int64 `json:"created"`
  94. Model string `json:"model"`
  95. SystemFingerprint string `json:"system_fingerprint"`
  96. Choices []Choice `json:"choices"`
  97. Usage Usage `json:"usage,omitempty"`
  98. }
  99. type ChatCompletionChunk struct {
  100. Id string `json:"id"`
  101. Object string `json:"object"`
  102. Created int64 `json:"created"`
  103. Model string `json:"model"`
  104. SystemFingerprint string `json:"system_fingerprint"`
  105. Choices []ChunkChoice `json:"choices"`
  106. Usage *ChunkUsage `json:"usage,omitempty"`
  107. }
  108. // TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
  109. type CompletionRequest struct {
  110. Model string `json:"model"`
  111. Prompt string `json:"prompt"`
  112. FrequencyPenalty float32 `json:"frequency_penalty"`
  113. MaxTokens *int `json:"max_tokens"`
  114. PresencePenalty float32 `json:"presence_penalty"`
  115. Seed *int `json:"seed"`
  116. Stop any `json:"stop"`
  117. Stream bool `json:"stream"`
  118. StreamOptions *StreamOptions `json:"stream_options"`
  119. Temperature *float32 `json:"temperature"`
  120. TopP float32 `json:"top_p"`
  121. Suffix string `json:"suffix"`
  122. }
  123. type Completion struct {
  124. Id string `json:"id"`
  125. Object string `json:"object"`
  126. Created int64 `json:"created"`
  127. Model string `json:"model"`
  128. SystemFingerprint string `json:"system_fingerprint"`
  129. Choices []CompleteChunkChoice `json:"choices"`
  130. Usage Usage `json:"usage,omitempty"`
  131. }
  132. type CompletionChunk struct {
  133. Id string `json:"id"`
  134. Object string `json:"object"`
  135. Created int64 `json:"created"`
  136. Choices []CompleteChunkChoice `json:"choices"`
  137. Model string `json:"model"`
  138. SystemFingerprint string `json:"system_fingerprint"`
  139. Usage *ChunkUsage `json:"usage,omitempty"`
  140. }
  141. type ToolCall struct {
  142. ID string `json:"id"`
  143. Type string `json:"type"`
  144. Function struct {
  145. Name string `json:"name"`
  146. Arguments string `json:"arguments"`
  147. } `json:"function"`
  148. }
  149. type Model struct {
  150. Id string `json:"id"`
  151. Object string `json:"object"`
  152. Created int64 `json:"created"`
  153. OwnedBy string `json:"owned_by"`
  154. }
  155. type Embedding struct {
  156. Object string `json:"object"`
  157. Embedding []float32 `json:"embedding"`
  158. Index int `json:"index"`
  159. }
  160. type ListCompletion struct {
  161. Object string `json:"object"`
  162. Data []Model `json:"data"`
  163. }
  164. type EmbeddingList struct {
  165. Object string `json:"object"`
  166. Data []Embedding `json:"data"`
  167. Model string `json:"model"`
  168. Usage EmbeddingUsage `json:"usage,omitempty"`
  169. }
  170. type EmbeddingUsage struct {
  171. PromptTokens int `json:"prompt_tokens"`
  172. TotalTokens int `json:"total_tokens"`
  173. }
  174. func NewError(code int, message string) ErrorResponse {
  175. var etype string
  176. switch code {
  177. case http.StatusBadRequest:
  178. etype = "invalid_request_error"
  179. case http.StatusNotFound:
  180. etype = "not_found_error"
  181. default:
  182. etype = "api_error"
  183. }
  184. return ErrorResponse{Error{Type: etype, Message: message}}
  185. }
  186. func toolCallId() string {
  187. const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
  188. b := make([]byte, 8)
  189. for i := range b {
  190. b[i] = letterBytes[rand.Intn(len(letterBytes))]
  191. }
  192. return "call_" + strings.ToLower(string(b))
  193. }
  194. func toUsage(r api.ChatResponse) Usage {
  195. return Usage{
  196. PromptTokens: r.PromptEvalCount,
  197. CompletionTokens: r.EvalCount,
  198. TotalTokens: r.PromptEvalCount + r.EvalCount,
  199. }
  200. }
  201. func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
  202. toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
  203. for i, tc := range r.Message.ToolCalls {
  204. toolCalls[i].ID = toolCallId()
  205. toolCalls[i].Type = "function"
  206. toolCalls[i].Function.Name = tc.Function.Name
  207. args, err := json.Marshal(tc.Function.Arguments)
  208. if err != nil {
  209. slog.Error("could not marshall function arguments to json", "error", err)
  210. continue
  211. }
  212. toolCalls[i].Function.Arguments = string(args)
  213. }
  214. return ChatCompletion{
  215. Id: id,
  216. Object: "chat.completion",
  217. Created: r.CreatedAt.Unix(),
  218. Model: r.Model,
  219. SystemFingerprint: "fp_ollama",
  220. Choices: []Choice{{
  221. Index: 0,
  222. Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
  223. FinishReason: func(reason string) *string {
  224. if len(toolCalls) > 0 {
  225. reason = "tool_calls"
  226. }
  227. if len(reason) > 0 {
  228. return &reason
  229. }
  230. return nil
  231. }(r.DoneReason),
  232. }},
  233. Usage: toUsage(r),
  234. }
  235. }
  236. func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
  237. return ChatCompletionChunk{
  238. Id: id,
  239. Object: "chat.completion.chunk",
  240. Created: time.Now().Unix(),
  241. Model: r.Model,
  242. SystemFingerprint: "fp_ollama",
  243. Choices: []ChunkChoice{{
  244. Index: 0,
  245. Delta: Message{Role: "assistant", Content: r.Message.Content},
  246. FinishReason: func(reason string) *string {
  247. if len(reason) > 0 {
  248. return &reason
  249. }
  250. return nil
  251. }(r.DoneReason),
  252. }},
  253. }
  254. }
  255. func toUsageGenerate(r api.GenerateResponse) Usage {
  256. return Usage{
  257. PromptTokens: r.PromptEvalCount,
  258. CompletionTokens: r.EvalCount,
  259. TotalTokens: r.PromptEvalCount + r.EvalCount,
  260. }
  261. }
  262. func toCompletion(id string, r api.GenerateResponse) Completion {
  263. return Completion{
  264. Id: id,
  265. Object: "text_completion",
  266. Created: r.CreatedAt.Unix(),
  267. Model: r.Model,
  268. SystemFingerprint: "fp_ollama",
  269. Choices: []CompleteChunkChoice{{
  270. Text: r.Response,
  271. Index: 0,
  272. FinishReason: func(reason string) *string {
  273. if len(reason) > 0 {
  274. return &reason
  275. }
  276. return nil
  277. }(r.DoneReason),
  278. }},
  279. Usage: toUsageGenerate(r),
  280. }
  281. }
  282. func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
  283. return CompletionChunk{
  284. Id: id,
  285. Object: "text_completion",
  286. Created: time.Now().Unix(),
  287. Model: r.Model,
  288. SystemFingerprint: "fp_ollama",
  289. Choices: []CompleteChunkChoice{{
  290. Text: r.Response,
  291. Index: 0,
  292. FinishReason: func(reason string) *string {
  293. if len(reason) > 0 {
  294. return &reason
  295. }
  296. return nil
  297. }(r.DoneReason),
  298. }},
  299. }
  300. }
  301. func toListCompletion(r api.ListResponse) ListCompletion {
  302. var data []Model
  303. for _, m := range r.Models {
  304. data = append(data, Model{
  305. Id: m.Name,
  306. Object: "model",
  307. Created: m.ModifiedAt.Unix(),
  308. OwnedBy: model.ParseName(m.Name).Namespace,
  309. })
  310. }
  311. return ListCompletion{
  312. Object: "list",
  313. Data: data,
  314. }
  315. }
  316. func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
  317. if r.Embeddings != nil {
  318. var data []Embedding
  319. for i, e := range r.Embeddings {
  320. data = append(data, Embedding{
  321. Object: "embedding",
  322. Embedding: e,
  323. Index: i,
  324. })
  325. }
  326. return EmbeddingList{
  327. Object: "list",
  328. Data: data,
  329. Model: model,
  330. Usage: EmbeddingUsage{
  331. PromptTokens: r.PromptEvalCount,
  332. TotalTokens: r.PromptEvalCount,
  333. },
  334. }
  335. }
  336. return EmbeddingList{}
  337. }
  338. func toModel(r api.ShowResponse, m string) Model {
  339. return Model{
  340. Id: m,
  341. Object: "model",
  342. Created: r.ModifiedAt.Unix(),
  343. OwnedBy: model.ParseName(m).Namespace,
  344. }
  345. }
  346. func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
  347. var messages []api.Message
  348. for _, msg := range r.Messages {
  349. switch content := msg.Content.(type) {
  350. case string:
  351. messages = append(messages, api.Message{Role: msg.Role, Content: content})
  352. case []any:
  353. for _, c := range content {
  354. data, ok := c.(map[string]any)
  355. if !ok {
  356. return nil, errors.New("invalid message format")
  357. }
  358. switch data["type"] {
  359. case "text":
  360. text, ok := data["text"].(string)
  361. if !ok {
  362. return nil, errors.New("invalid message format")
  363. }
  364. messages = append(messages, api.Message{Role: msg.Role, Content: text})
  365. case "image_url":
  366. var url string
  367. if urlMap, ok := data["image_url"].(map[string]any); ok {
  368. if url, ok = urlMap["url"].(string); !ok {
  369. return nil, errors.New("invalid message format")
  370. }
  371. } else {
  372. if url, ok = data["image_url"].(string); !ok {
  373. return nil, errors.New("invalid message format")
  374. }
  375. }
  376. types := []string{"jpeg", "jpg", "png"}
  377. valid := false
  378. for _, t := range types {
  379. prefix := "data:image/" + t + ";base64,"
  380. if strings.HasPrefix(url, prefix) {
  381. url = strings.TrimPrefix(url, prefix)
  382. valid = true
  383. break
  384. }
  385. }
  386. if !valid {
  387. return nil, errors.New("invalid image input")
  388. }
  389. img, err := base64.StdEncoding.DecodeString(url)
  390. if err != nil {
  391. return nil, errors.New("invalid message format")
  392. }
  393. messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
  394. default:
  395. return nil, errors.New("invalid message format")
  396. }
  397. }
  398. default:
  399. if msg.ToolCalls == nil {
  400. return nil, fmt.Errorf("invalid message content type: %T", content)
  401. }
  402. toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
  403. for i, tc := range msg.ToolCalls {
  404. toolCalls[i].Function.Name = tc.Function.Name
  405. err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
  406. if err != nil {
  407. return nil, errors.New("invalid tool call arguments")
  408. }
  409. }
  410. messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
  411. }
  412. }
  413. options := make(map[string]interface{})
  414. switch stop := r.Stop.(type) {
  415. case string:
  416. options["stop"] = []string{stop}
  417. case []any:
  418. var stops []string
  419. for _, s := range stop {
  420. if str, ok := s.(string); ok {
  421. stops = append(stops, str)
  422. }
  423. }
  424. options["stop"] = stops
  425. }
  426. if r.MaxTokens != nil {
  427. options["num_predict"] = *r.MaxTokens
  428. }
  429. if r.Temperature != nil {
  430. options["temperature"] = *r.Temperature
  431. } else {
  432. options["temperature"] = 1.0
  433. }
  434. if r.Seed != nil {
  435. options["seed"] = *r.Seed
  436. }
  437. if r.FrequencyPenalty != nil {
  438. options["frequency_penalty"] = *r.FrequencyPenalty
  439. }
  440. if r.PresencePenalty != nil {
  441. options["presence_penalty"] = *r.PresencePenalty
  442. }
  443. if r.TopP != nil {
  444. options["top_p"] = *r.TopP
  445. } else {
  446. options["top_p"] = 1.0
  447. }
  448. var format string
  449. if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" {
  450. format = "json"
  451. }
  452. return &api.ChatRequest{
  453. Model: r.Model,
  454. Messages: messages,
  455. Format: format,
  456. Options: options,
  457. Stream: &r.Stream,
  458. Tools: r.Tools,
  459. }, nil
  460. }
  461. func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
  462. options := make(map[string]any)
  463. switch stop := r.Stop.(type) {
  464. case string:
  465. options["stop"] = []string{stop}
  466. case []any:
  467. var stops []string
  468. for _, s := range stop {
  469. if str, ok := s.(string); ok {
  470. stops = append(stops, str)
  471. } else {
  472. return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
  473. }
  474. }
  475. options["stop"] = stops
  476. }
  477. if r.MaxTokens != nil {
  478. options["num_predict"] = *r.MaxTokens
  479. }
  480. if r.Temperature != nil {
  481. options["temperature"] = *r.Temperature
  482. } else {
  483. options["temperature"] = 1.0
  484. }
  485. if r.Seed != nil {
  486. options["seed"] = *r.Seed
  487. }
  488. options["frequency_penalty"] = r.FrequencyPenalty
  489. options["presence_penalty"] = r.PresencePenalty
  490. if r.TopP != 0.0 {
  491. options["top_p"] = r.TopP
  492. } else {
  493. options["top_p"] = 1.0
  494. }
  495. return api.GenerateRequest{
  496. Model: r.Model,
  497. Prompt: r.Prompt,
  498. Options: options,
  499. Stream: &r.Stream,
  500. Suffix: r.Suffix,
  501. }, nil
  502. }
  503. type BaseWriter struct {
  504. gin.ResponseWriter
  505. }
  506. type ChatWriter struct {
  507. stream bool
  508. streamUsage bool
  509. id string
  510. BaseWriter
  511. }
  512. type CompleteWriter struct {
  513. stream bool
  514. streamUsage bool
  515. id string
  516. BaseWriter
  517. }
  518. type ListWriter struct {
  519. BaseWriter
  520. }
  521. type RetrieveWriter struct {
  522. BaseWriter
  523. model string
  524. }
  525. type EmbedWriter struct {
  526. BaseWriter
  527. model string
  528. }
  529. func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
  530. var serr api.StatusError
  531. err := json.Unmarshal(data, &serr)
  532. if err != nil {
  533. return 0, err
  534. }
  535. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  536. err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
  537. if err != nil {
  538. return 0, err
  539. }
  540. return len(data), nil
  541. }
  542. func (w *ChatWriter) writeResponse(data []byte) (int, error) {
  543. var chatResponse api.ChatResponse
  544. err := json.Unmarshal(data, &chatResponse)
  545. if err != nil {
  546. return 0, err
  547. }
  548. // chat chunk
  549. if w.stream {
  550. c := toChunk(w.id, chatResponse)
  551. if w.streamUsage {
  552. c.Usage = &nullChunkUsage
  553. }
  554. d, err := json.Marshal(c)
  555. if err != nil {
  556. return 0, err
  557. }
  558. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  559. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  560. if err != nil {
  561. return 0, err
  562. }
  563. if chatResponse.Done {
  564. if w.streamUsage {
  565. u := toUsage(chatResponse)
  566. d, err := json.Marshal(ChatCompletionChunk{Usage: &u})
  567. if err != nil {
  568. return 0, err
  569. }
  570. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  571. if err != nil {
  572. return 0, err
  573. }
  574. }
  575. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  576. if err != nil {
  577. return 0, err
  578. }
  579. }
  580. return len(data), nil
  581. }
  582. // chat completion
  583. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  584. err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
  585. if err != nil {
  586. return 0, err
  587. }
  588. return len(data), nil
  589. }
  590. func (w *ChatWriter) Write(data []byte) (int, error) {
  591. code := w.ResponseWriter.Status()
  592. if code != http.StatusOK {
  593. return w.writeError(code, data)
  594. }
  595. return w.writeResponse(data)
  596. }
  597. func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
  598. var generateResponse api.GenerateResponse
  599. err := json.Unmarshal(data, &generateResponse)
  600. if err != nil {
  601. return 0, err
  602. }
  603. // completion chunk
  604. if w.stream {
  605. c := toCompleteChunk(w.id, generateResponse)
  606. if w.streamUsage {
  607. c.Usage = &nullChunkUsage
  608. }
  609. d, err := json.Marshal(c)
  610. if err != nil {
  611. return 0, err
  612. }
  613. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  614. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  615. if err != nil {
  616. return 0, err
  617. }
  618. if generateResponse.Done {
  619. if w.streamUsage {
  620. u := toUsageGenerate(generateResponse)
  621. d, err := json.Marshal(CompletionChunk{Usage: &u})
  622. if err != nil {
  623. return 0, err
  624. }
  625. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  626. if err != nil {
  627. return 0, err
  628. }
  629. }
  630. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  631. if err != nil {
  632. return 0, err
  633. }
  634. }
  635. return len(data), nil
  636. }
  637. // completion
  638. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  639. err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
  640. if err != nil {
  641. return 0, err
  642. }
  643. return len(data), nil
  644. }
  645. func (w *CompleteWriter) Write(data []byte) (int, error) {
  646. code := w.ResponseWriter.Status()
  647. if code != http.StatusOK {
  648. return w.writeError(code, data)
  649. }
  650. return w.writeResponse(data)
  651. }
  652. func (w *ListWriter) writeResponse(data []byte) (int, error) {
  653. var listResponse api.ListResponse
  654. err := json.Unmarshal(data, &listResponse)
  655. if err != nil {
  656. return 0, err
  657. }
  658. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  659. err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
  660. if err != nil {
  661. return 0, err
  662. }
  663. return len(data), nil
  664. }
  665. func (w *ListWriter) Write(data []byte) (int, error) {
  666. code := w.ResponseWriter.Status()
  667. if code != http.StatusOK {
  668. return w.writeError(code, data)
  669. }
  670. return w.writeResponse(data)
  671. }
  672. func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
  673. var showResponse api.ShowResponse
  674. err := json.Unmarshal(data, &showResponse)
  675. if err != nil {
  676. return 0, err
  677. }
  678. // retrieve completion
  679. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  680. err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
  681. if err != nil {
  682. return 0, err
  683. }
  684. return len(data), nil
  685. }
  686. func (w *RetrieveWriter) Write(data []byte) (int, error) {
  687. code := w.ResponseWriter.Status()
  688. if code != http.StatusOK {
  689. return w.writeError(code, data)
  690. }
  691. return w.writeResponse(data)
  692. }
  693. func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
  694. var embedResponse api.EmbedResponse
  695. err := json.Unmarshal(data, &embedResponse)
  696. if err != nil {
  697. return 0, err
  698. }
  699. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  700. err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
  701. if err != nil {
  702. return 0, err
  703. }
  704. return len(data), nil
  705. }
  706. func (w *EmbedWriter) Write(data []byte) (int, error) {
  707. code := w.ResponseWriter.Status()
  708. if code != http.StatusOK {
  709. return w.writeError(code, data)
  710. }
  711. return w.writeResponse(data)
  712. }
  713. func ListMiddleware() gin.HandlerFunc {
  714. return func(c *gin.Context) {
  715. w := &ListWriter{
  716. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  717. }
  718. c.Writer = w
  719. c.Next()
  720. }
  721. }
  722. func RetrieveMiddleware() gin.HandlerFunc {
  723. return func(c *gin.Context) {
  724. var b bytes.Buffer
  725. if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
  726. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  727. return
  728. }
  729. c.Request.Body = io.NopCloser(&b)
  730. // response writer
  731. w := &RetrieveWriter{
  732. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  733. model: c.Param("model"),
  734. }
  735. c.Writer = w
  736. c.Next()
  737. }
  738. }
  739. func CompletionsMiddleware() gin.HandlerFunc {
  740. return func(c *gin.Context) {
  741. var req CompletionRequest
  742. err := c.ShouldBindJSON(&req)
  743. if err != nil {
  744. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  745. return
  746. }
  747. var b bytes.Buffer
  748. genReq, err := fromCompleteRequest(req)
  749. if err != nil {
  750. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  751. return
  752. }
  753. if err := json.NewEncoder(&b).Encode(genReq); err != nil {
  754. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  755. return
  756. }
  757. c.Request.Body = io.NopCloser(&b)
  758. w := &CompleteWriter{
  759. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  760. stream: req.Stream,
  761. id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
  762. streamUsage: req.StreamOptions != nil && req.StreamOptions.IncludeUsage,
  763. }
  764. c.Writer = w
  765. c.Next()
  766. }
  767. }
  768. func EmbeddingsMiddleware() gin.HandlerFunc {
  769. return func(c *gin.Context) {
  770. var req EmbedRequest
  771. err := c.ShouldBindJSON(&req)
  772. if err != nil {
  773. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  774. return
  775. }
  776. if req.Input == "" {
  777. req.Input = []string{""}
  778. }
  779. if req.Input == nil {
  780. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
  781. return
  782. }
  783. if v, ok := req.Input.([]any); ok && len(v) == 0 {
  784. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
  785. return
  786. }
  787. var b bytes.Buffer
  788. if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
  789. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  790. return
  791. }
  792. c.Request.Body = io.NopCloser(&b)
  793. w := &EmbedWriter{
  794. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  795. model: req.Model,
  796. }
  797. c.Writer = w
  798. c.Next()
  799. }
  800. }
  801. func ChatMiddleware() gin.HandlerFunc {
  802. return func(c *gin.Context) {
  803. var req ChatCompletionRequest
  804. err := c.ShouldBindJSON(&req)
  805. if err != nil {
  806. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  807. return
  808. }
  809. if len(req.Messages) == 0 {
  810. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
  811. return
  812. }
  813. var b bytes.Buffer
  814. chatReq, err := fromChatRequest(req)
  815. if err != nil {
  816. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  817. return
  818. }
  819. if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
  820. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  821. return
  822. }
  823. c.Request.Body = io.NopCloser(&b)
  824. w := &ChatWriter{
  825. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  826. stream: req.Stream,
  827. id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
  828. streamUsage: req.StreamOptions != nil && req.StreamOptions.IncludeUsage,
  829. }
  830. c.Writer = w
  831. c.Next()
  832. }
  833. }