openai.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672
  1. // openai package provides middleware for partial compatibility with the OpenAI REST API
  2. package openai
  3. import (
  4. "bytes"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "math/rand"
  9. "net/http"
  10. "time"
  11. "github.com/gin-gonic/gin"
  12. "github.com/ollama/ollama/api"
  13. "github.com/ollama/ollama/types/model"
  14. )
  15. type Error struct {
  16. Message string `json:"message"`
  17. Type string `json:"type"`
  18. Param interface{} `json:"param"`
  19. Code *string `json:"code"`
  20. }
  21. type ErrorResponse struct {
  22. Error Error `json:"error"`
  23. }
  24. type Message struct {
  25. Role string `json:"role"`
  26. Content string `json:"content"`
  27. }
  28. type Choice struct {
  29. Index int `json:"index"`
  30. Message Message `json:"message"`
  31. FinishReason *string `json:"finish_reason"`
  32. }
  33. type ChunkChoice struct {
  34. Index int `json:"index"`
  35. Delta Message `json:"delta"`
  36. FinishReason *string `json:"finish_reason"`
  37. }
  38. type CompleteChunkChoice struct {
  39. Text string `json:"text"`
  40. Index int `json:"index"`
  41. FinishReason *string `json:"finish_reason"`
  42. }
  43. type Usage struct {
  44. PromptTokens int `json:"prompt_tokens"`
  45. CompletionTokens int `json:"completion_tokens"`
  46. TotalTokens int `json:"total_tokens"`
  47. }
  48. type ResponseFormat struct {
  49. Type string `json:"type"`
  50. }
  51. type ChatCompletionRequest struct {
  52. Model string `json:"model"`
  53. Messages []Message `json:"messages"`
  54. Stream bool `json:"stream"`
  55. MaxTokens *int `json:"max_tokens"`
  56. Seed *int `json:"seed"`
  57. Stop any `json:"stop"`
  58. Temperature *float64 `json:"temperature"`
  59. FrequencyPenalty *float64 `json:"frequency_penalty"`
  60. PresencePenalty *float64 `json:"presence_penalty_penalty"`
  61. TopP *float64 `json:"top_p"`
  62. ResponseFormat *ResponseFormat `json:"response_format"`
  63. }
  64. type ChatCompletion struct {
  65. Id string `json:"id"`
  66. Object string `json:"object"`
  67. Created int64 `json:"created"`
  68. Model string `json:"model"`
  69. SystemFingerprint string `json:"system_fingerprint"`
  70. Choices []Choice `json:"choices"`
  71. Usage Usage `json:"usage,omitempty"`
  72. }
  73. type ChatCompletionChunk struct {
  74. Id string `json:"id"`
  75. Object string `json:"object"`
  76. Created int64 `json:"created"`
  77. Model string `json:"model"`
  78. SystemFingerprint string `json:"system_fingerprint"`
  79. Choices []ChunkChoice `json:"choices"`
  80. }
  81. // TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
  82. type CompletionRequest struct {
  83. Model string `json:"model"`
  84. Prompt string `json:"prompt"`
  85. FrequencyPenalty float32 `json:"frequency_penalty"`
  86. MaxTokens *int `json:"max_tokens"`
  87. PresencePenalty float32 `json:"presence_penalty"`
  88. Seed *int `json:"seed"`
  89. Stop any `json:"stop"`
  90. Stream bool `json:"stream"`
  91. Temperature *float32 `json:"temperature"`
  92. TopP float32 `json:"top_p"`
  93. }
  94. type Completion struct {
  95. Id string `json:"id"`
  96. Object string `json:"object"`
  97. Created int64 `json:"created"`
  98. Model string `json:"model"`
  99. SystemFingerprint string `json:"system_fingerprint"`
  100. Choices []CompleteChunkChoice `json:"choices"`
  101. Usage Usage `json:"usage,omitempty"`
  102. }
  103. type CompletionChunk struct {
  104. Id string `json:"id"`
  105. Object string `json:"object"`
  106. Created int64 `json:"created"`
  107. Choices []CompleteChunkChoice `json:"choices"`
  108. Model string `json:"model"`
  109. SystemFingerprint string `json:"system_fingerprint"`
  110. }
  111. type Model struct {
  112. Id string `json:"id"`
  113. Object string `json:"object"`
  114. Created int64 `json:"created"`
  115. OwnedBy string `json:"owned_by"`
  116. }
  117. type ListCompletion struct {
  118. Object string `json:"object"`
  119. Data []Model `json:"data"`
  120. }
  121. func NewError(code int, message string) ErrorResponse {
  122. var etype string
  123. switch code {
  124. case http.StatusBadRequest:
  125. etype = "invalid_request_error"
  126. case http.StatusNotFound:
  127. etype = "not_found_error"
  128. default:
  129. etype = "api_error"
  130. }
  131. return ErrorResponse{Error{Type: etype, Message: message}}
  132. }
  133. func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
  134. return ChatCompletion{
  135. Id: id,
  136. Object: "chat.completion",
  137. Created: r.CreatedAt.Unix(),
  138. Model: r.Model,
  139. SystemFingerprint: "fp_ollama",
  140. Choices: []Choice{{
  141. Index: 0,
  142. Message: Message{Role: r.Message.Role, Content: r.Message.Content},
  143. FinishReason: func(reason string) *string {
  144. if len(reason) > 0 {
  145. return &reason
  146. }
  147. return nil
  148. }(r.DoneReason),
  149. }},
  150. Usage: Usage{
  151. // TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
  152. PromptTokens: r.PromptEvalCount,
  153. CompletionTokens: r.EvalCount,
  154. TotalTokens: r.PromptEvalCount + r.EvalCount,
  155. },
  156. }
  157. }
  158. func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
  159. return ChatCompletionChunk{
  160. Id: id,
  161. Object: "chat.completion.chunk",
  162. Created: time.Now().Unix(),
  163. Model: r.Model,
  164. SystemFingerprint: "fp_ollama",
  165. Choices: []ChunkChoice{{
  166. Index: 0,
  167. Delta: Message{Role: "assistant", Content: r.Message.Content},
  168. FinishReason: func(reason string) *string {
  169. if len(reason) > 0 {
  170. return &reason
  171. }
  172. return nil
  173. }(r.DoneReason),
  174. }},
  175. }
  176. }
  177. func toCompletion(id string, r api.GenerateResponse) Completion {
  178. return Completion{
  179. Id: id,
  180. Object: "text_completion",
  181. Created: r.CreatedAt.Unix(),
  182. Model: r.Model,
  183. SystemFingerprint: "fp_ollama",
  184. Choices: []CompleteChunkChoice{{
  185. Text: r.Response,
  186. Index: 0,
  187. FinishReason: func(reason string) *string {
  188. if len(reason) > 0 {
  189. return &reason
  190. }
  191. return nil
  192. }(r.DoneReason),
  193. }},
  194. Usage: Usage{
  195. // TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
  196. PromptTokens: r.PromptEvalCount,
  197. CompletionTokens: r.EvalCount,
  198. TotalTokens: r.PromptEvalCount + r.EvalCount,
  199. },
  200. }
  201. }
  202. func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
  203. return CompletionChunk{
  204. Id: id,
  205. Object: "text_completion",
  206. Created: time.Now().Unix(),
  207. Model: r.Model,
  208. SystemFingerprint: "fp_ollama",
  209. Choices: []CompleteChunkChoice{{
  210. Text: r.Response,
  211. Index: 0,
  212. FinishReason: func(reason string) *string {
  213. if len(reason) > 0 {
  214. return &reason
  215. }
  216. return nil
  217. }(r.DoneReason),
  218. }},
  219. }
  220. }
  221. func toListCompletion(r api.ListResponse) ListCompletion {
  222. var data []Model
  223. for _, m := range r.Models {
  224. data = append(data, Model{
  225. Id: m.Name,
  226. Object: "model",
  227. Created: m.ModifiedAt.Unix(),
  228. OwnedBy: model.ParseName(m.Name).Namespace,
  229. })
  230. }
  231. return ListCompletion{
  232. Object: "list",
  233. Data: data,
  234. }
  235. }
  236. func toModel(r api.ShowResponse, m string) Model {
  237. return Model{
  238. Id: m,
  239. Object: "model",
  240. Created: r.ModifiedAt.Unix(),
  241. OwnedBy: model.ParseName(m).Namespace,
  242. }
  243. }
  244. func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
  245. var messages []api.Message
  246. for _, msg := range r.Messages {
  247. messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
  248. }
  249. options := make(map[string]interface{})
  250. switch stop := r.Stop.(type) {
  251. case string:
  252. options["stop"] = []string{stop}
  253. case []any:
  254. var stops []string
  255. for _, s := range stop {
  256. if str, ok := s.(string); ok {
  257. stops = append(stops, str)
  258. }
  259. }
  260. options["stop"] = stops
  261. }
  262. if r.MaxTokens != nil {
  263. options["num_predict"] = *r.MaxTokens
  264. }
  265. if r.Temperature != nil {
  266. options["temperature"] = *r.Temperature * 2.0
  267. } else {
  268. options["temperature"] = 1.0
  269. }
  270. if r.Seed != nil {
  271. options["seed"] = *r.Seed
  272. }
  273. if r.FrequencyPenalty != nil {
  274. options["frequency_penalty"] = *r.FrequencyPenalty * 2.0
  275. }
  276. if r.PresencePenalty != nil {
  277. options["presence_penalty"] = *r.PresencePenalty * 2.0
  278. }
  279. if r.TopP != nil {
  280. options["top_p"] = *r.TopP
  281. } else {
  282. options["top_p"] = 1.0
  283. }
  284. var format string
  285. if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" {
  286. format = "json"
  287. }
  288. return api.ChatRequest{
  289. Model: r.Model,
  290. Messages: messages,
  291. Format: format,
  292. Options: options,
  293. Stream: &r.Stream,
  294. }
  295. }
  296. func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
  297. options := make(map[string]any)
  298. switch stop := r.Stop.(type) {
  299. case string:
  300. options["stop"] = []string{stop}
  301. case []string:
  302. options["stop"] = stop
  303. default:
  304. if r.Stop != nil {
  305. return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", r.Stop)
  306. }
  307. }
  308. if r.MaxTokens != nil {
  309. options["num_predict"] = *r.MaxTokens
  310. }
  311. if r.Temperature != nil {
  312. options["temperature"] = *r.Temperature * 2.0
  313. } else {
  314. options["temperature"] = 1.0
  315. }
  316. if r.Seed != nil {
  317. options["seed"] = *r.Seed
  318. }
  319. options["frequency_penalty"] = r.FrequencyPenalty * 2.0
  320. options["presence_penalty"] = r.PresencePenalty * 2.0
  321. if r.TopP != 0.0 {
  322. options["top_p"] = r.TopP
  323. } else {
  324. options["top_p"] = 1.0
  325. }
  326. return api.GenerateRequest{
  327. Model: r.Model,
  328. Prompt: r.Prompt,
  329. Options: options,
  330. Stream: &r.Stream,
  331. }, nil
  332. }
  333. type BaseWriter struct {
  334. gin.ResponseWriter
  335. }
  336. type ChatWriter struct {
  337. stream bool
  338. id string
  339. BaseWriter
  340. }
  341. type CompleteWriter struct {
  342. stream bool
  343. id string
  344. BaseWriter
  345. }
  346. type ListWriter struct {
  347. BaseWriter
  348. }
  349. type RetrieveWriter struct {
  350. BaseWriter
  351. model string
  352. }
  353. func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
  354. var serr api.StatusError
  355. err := json.Unmarshal(data, &serr)
  356. if err != nil {
  357. return 0, err
  358. }
  359. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  360. err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
  361. if err != nil {
  362. return 0, err
  363. }
  364. return len(data), nil
  365. }
  366. func (w *ChatWriter) writeResponse(data []byte) (int, error) {
  367. var chatResponse api.ChatResponse
  368. err := json.Unmarshal(data, &chatResponse)
  369. if err != nil {
  370. return 0, err
  371. }
  372. // chat chunk
  373. if w.stream {
  374. d, err := json.Marshal(toChunk(w.id, chatResponse))
  375. if err != nil {
  376. return 0, err
  377. }
  378. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  379. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  380. if err != nil {
  381. return 0, err
  382. }
  383. if chatResponse.Done {
  384. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  385. if err != nil {
  386. return 0, err
  387. }
  388. }
  389. return len(data), nil
  390. }
  391. // chat completion
  392. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  393. err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
  394. if err != nil {
  395. return 0, err
  396. }
  397. return len(data), nil
  398. }
  399. func (w *ChatWriter) Write(data []byte) (int, error) {
  400. code := w.ResponseWriter.Status()
  401. if code != http.StatusOK {
  402. return w.writeError(code, data)
  403. }
  404. return w.writeResponse(data)
  405. }
  406. func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
  407. var generateResponse api.GenerateResponse
  408. err := json.Unmarshal(data, &generateResponse)
  409. if err != nil {
  410. return 0, err
  411. }
  412. // completion chunk
  413. if w.stream {
  414. d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
  415. if err != nil {
  416. return 0, err
  417. }
  418. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  419. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  420. if err != nil {
  421. return 0, err
  422. }
  423. if generateResponse.Done {
  424. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  425. if err != nil {
  426. return 0, err
  427. }
  428. }
  429. return len(data), nil
  430. }
  431. // completion
  432. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  433. err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
  434. if err != nil {
  435. return 0, err
  436. }
  437. return len(data), nil
  438. }
  439. func (w *CompleteWriter) Write(data []byte) (int, error) {
  440. code := w.ResponseWriter.Status()
  441. if code != http.StatusOK {
  442. return w.writeError(code, data)
  443. }
  444. return w.writeResponse(data)
  445. }
  446. func (w *ListWriter) writeResponse(data []byte) (int, error) {
  447. var listResponse api.ListResponse
  448. err := json.Unmarshal(data, &listResponse)
  449. if err != nil {
  450. return 0, err
  451. }
  452. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  453. err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
  454. if err != nil {
  455. return 0, err
  456. }
  457. return len(data), nil
  458. }
  459. func (w *ListWriter) Write(data []byte) (int, error) {
  460. code := w.ResponseWriter.Status()
  461. if code != http.StatusOK {
  462. return w.writeError(code, data)
  463. }
  464. return w.writeResponse(data)
  465. }
  466. func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
  467. var showResponse api.ShowResponse
  468. err := json.Unmarshal(data, &showResponse)
  469. if err != nil {
  470. return 0, err
  471. }
  472. // retrieve completion
  473. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  474. err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
  475. if err != nil {
  476. return 0, err
  477. }
  478. return len(data), nil
  479. }
  480. func (w *RetrieveWriter) Write(data []byte) (int, error) {
  481. code := w.ResponseWriter.Status()
  482. if code != http.StatusOK {
  483. return w.writeError(code, data)
  484. }
  485. return w.writeResponse(data)
  486. }
  487. func ListMiddleware() gin.HandlerFunc {
  488. return func(c *gin.Context) {
  489. w := &ListWriter{
  490. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  491. }
  492. c.Writer = w
  493. c.Next()
  494. }
  495. }
  496. func RetrieveMiddleware() gin.HandlerFunc {
  497. return func(c *gin.Context) {
  498. var b bytes.Buffer
  499. if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
  500. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  501. return
  502. }
  503. c.Request.Body = io.NopCloser(&b)
  504. // response writer
  505. w := &RetrieveWriter{
  506. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  507. model: c.Param("model"),
  508. }
  509. c.Writer = w
  510. c.Next()
  511. }
  512. }
  513. func CompletionsMiddleware() gin.HandlerFunc {
  514. return func(c *gin.Context) {
  515. var req CompletionRequest
  516. err := c.ShouldBindJSON(&req)
  517. if err != nil {
  518. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  519. return
  520. }
  521. var b bytes.Buffer
  522. genReq, err := fromCompleteRequest(req)
  523. if err != nil {
  524. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  525. return
  526. }
  527. if err := json.NewEncoder(&b).Encode(genReq); err != nil {
  528. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  529. return
  530. }
  531. c.Request.Body = io.NopCloser(&b)
  532. w := &CompleteWriter{
  533. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  534. stream: req.Stream,
  535. id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
  536. }
  537. c.Writer = w
  538. c.Next()
  539. }
  540. }
  541. func ChatMiddleware() gin.HandlerFunc {
  542. return func(c *gin.Context) {
  543. var req ChatCompletionRequest
  544. err := c.ShouldBindJSON(&req)
  545. if err != nil {
  546. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  547. return
  548. }
  549. if len(req.Messages) == 0 {
  550. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
  551. return
  552. }
  553. var b bytes.Buffer
  554. if err := json.NewEncoder(&b).Encode(fromChatRequest(req)); err != nil {
  555. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  556. return
  557. }
  558. c.Request.Body = io.NopCloser(&b)
  559. w := &ChatWriter{
  560. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  561. stream: req.Stream,
  562. id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
  563. }
  564. c.Writer = w
  565. c.Next()
  566. }
  567. }