openai.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740
  1. // openai package provides middleware for partial compatibility with the OpenAI REST API
  2. package openai
  3. import (
  4. "bytes"
  5. "encoding/base64"
  6. "encoding/json"
  7. "fmt"
  8. "io"
  9. "math/rand"
  10. "net/http"
  11. "strings"
  12. "time"
  13. "github.com/gin-gonic/gin"
  14. "github.com/ollama/ollama/api"
  15. "github.com/ollama/ollama/types/model"
  16. )
  17. type Error struct {
  18. Message string `json:"message"`
  19. Type string `json:"type"`
  20. Param interface{} `json:"param"`
  21. Code *string `json:"code"`
  22. }
  23. type ErrorResponse struct {
  24. Error Error `json:"error"`
  25. }
  26. type Message struct {
  27. Role string `json:"role"`
  28. Content any `json:"content"`
  29. }
  30. type Choice struct {
  31. Index int `json:"index"`
  32. Message Message `json:"message"`
  33. FinishReason *string `json:"finish_reason"`
  34. }
  35. type ChunkChoice struct {
  36. Index int `json:"index"`
  37. Delta Message `json:"delta"`
  38. FinishReason *string `json:"finish_reason"`
  39. }
  40. type CompleteChunkChoice struct {
  41. Text string `json:"text"`
  42. Index int `json:"index"`
  43. FinishReason *string `json:"finish_reason"`
  44. }
  45. type Usage struct {
  46. PromptTokens int `json:"prompt_tokens"`
  47. CompletionTokens int `json:"completion_tokens"`
  48. TotalTokens int `json:"total_tokens"`
  49. }
  50. type ResponseFormat struct {
  51. Type string `json:"type"`
  52. }
  53. type ChatCompletionRequest struct {
  54. Model string `json:"model"`
  55. Messages []Message `json:"messages"`
  56. Stream bool `json:"stream"`
  57. MaxTokens *int `json:"max_tokens"`
  58. Seed *int `json:"seed"`
  59. Stop any `json:"stop"`
  60. Temperature *float64 `json:"temperature"`
  61. FrequencyPenalty *float64 `json:"frequency_penalty"`
  62. PresencePenalty *float64 `json:"presence_penalty_penalty"`
  63. TopP *float64 `json:"top_p"`
  64. ResponseFormat *ResponseFormat `json:"response_format"`
  65. }
  66. type ChatCompletion struct {
  67. Id string `json:"id"`
  68. Object string `json:"object"`
  69. Created int64 `json:"created"`
  70. Model string `json:"model"`
  71. SystemFingerprint string `json:"system_fingerprint"`
  72. Choices []Choice `json:"choices"`
  73. Usage Usage `json:"usage,omitempty"`
  74. }
  75. type ChatCompletionChunk struct {
  76. Id string `json:"id"`
  77. Object string `json:"object"`
  78. Created int64 `json:"created"`
  79. Model string `json:"model"`
  80. SystemFingerprint string `json:"system_fingerprint"`
  81. Choices []ChunkChoice `json:"choices"`
  82. }
  83. // TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
  84. type CompletionRequest struct {
  85. Model string `json:"model"`
  86. Prompt string `json:"prompt"`
  87. FrequencyPenalty float32 `json:"frequency_penalty"`
  88. MaxTokens *int `json:"max_tokens"`
  89. PresencePenalty float32 `json:"presence_penalty"`
  90. Seed *int `json:"seed"`
  91. Stop any `json:"stop"`
  92. Stream bool `json:"stream"`
  93. Temperature *float32 `json:"temperature"`
  94. TopP float32 `json:"top_p"`
  95. }
  96. type Completion struct {
  97. Id string `json:"id"`
  98. Object string `json:"object"`
  99. Created int64 `json:"created"`
  100. Model string `json:"model"`
  101. SystemFingerprint string `json:"system_fingerprint"`
  102. Choices []CompleteChunkChoice `json:"choices"`
  103. Usage Usage `json:"usage,omitempty"`
  104. }
  105. type CompletionChunk struct {
  106. Id string `json:"id"`
  107. Object string `json:"object"`
  108. Created int64 `json:"created"`
  109. Choices []CompleteChunkChoice `json:"choices"`
  110. Model string `json:"model"`
  111. SystemFingerprint string `json:"system_fingerprint"`
  112. }
  113. type Model struct {
  114. Id string `json:"id"`
  115. Object string `json:"object"`
  116. Created int64 `json:"created"`
  117. OwnedBy string `json:"owned_by"`
  118. }
  119. type ListCompletion struct {
  120. Object string `json:"object"`
  121. Data []Model `json:"data"`
  122. }
  123. func NewError(code int, message string) ErrorResponse {
  124. var etype string
  125. switch code {
  126. case http.StatusBadRequest:
  127. etype = "invalid_request_error"
  128. case http.StatusNotFound:
  129. etype = "not_found_error"
  130. default:
  131. etype = "api_error"
  132. }
  133. return ErrorResponse{Error{Type: etype, Message: message}}
  134. }
  135. func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
  136. return ChatCompletion{
  137. Id: id,
  138. Object: "chat.completion",
  139. Created: r.CreatedAt.Unix(),
  140. Model: r.Model,
  141. SystemFingerprint: "fp_ollama",
  142. Choices: []Choice{{
  143. Index: 0,
  144. Message: Message{Role: r.Message.Role, Content: r.Message.Content},
  145. FinishReason: func(reason string) *string {
  146. if len(reason) > 0 {
  147. return &reason
  148. }
  149. return nil
  150. }(r.DoneReason),
  151. }},
  152. Usage: Usage{
  153. // TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
  154. PromptTokens: r.PromptEvalCount,
  155. CompletionTokens: r.EvalCount,
  156. TotalTokens: r.PromptEvalCount + r.EvalCount,
  157. },
  158. }
  159. }
  160. func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
  161. return ChatCompletionChunk{
  162. Id: id,
  163. Object: "chat.completion.chunk",
  164. Created: time.Now().Unix(),
  165. Model: r.Model,
  166. SystemFingerprint: "fp_ollama",
  167. Choices: []ChunkChoice{{
  168. Index: 0,
  169. Delta: Message{Role: "assistant", Content: r.Message.Content},
  170. FinishReason: func(reason string) *string {
  171. if len(reason) > 0 {
  172. return &reason
  173. }
  174. return nil
  175. }(r.DoneReason),
  176. }},
  177. }
  178. }
  179. func toCompletion(id string, r api.GenerateResponse) Completion {
  180. return Completion{
  181. Id: id,
  182. Object: "text_completion",
  183. Created: r.CreatedAt.Unix(),
  184. Model: r.Model,
  185. SystemFingerprint: "fp_ollama",
  186. Choices: []CompleteChunkChoice{{
  187. Text: r.Response,
  188. Index: 0,
  189. FinishReason: func(reason string) *string {
  190. if len(reason) > 0 {
  191. return &reason
  192. }
  193. return nil
  194. }(r.DoneReason),
  195. }},
  196. Usage: Usage{
  197. // TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
  198. PromptTokens: r.PromptEvalCount,
  199. CompletionTokens: r.EvalCount,
  200. TotalTokens: r.PromptEvalCount + r.EvalCount,
  201. },
  202. }
  203. }
  204. func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
  205. return CompletionChunk{
  206. Id: id,
  207. Object: "text_completion",
  208. Created: time.Now().Unix(),
  209. Model: r.Model,
  210. SystemFingerprint: "fp_ollama",
  211. Choices: []CompleteChunkChoice{{
  212. Text: r.Response,
  213. Index: 0,
  214. FinishReason: func(reason string) *string {
  215. if len(reason) > 0 {
  216. return &reason
  217. }
  218. return nil
  219. }(r.DoneReason),
  220. }},
  221. }
  222. }
  223. func toListCompletion(r api.ListResponse) ListCompletion {
  224. var data []Model
  225. for _, m := range r.Models {
  226. data = append(data, Model{
  227. Id: m.Name,
  228. Object: "model",
  229. Created: m.ModifiedAt.Unix(),
  230. OwnedBy: model.ParseName(m.Name).Namespace,
  231. })
  232. }
  233. return ListCompletion{
  234. Object: "list",
  235. Data: data,
  236. }
  237. }
  238. func toModel(r api.ShowResponse, m string) Model {
  239. return Model{
  240. Id: m,
  241. Object: "model",
  242. Created: r.ModifiedAt.Unix(),
  243. OwnedBy: model.ParseName(m).Namespace,
  244. }
  245. }
  246. func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
  247. var messages []api.Message
  248. for _, msg := range r.Messages {
  249. switch content := msg.Content.(type) {
  250. case string:
  251. messages = append(messages, api.Message{Role: msg.Role, Content: content})
  252. case []any:
  253. message := api.Message{Role: msg.Role}
  254. for _, c := range content {
  255. data, ok := c.(map[string]any)
  256. if !ok {
  257. return nil, fmt.Errorf("invalid message format")
  258. }
  259. switch data["type"] {
  260. case "text":
  261. text, ok := data["text"].(string)
  262. if !ok {
  263. return nil, fmt.Errorf("invalid message format")
  264. }
  265. message.Content = text
  266. case "image_url":
  267. var url string
  268. if urlMap, ok := data["image_url"].(map[string]any); ok {
  269. if url, ok = urlMap["url"].(string); !ok {
  270. return nil, fmt.Errorf("invalid message format")
  271. }
  272. } else {
  273. if url, ok = data["image_url"].(string); !ok {
  274. return nil, fmt.Errorf("invalid message format")
  275. }
  276. }
  277. types := []string{"jpeg", "jpg", "png"}
  278. valid := false
  279. for _, t := range types {
  280. prefix := "data:image/" + t + ";base64,"
  281. if strings.HasPrefix(url, prefix) {
  282. url = strings.TrimPrefix(url, prefix)
  283. valid = true
  284. break
  285. }
  286. }
  287. if !valid {
  288. return nil, fmt.Errorf("invalid image input")
  289. }
  290. img, err := base64.StdEncoding.DecodeString(url)
  291. if err != nil {
  292. return nil, fmt.Errorf("invalid message format")
  293. }
  294. message.Images = append(message.Images, img)
  295. default:
  296. return nil, fmt.Errorf("invalid message format")
  297. }
  298. }
  299. messages = append(messages, message)
  300. default:
  301. return nil, fmt.Errorf("invalid message content type: %T", content)
  302. }
  303. }
  304. options := make(map[string]interface{})
  305. switch stop := r.Stop.(type) {
  306. case string:
  307. options["stop"] = []string{stop}
  308. case []any:
  309. var stops []string
  310. for _, s := range stop {
  311. if str, ok := s.(string); ok {
  312. stops = append(stops, str)
  313. }
  314. }
  315. options["stop"] = stops
  316. }
  317. if r.MaxTokens != nil {
  318. options["num_predict"] = *r.MaxTokens
  319. }
  320. if r.Temperature != nil {
  321. options["temperature"] = *r.Temperature * 2.0
  322. } else {
  323. options["temperature"] = 1.0
  324. }
  325. if r.Seed != nil {
  326. options["seed"] = *r.Seed
  327. }
  328. if r.FrequencyPenalty != nil {
  329. options["frequency_penalty"] = *r.FrequencyPenalty * 2.0
  330. }
  331. if r.PresencePenalty != nil {
  332. options["presence_penalty"] = *r.PresencePenalty * 2.0
  333. }
  334. if r.TopP != nil {
  335. options["top_p"] = *r.TopP
  336. } else {
  337. options["top_p"] = 1.0
  338. }
  339. var format string
  340. if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" {
  341. format = "json"
  342. }
  343. return &api.ChatRequest{
  344. Model: r.Model,
  345. Messages: messages,
  346. Format: format,
  347. Options: options,
  348. Stream: &r.Stream,
  349. }, nil
  350. }
  351. func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
  352. options := make(map[string]any)
  353. switch stop := r.Stop.(type) {
  354. case string:
  355. options["stop"] = []string{stop}
  356. case []any:
  357. var stops []string
  358. for _, s := range stop {
  359. if str, ok := s.(string); ok {
  360. stops = append(stops, str)
  361. } else {
  362. return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
  363. }
  364. }
  365. options["stop"] = stops
  366. }
  367. if r.MaxTokens != nil {
  368. options["num_predict"] = *r.MaxTokens
  369. }
  370. if r.Temperature != nil {
  371. options["temperature"] = *r.Temperature * 2.0
  372. } else {
  373. options["temperature"] = 1.0
  374. }
  375. if r.Seed != nil {
  376. options["seed"] = *r.Seed
  377. }
  378. options["frequency_penalty"] = r.FrequencyPenalty * 2.0
  379. options["presence_penalty"] = r.PresencePenalty * 2.0
  380. if r.TopP != 0.0 {
  381. options["top_p"] = r.TopP
  382. } else {
  383. options["top_p"] = 1.0
  384. }
  385. return api.GenerateRequest{
  386. Model: r.Model,
  387. Prompt: r.Prompt,
  388. Options: options,
  389. Stream: &r.Stream,
  390. }, nil
  391. }
  392. type BaseWriter struct {
  393. gin.ResponseWriter
  394. }
  395. type ChatWriter struct {
  396. stream bool
  397. id string
  398. BaseWriter
  399. }
  400. type CompleteWriter struct {
  401. stream bool
  402. id string
  403. BaseWriter
  404. }
  405. type ListWriter struct {
  406. BaseWriter
  407. }
  408. type RetrieveWriter struct {
  409. BaseWriter
  410. model string
  411. }
  412. func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
  413. var serr api.StatusError
  414. err := json.Unmarshal(data, &serr)
  415. if err != nil {
  416. return 0, err
  417. }
  418. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  419. err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
  420. if err != nil {
  421. return 0, err
  422. }
  423. return len(data), nil
  424. }
  425. func (w *ChatWriter) writeResponse(data []byte) (int, error) {
  426. var chatResponse api.ChatResponse
  427. err := json.Unmarshal(data, &chatResponse)
  428. if err != nil {
  429. return 0, err
  430. }
  431. // chat chunk
  432. if w.stream {
  433. d, err := json.Marshal(toChunk(w.id, chatResponse))
  434. if err != nil {
  435. return 0, err
  436. }
  437. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  438. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  439. if err != nil {
  440. return 0, err
  441. }
  442. if chatResponse.Done {
  443. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  444. if err != nil {
  445. return 0, err
  446. }
  447. }
  448. return len(data), nil
  449. }
  450. // chat completion
  451. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  452. err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
  453. if err != nil {
  454. return 0, err
  455. }
  456. return len(data), nil
  457. }
  458. func (w *ChatWriter) Write(data []byte) (int, error) {
  459. code := w.ResponseWriter.Status()
  460. if code != http.StatusOK {
  461. return w.writeError(code, data)
  462. }
  463. return w.writeResponse(data)
  464. }
  465. func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
  466. var generateResponse api.GenerateResponse
  467. err := json.Unmarshal(data, &generateResponse)
  468. if err != nil {
  469. return 0, err
  470. }
  471. // completion chunk
  472. if w.stream {
  473. d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
  474. if err != nil {
  475. return 0, err
  476. }
  477. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  478. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  479. if err != nil {
  480. return 0, err
  481. }
  482. if generateResponse.Done {
  483. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  484. if err != nil {
  485. return 0, err
  486. }
  487. }
  488. return len(data), nil
  489. }
  490. // completion
  491. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  492. err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
  493. if err != nil {
  494. return 0, err
  495. }
  496. return len(data), nil
  497. }
  498. func (w *CompleteWriter) Write(data []byte) (int, error) {
  499. code := w.ResponseWriter.Status()
  500. if code != http.StatusOK {
  501. return w.writeError(code, data)
  502. }
  503. return w.writeResponse(data)
  504. }
  505. func (w *ListWriter) writeResponse(data []byte) (int, error) {
  506. var listResponse api.ListResponse
  507. err := json.Unmarshal(data, &listResponse)
  508. if err != nil {
  509. return 0, err
  510. }
  511. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  512. err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
  513. if err != nil {
  514. return 0, err
  515. }
  516. return len(data), nil
  517. }
  518. func (w *ListWriter) Write(data []byte) (int, error) {
  519. code := w.ResponseWriter.Status()
  520. if code != http.StatusOK {
  521. return w.writeError(code, data)
  522. }
  523. return w.writeResponse(data)
  524. }
  525. func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
  526. var showResponse api.ShowResponse
  527. err := json.Unmarshal(data, &showResponse)
  528. if err != nil {
  529. return 0, err
  530. }
  531. // retrieve completion
  532. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  533. err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
  534. if err != nil {
  535. return 0, err
  536. }
  537. return len(data), nil
  538. }
  539. func (w *RetrieveWriter) Write(data []byte) (int, error) {
  540. code := w.ResponseWriter.Status()
  541. if code != http.StatusOK {
  542. return w.writeError(code, data)
  543. }
  544. return w.writeResponse(data)
  545. }
  546. func ListMiddleware() gin.HandlerFunc {
  547. return func(c *gin.Context) {
  548. w := &ListWriter{
  549. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  550. }
  551. c.Writer = w
  552. c.Next()
  553. }
  554. }
  555. func RetrieveMiddleware() gin.HandlerFunc {
  556. return func(c *gin.Context) {
  557. var b bytes.Buffer
  558. if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
  559. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  560. return
  561. }
  562. c.Request.Body = io.NopCloser(&b)
  563. // response writer
  564. w := &RetrieveWriter{
  565. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  566. model: c.Param("model"),
  567. }
  568. c.Writer = w
  569. c.Next()
  570. }
  571. }
  572. func CompletionsMiddleware() gin.HandlerFunc {
  573. return func(c *gin.Context) {
  574. var req CompletionRequest
  575. err := c.ShouldBindJSON(&req)
  576. if err != nil {
  577. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  578. return
  579. }
  580. var b bytes.Buffer
  581. genReq, err := fromCompleteRequest(req)
  582. if err != nil {
  583. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  584. return
  585. }
  586. if err := json.NewEncoder(&b).Encode(genReq); err != nil {
  587. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  588. return
  589. }
  590. c.Request.Body = io.NopCloser(&b)
  591. w := &CompleteWriter{
  592. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  593. stream: req.Stream,
  594. id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
  595. }
  596. c.Writer = w
  597. c.Next()
  598. }
  599. }
  600. func ChatMiddleware() gin.HandlerFunc {
  601. return func(c *gin.Context) {
  602. var req ChatCompletionRequest
  603. err := c.ShouldBindJSON(&req)
  604. if err != nil {
  605. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  606. return
  607. }
  608. if len(req.Messages) == 0 {
  609. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
  610. return
  611. }
  612. var b bytes.Buffer
  613. chatReq, err := fromChatRequest(req)
  614. if err != nil {
  615. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  616. }
  617. if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
  618. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  619. return
  620. }
  621. c.Request.Body = io.NopCloser(&b)
  622. w := &ChatWriter{
  623. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  624. stream: req.Stream,
  625. id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
  626. }
  627. c.Writer = w
  628. c.Next()
  629. }
  630. }