openai.go 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987
  1. // openai package provides middleware for partial compatibility with the OpenAI REST API
  2. package openai
  3. import (
  4. "bytes"
  5. "encoding/base64"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log/slog"
  11. "math/rand"
  12. "net/http"
  13. "strings"
  14. "time"
  15. "github.com/gin-gonic/gin"
  16. "github.com/ollama/ollama/api"
  17. "github.com/ollama/ollama/types/model"
  18. )
  19. type Error struct {
  20. Message string `json:"message"`
  21. Type string `json:"type"`
  22. Param interface{} `json:"param"`
  23. Code *string `json:"code"`
  24. }
  25. type ErrorResponse struct {
  26. Error Error `json:"error"`
  27. }
  28. type Message struct {
  29. Role string `json:"role"`
  30. Content any `json:"content"`
  31. ToolCalls []ToolCall `json:"tool_calls,omitempty"`
  32. }
  33. type Choice struct {
  34. Index int `json:"index"`
  35. Message Message `json:"message"`
  36. FinishReason *string `json:"finish_reason"`
  37. }
  38. type ChunkChoice struct {
  39. Index int `json:"index"`
  40. Delta Message `json:"delta"`
  41. FinishReason *string `json:"finish_reason"`
  42. }
  43. type CompleteChunkChoice struct {
  44. Text string `json:"text"`
  45. Index int `json:"index"`
  46. FinishReason *string `json:"finish_reason"`
  47. }
  48. type Usage struct {
  49. PromptTokens int `json:"prompt_tokens"`
  50. CompletionTokens int `json:"completion_tokens"`
  51. TotalTokens int `json:"total_tokens"`
  52. }
  53. type ResponseFormat struct {
  54. Type string `json:"type"`
  55. JsonSchema *JsonSchema `json:"json_schema,omitempty"`
  56. }
  57. type JsonSchema struct {
  58. Schema map[string]any `json:"schema"`
  59. }
  60. type EmbedRequest struct {
  61. Input any `json:"input"`
  62. Model string `json:"model"`
  63. }
  64. type StreamOptions struct {
  65. IncludeUsage bool `json:"include_usage"`
  66. }
  67. type ChatCompletionRequest struct {
  68. Model string `json:"model"`
  69. Messages []Message `json:"messages"`
  70. Stream bool `json:"stream"`
  71. StreamOptions *StreamOptions `json:"stream_options"`
  72. MaxTokens *int `json:"max_tokens"`
  73. Seed *int `json:"seed"`
  74. Stop any `json:"stop"`
  75. Temperature *float64 `json:"temperature"`
  76. FrequencyPenalty *float64 `json:"frequency_penalty"`
  77. PresencePenalty *float64 `json:"presence_penalty"`
  78. TopP *float64 `json:"top_p"`
  79. ResponseFormat *ResponseFormat `json:"response_format"`
  80. Tools []api.Tool `json:"tools"`
  81. }
  82. type ChatCompletion struct {
  83. Id string `json:"id"`
  84. Object string `json:"object"`
  85. Created int64 `json:"created"`
  86. Model string `json:"model"`
  87. SystemFingerprint string `json:"system_fingerprint"`
  88. Choices []Choice `json:"choices"`
  89. Usage Usage `json:"usage,omitempty"`
  90. }
  91. type ChatCompletionChunk struct {
  92. Id string `json:"id"`
  93. Object string `json:"object"`
  94. Created int64 `json:"created"`
  95. Model string `json:"model"`
  96. SystemFingerprint string `json:"system_fingerprint"`
  97. Choices []ChunkChoice `json:"choices"`
  98. Usage *Usage `json:"usage,omitempty"`
  99. }
  100. // TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
  101. type CompletionRequest struct {
  102. Model string `json:"model"`
  103. Prompt string `json:"prompt"`
  104. FrequencyPenalty float32 `json:"frequency_penalty"`
  105. MaxTokens *int `json:"max_tokens"`
  106. PresencePenalty float32 `json:"presence_penalty"`
  107. Seed *int `json:"seed"`
  108. Stop any `json:"stop"`
  109. Stream bool `json:"stream"`
  110. StreamOptions *StreamOptions `json:"stream_options"`
  111. Temperature *float32 `json:"temperature"`
  112. TopP float32 `json:"top_p"`
  113. Suffix string `json:"suffix"`
  114. }
  115. type Completion struct {
  116. Id string `json:"id"`
  117. Object string `json:"object"`
  118. Created int64 `json:"created"`
  119. Model string `json:"model"`
  120. SystemFingerprint string `json:"system_fingerprint"`
  121. Choices []CompleteChunkChoice `json:"choices"`
  122. Usage Usage `json:"usage,omitempty"`
  123. }
  124. type CompletionChunk struct {
  125. Id string `json:"id"`
  126. Object string `json:"object"`
  127. Created int64 `json:"created"`
  128. Choices []CompleteChunkChoice `json:"choices"`
  129. Model string `json:"model"`
  130. SystemFingerprint string `json:"system_fingerprint"`
  131. Usage *Usage `json:"usage,omitempty"`
  132. }
  133. type ToolCall struct {
  134. ID string `json:"id"`
  135. Index int `json:"index"`
  136. Type string `json:"type"`
  137. Function struct {
  138. Name string `json:"name"`
  139. Arguments string `json:"arguments"`
  140. } `json:"function"`
  141. }
  142. type Model struct {
  143. Id string `json:"id"`
  144. Object string `json:"object"`
  145. Created int64 `json:"created"`
  146. OwnedBy string `json:"owned_by"`
  147. }
  148. type Embedding struct {
  149. Object string `json:"object"`
  150. Embedding []float32 `json:"embedding"`
  151. Index int `json:"index"`
  152. }
  153. type ListCompletion struct {
  154. Object string `json:"object"`
  155. Data []Model `json:"data"`
  156. }
  157. type EmbeddingList struct {
  158. Object string `json:"object"`
  159. Data []Embedding `json:"data"`
  160. Model string `json:"model"`
  161. Usage EmbeddingUsage `json:"usage,omitempty"`
  162. }
  163. type EmbeddingUsage struct {
  164. PromptTokens int `json:"prompt_tokens"`
  165. TotalTokens int `json:"total_tokens"`
  166. }
  167. func NewError(code int, message string) ErrorResponse {
  168. var etype string
  169. switch code {
  170. case http.StatusBadRequest:
  171. etype = "invalid_request_error"
  172. case http.StatusNotFound:
  173. etype = "not_found_error"
  174. default:
  175. etype = "api_error"
  176. }
  177. return ErrorResponse{Error{Type: etype, Message: message}}
  178. }
  179. func toUsage(r api.ChatResponse) Usage {
  180. return Usage{
  181. PromptTokens: r.PromptEvalCount,
  182. CompletionTokens: r.EvalCount,
  183. TotalTokens: r.PromptEvalCount + r.EvalCount,
  184. }
  185. }
  186. func toolCallId() string {
  187. const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
  188. b := make([]byte, 8)
  189. for i := range b {
  190. b[i] = letterBytes[rand.Intn(len(letterBytes))]
  191. }
  192. return "call_" + strings.ToLower(string(b))
  193. }
  194. func toToolCalls(tc []api.ToolCall) []ToolCall {
  195. toolCalls := make([]ToolCall, len(tc))
  196. for i, tc := range tc {
  197. toolCalls[i].ID = toolCallId()
  198. toolCalls[i].Type = "function"
  199. toolCalls[i].Function.Name = tc.Function.Name
  200. toolCalls[i].Index = tc.Function.Index
  201. args, err := json.Marshal(tc.Function.Arguments)
  202. if err != nil {
  203. slog.Error("could not marshall function arguments to json", "error", err)
  204. continue
  205. }
  206. toolCalls[i].Function.Arguments = string(args)
  207. }
  208. return toolCalls
  209. }
  210. func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
  211. toolCalls := toToolCalls(r.Message.ToolCalls)
  212. return ChatCompletion{
  213. Id: id,
  214. Object: "chat.completion",
  215. Created: r.CreatedAt.Unix(),
  216. Model: r.Model,
  217. SystemFingerprint: "fp_ollama",
  218. Choices: []Choice{{
  219. Index: 0,
  220. Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
  221. FinishReason: func(reason string) *string {
  222. if len(toolCalls) > 0 {
  223. reason = "tool_calls"
  224. }
  225. if len(reason) > 0 {
  226. return &reason
  227. }
  228. return nil
  229. }(r.DoneReason),
  230. }},
  231. Usage: toUsage(r),
  232. }
  233. }
  234. func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
  235. toolCalls := toToolCalls(r.Message.ToolCalls)
  236. return ChatCompletionChunk{
  237. Id: id,
  238. Object: "chat.completion.chunk",
  239. Created: time.Now().Unix(),
  240. Model: r.Model,
  241. SystemFingerprint: "fp_ollama",
  242. Choices: []ChunkChoice{{
  243. Index: 0,
  244. Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
  245. FinishReason: func(reason string) *string {
  246. if len(reason) > 0 {
  247. return &reason
  248. }
  249. return nil
  250. }(r.DoneReason),
  251. }},
  252. }
  253. }
  254. func toUsageGenerate(r api.GenerateResponse) Usage {
  255. return Usage{
  256. PromptTokens: r.PromptEvalCount,
  257. CompletionTokens: r.EvalCount,
  258. TotalTokens: r.PromptEvalCount + r.EvalCount,
  259. }
  260. }
  261. func toCompletion(id string, r api.GenerateResponse) Completion {
  262. return Completion{
  263. Id: id,
  264. Object: "text_completion",
  265. Created: r.CreatedAt.Unix(),
  266. Model: r.Model,
  267. SystemFingerprint: "fp_ollama",
  268. Choices: []CompleteChunkChoice{{
  269. Text: r.Response,
  270. Index: 0,
  271. FinishReason: func(reason string) *string {
  272. if len(reason) > 0 {
  273. return &reason
  274. }
  275. return nil
  276. }(r.DoneReason),
  277. }},
  278. Usage: toUsageGenerate(r),
  279. }
  280. }
  281. func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
  282. return CompletionChunk{
  283. Id: id,
  284. Object: "text_completion",
  285. Created: time.Now().Unix(),
  286. Model: r.Model,
  287. SystemFingerprint: "fp_ollama",
  288. Choices: []CompleteChunkChoice{{
  289. Text: r.Response,
  290. Index: 0,
  291. FinishReason: func(reason string) *string {
  292. if len(reason) > 0 {
  293. return &reason
  294. }
  295. return nil
  296. }(r.DoneReason),
  297. }},
  298. }
  299. }
  300. func toListCompletion(r api.ListResponse) ListCompletion {
  301. var data []Model
  302. for _, m := range r.Models {
  303. data = append(data, Model{
  304. Id: m.Name,
  305. Object: "model",
  306. Created: m.ModifiedAt.Unix(),
  307. OwnedBy: model.ParseName(m.Name).Namespace,
  308. })
  309. }
  310. return ListCompletion{
  311. Object: "list",
  312. Data: data,
  313. }
  314. }
  315. func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
  316. if r.Embeddings != nil {
  317. var data []Embedding
  318. for i, e := range r.Embeddings {
  319. data = append(data, Embedding{
  320. Object: "embedding",
  321. Embedding: e,
  322. Index: i,
  323. })
  324. }
  325. return EmbeddingList{
  326. Object: "list",
  327. Data: data,
  328. Model: model,
  329. Usage: EmbeddingUsage{
  330. PromptTokens: r.PromptEvalCount,
  331. TotalTokens: r.PromptEvalCount,
  332. },
  333. }
  334. }
  335. return EmbeddingList{}
  336. }
  337. func toModel(r api.ShowResponse, m string) Model {
  338. return Model{
  339. Id: m,
  340. Object: "model",
  341. Created: r.ModifiedAt.Unix(),
  342. OwnedBy: model.ParseName(m).Namespace,
  343. }
  344. }
  345. func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
  346. var messages []api.Message
  347. for _, msg := range r.Messages {
  348. switch content := msg.Content.(type) {
  349. case string:
  350. messages = append(messages, api.Message{Role: msg.Role, Content: content})
  351. case []any:
  352. for _, c := range content {
  353. data, ok := c.(map[string]any)
  354. if !ok {
  355. return nil, errors.New("invalid message format")
  356. }
  357. switch data["type"] {
  358. case "text":
  359. text, ok := data["text"].(string)
  360. if !ok {
  361. return nil, errors.New("invalid message format")
  362. }
  363. messages = append(messages, api.Message{Role: msg.Role, Content: text})
  364. case "image_url":
  365. var url string
  366. if urlMap, ok := data["image_url"].(map[string]any); ok {
  367. if url, ok = urlMap["url"].(string); !ok {
  368. return nil, errors.New("invalid message format")
  369. }
  370. } else {
  371. if url, ok = data["image_url"].(string); !ok {
  372. return nil, errors.New("invalid message format")
  373. }
  374. }
  375. types := []string{"jpeg", "jpg", "png"}
  376. valid := false
  377. for _, t := range types {
  378. prefix := "data:image/" + t + ";base64,"
  379. if strings.HasPrefix(url, prefix) {
  380. url = strings.TrimPrefix(url, prefix)
  381. valid = true
  382. break
  383. }
  384. }
  385. if !valid {
  386. return nil, errors.New("invalid image input")
  387. }
  388. img, err := base64.StdEncoding.DecodeString(url)
  389. if err != nil {
  390. return nil, errors.New("invalid message format")
  391. }
  392. messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
  393. default:
  394. return nil, errors.New("invalid message format")
  395. }
  396. }
  397. default:
  398. if msg.ToolCalls == nil {
  399. return nil, fmt.Errorf("invalid message content type: %T", content)
  400. }
  401. toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
  402. for i, tc := range msg.ToolCalls {
  403. toolCalls[i].Function.Name = tc.Function.Name
  404. err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
  405. if err != nil {
  406. return nil, errors.New("invalid tool call arguments")
  407. }
  408. }
  409. messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
  410. }
  411. }
  412. options := make(map[string]interface{})
  413. switch stop := r.Stop.(type) {
  414. case string:
  415. options["stop"] = []string{stop}
  416. case []any:
  417. var stops []string
  418. for _, s := range stop {
  419. if str, ok := s.(string); ok {
  420. stops = append(stops, str)
  421. }
  422. }
  423. options["stop"] = stops
  424. }
  425. if r.MaxTokens != nil {
  426. options["num_predict"] = *r.MaxTokens
  427. }
  428. if r.Temperature != nil {
  429. options["temperature"] = *r.Temperature
  430. } else {
  431. options["temperature"] = 1.0
  432. }
  433. if r.Seed != nil {
  434. options["seed"] = *r.Seed
  435. }
  436. if r.FrequencyPenalty != nil {
  437. options["frequency_penalty"] = *r.FrequencyPenalty
  438. }
  439. if r.PresencePenalty != nil {
  440. options["presence_penalty"] = *r.PresencePenalty
  441. }
  442. if r.TopP != nil {
  443. options["top_p"] = *r.TopP
  444. } else {
  445. options["top_p"] = 1.0
  446. }
  447. var format json.RawMessage
  448. if r.ResponseFormat != nil {
  449. switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
  450. // Support the old "json_object" type for OpenAI compatibility
  451. case "json_object":
  452. format = json.RawMessage(`"json"`)
  453. case "json_schema":
  454. if r.ResponseFormat.JsonSchema != nil {
  455. schema, err := json.Marshal(r.ResponseFormat.JsonSchema.Schema)
  456. if err != nil {
  457. return nil, fmt.Errorf("failed to marshal json schema: %w", err)
  458. }
  459. format = schema
  460. }
  461. }
  462. }
  463. return &api.ChatRequest{
  464. Model: r.Model,
  465. Messages: messages,
  466. Format: format,
  467. Options: options,
  468. Stream: &r.Stream,
  469. Tools: r.Tools,
  470. }, nil
  471. }
  472. func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
  473. options := make(map[string]any)
  474. switch stop := r.Stop.(type) {
  475. case string:
  476. options["stop"] = []string{stop}
  477. case []any:
  478. var stops []string
  479. for _, s := range stop {
  480. if str, ok := s.(string); ok {
  481. stops = append(stops, str)
  482. } else {
  483. return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
  484. }
  485. }
  486. options["stop"] = stops
  487. }
  488. if r.MaxTokens != nil {
  489. options["num_predict"] = *r.MaxTokens
  490. }
  491. if r.Temperature != nil {
  492. options["temperature"] = *r.Temperature
  493. } else {
  494. options["temperature"] = 1.0
  495. }
  496. if r.Seed != nil {
  497. options["seed"] = *r.Seed
  498. }
  499. options["frequency_penalty"] = r.FrequencyPenalty
  500. options["presence_penalty"] = r.PresencePenalty
  501. if r.TopP != 0.0 {
  502. options["top_p"] = r.TopP
  503. } else {
  504. options["top_p"] = 1.0
  505. }
  506. return api.GenerateRequest{
  507. Model: r.Model,
  508. Prompt: r.Prompt,
  509. Options: options,
  510. Stream: &r.Stream,
  511. Suffix: r.Suffix,
  512. }, nil
  513. }
  514. type BaseWriter struct {
  515. gin.ResponseWriter
  516. }
  517. type ChatWriter struct {
  518. stream bool
  519. streamOptions *StreamOptions
  520. id string
  521. BaseWriter
  522. }
  523. type CompleteWriter struct {
  524. stream bool
  525. streamOptions *StreamOptions
  526. id string
  527. BaseWriter
  528. }
  529. type ListWriter struct {
  530. BaseWriter
  531. }
  532. type RetrieveWriter struct {
  533. BaseWriter
  534. model string
  535. }
  536. type EmbedWriter struct {
  537. BaseWriter
  538. model string
  539. }
  540. func (w *BaseWriter) writeError(data []byte) (int, error) {
  541. var serr api.StatusError
  542. err := json.Unmarshal(data, &serr)
  543. if err != nil {
  544. return 0, err
  545. }
  546. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  547. err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
  548. if err != nil {
  549. return 0, err
  550. }
  551. return len(data), nil
  552. }
  553. func (w *ChatWriter) writeResponse(data []byte) (int, error) {
  554. var chatResponse api.ChatResponse
  555. err := json.Unmarshal(data, &chatResponse)
  556. if err != nil {
  557. return 0, err
  558. }
  559. // chat chunk
  560. if w.stream {
  561. c := toChunk(w.id, chatResponse)
  562. if w.streamOptions != nil && w.streamOptions.IncludeUsage {
  563. c.Usage = &Usage{}
  564. }
  565. d, err := json.Marshal(c)
  566. if err != nil {
  567. return 0, err
  568. }
  569. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  570. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  571. if err != nil {
  572. return 0, err
  573. }
  574. if chatResponse.Done {
  575. if w.streamOptions != nil && w.streamOptions.IncludeUsage {
  576. u := toUsage(chatResponse)
  577. d, err := json.Marshal(ChatCompletionChunk{Choices: []ChunkChoice{}, Usage: &u})
  578. if err != nil {
  579. return 0, err
  580. }
  581. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  582. if err != nil {
  583. return 0, err
  584. }
  585. }
  586. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  587. if err != nil {
  588. return 0, err
  589. }
  590. }
  591. return len(data), nil
  592. }
  593. // chat completion
  594. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  595. err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
  596. if err != nil {
  597. return 0, err
  598. }
  599. return len(data), nil
  600. }
  601. func (w *ChatWriter) Write(data []byte) (int, error) {
  602. code := w.ResponseWriter.Status()
  603. if code != http.StatusOK {
  604. return w.writeError(data)
  605. }
  606. return w.writeResponse(data)
  607. }
  608. func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
  609. var generateResponse api.GenerateResponse
  610. err := json.Unmarshal(data, &generateResponse)
  611. if err != nil {
  612. return 0, err
  613. }
  614. // completion chunk
  615. if w.stream {
  616. c := toCompleteChunk(w.id, generateResponse)
  617. if w.streamOptions != nil && w.streamOptions.IncludeUsage {
  618. c.Usage = &Usage{}
  619. }
  620. d, err := json.Marshal(c)
  621. if err != nil {
  622. return 0, err
  623. }
  624. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  625. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  626. if err != nil {
  627. return 0, err
  628. }
  629. if generateResponse.Done {
  630. if w.streamOptions != nil && w.streamOptions.IncludeUsage {
  631. u := toUsageGenerate(generateResponse)
  632. d, err := json.Marshal(CompletionChunk{Choices: []CompleteChunkChoice{}, Usage: &u})
  633. if err != nil {
  634. return 0, err
  635. }
  636. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  637. if err != nil {
  638. return 0, err
  639. }
  640. }
  641. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  642. if err != nil {
  643. return 0, err
  644. }
  645. }
  646. return len(data), nil
  647. }
  648. // completion
  649. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  650. err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
  651. if err != nil {
  652. return 0, err
  653. }
  654. return len(data), nil
  655. }
  656. func (w *CompleteWriter) Write(data []byte) (int, error) {
  657. code := w.ResponseWriter.Status()
  658. if code != http.StatusOK {
  659. return w.writeError(data)
  660. }
  661. return w.writeResponse(data)
  662. }
  663. func (w *ListWriter) writeResponse(data []byte) (int, error) {
  664. var listResponse api.ListResponse
  665. err := json.Unmarshal(data, &listResponse)
  666. if err != nil {
  667. return 0, err
  668. }
  669. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  670. err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
  671. if err != nil {
  672. return 0, err
  673. }
  674. return len(data), nil
  675. }
  676. func (w *ListWriter) Write(data []byte) (int, error) {
  677. code := w.ResponseWriter.Status()
  678. if code != http.StatusOK {
  679. return w.writeError(data)
  680. }
  681. return w.writeResponse(data)
  682. }
  683. func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
  684. var showResponse api.ShowResponse
  685. err := json.Unmarshal(data, &showResponse)
  686. if err != nil {
  687. return 0, err
  688. }
  689. // retrieve completion
  690. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  691. err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
  692. if err != nil {
  693. return 0, err
  694. }
  695. return len(data), nil
  696. }
  697. func (w *RetrieveWriter) Write(data []byte) (int, error) {
  698. code := w.ResponseWriter.Status()
  699. if code != http.StatusOK {
  700. return w.writeError(data)
  701. }
  702. return w.writeResponse(data)
  703. }
  704. func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
  705. var embedResponse api.EmbedResponse
  706. err := json.Unmarshal(data, &embedResponse)
  707. if err != nil {
  708. return 0, err
  709. }
  710. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  711. err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
  712. if err != nil {
  713. return 0, err
  714. }
  715. return len(data), nil
  716. }
  717. func (w *EmbedWriter) Write(data []byte) (int, error) {
  718. code := w.ResponseWriter.Status()
  719. if code != http.StatusOK {
  720. return w.writeError(data)
  721. }
  722. return w.writeResponse(data)
  723. }
  724. func ListMiddleware() gin.HandlerFunc {
  725. return func(c *gin.Context) {
  726. w := &ListWriter{
  727. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  728. }
  729. c.Writer = w
  730. c.Next()
  731. }
  732. }
  733. func RetrieveMiddleware() gin.HandlerFunc {
  734. return func(c *gin.Context) {
  735. var b bytes.Buffer
  736. if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
  737. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  738. return
  739. }
  740. c.Request.Body = io.NopCloser(&b)
  741. // response writer
  742. w := &RetrieveWriter{
  743. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  744. model: c.Param("model"),
  745. }
  746. c.Writer = w
  747. c.Next()
  748. }
  749. }
  750. func CompletionsMiddleware() gin.HandlerFunc {
  751. return func(c *gin.Context) {
  752. var req CompletionRequest
  753. err := c.ShouldBindJSON(&req)
  754. if err != nil {
  755. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  756. return
  757. }
  758. var b bytes.Buffer
  759. genReq, err := fromCompleteRequest(req)
  760. if err != nil {
  761. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  762. return
  763. }
  764. if err := json.NewEncoder(&b).Encode(genReq); err != nil {
  765. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  766. return
  767. }
  768. c.Request.Body = io.NopCloser(&b)
  769. w := &CompleteWriter{
  770. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  771. stream: req.Stream,
  772. id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
  773. streamOptions: req.StreamOptions,
  774. }
  775. c.Writer = w
  776. c.Next()
  777. }
  778. }
  779. func EmbeddingsMiddleware() gin.HandlerFunc {
  780. return func(c *gin.Context) {
  781. var req EmbedRequest
  782. err := c.ShouldBindJSON(&req)
  783. if err != nil {
  784. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  785. return
  786. }
  787. if req.Input == "" {
  788. req.Input = []string{""}
  789. }
  790. if req.Input == nil {
  791. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
  792. return
  793. }
  794. if v, ok := req.Input.([]any); ok && len(v) == 0 {
  795. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
  796. return
  797. }
  798. var b bytes.Buffer
  799. if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
  800. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  801. return
  802. }
  803. c.Request.Body = io.NopCloser(&b)
  804. w := &EmbedWriter{
  805. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  806. model: req.Model,
  807. }
  808. c.Writer = w
  809. c.Next()
  810. }
  811. }
  812. func ChatMiddleware() gin.HandlerFunc {
  813. return func(c *gin.Context) {
  814. var req ChatCompletionRequest
  815. err := c.ShouldBindJSON(&req)
  816. if err != nil {
  817. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  818. return
  819. }
  820. if len(req.Messages) == 0 {
  821. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
  822. return
  823. }
  824. var b bytes.Buffer
  825. chatReq, err := fromChatRequest(req)
  826. if err != nil {
  827. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  828. return
  829. }
  830. if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
  831. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  832. return
  833. }
  834. c.Request.Body = io.NopCloser(&b)
  835. w := &ChatWriter{
  836. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  837. stream: req.Stream,
  838. id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
  839. streamOptions: req.StreamOptions,
  840. }
  841. c.Writer = w
  842. c.Next()
  843. }
  844. }