openai.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676
  1. // openai package provides middleware for partial compatibility with the OpenAI REST API
  2. package openai
  3. import (
  4. "bytes"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "math/rand"
  9. "net/http"
  10. "time"
  11. "github.com/gin-gonic/gin"
  12. "github.com/ollama/ollama/api"
  13. "github.com/ollama/ollama/types/model"
  14. )
  15. type Error struct {
  16. Message string `json:"message"`
  17. Type string `json:"type"`
  18. Param interface{} `json:"param"`
  19. Code *string `json:"code"`
  20. }
  21. type ErrorResponse struct {
  22. Error Error `json:"error"`
  23. }
  24. type Message struct {
  25. Role string `json:"role"`
  26. Content string `json:"content"`
  27. }
  28. type Choice struct {
  29. Index int `json:"index"`
  30. Message Message `json:"message"`
  31. FinishReason *string `json:"finish_reason"`
  32. }
  33. type ChunkChoice struct {
  34. Index int `json:"index"`
  35. Delta Message `json:"delta"`
  36. FinishReason *string `json:"finish_reason"`
  37. }
  38. type CompleteChunkChoice struct {
  39. Text string `json:"text"`
  40. Index int `json:"index"`
  41. FinishReason *string `json:"finish_reason"`
  42. }
  43. type Usage struct {
  44. PromptTokens int `json:"prompt_tokens"`
  45. CompletionTokens int `json:"completion_tokens"`
  46. TotalTokens int `json:"total_tokens"`
  47. }
  48. type ResponseFormat struct {
  49. Type string `json:"type"`
  50. }
  51. type ChatCompletionRequest struct {
  52. Model string `json:"model"`
  53. Messages []Message `json:"messages"`
  54. Stream bool `json:"stream"`
  55. MaxTokens *int `json:"max_tokens"`
  56. Seed *int `json:"seed"`
  57. Stop any `json:"stop"`
  58. Temperature *float64 `json:"temperature"`
  59. FrequencyPenalty *float64 `json:"frequency_penalty"`
  60. PresencePenalty *float64 `json:"presence_penalty_penalty"`
  61. TopP *float64 `json:"top_p"`
  62. ResponseFormat *ResponseFormat `json:"response_format"`
  63. }
  64. type ChatCompletion struct {
  65. Id string `json:"id"`
  66. Object string `json:"object"`
  67. Created int64 `json:"created"`
  68. Model string `json:"model"`
  69. SystemFingerprint string `json:"system_fingerprint"`
  70. Choices []Choice `json:"choices"`
  71. Usage Usage `json:"usage,omitempty"`
  72. }
  73. type ChatCompletionChunk struct {
  74. Id string `json:"id"`
  75. Object string `json:"object"`
  76. Created int64 `json:"created"`
  77. Model string `json:"model"`
  78. SystemFingerprint string `json:"system_fingerprint"`
  79. Choices []ChunkChoice `json:"choices"`
  80. }
  81. // TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
  82. type CompletionRequest struct {
  83. Model string `json:"model"`
  84. Prompt string `json:"prompt"`
  85. FrequencyPenalty float32 `json:"frequency_penalty"`
  86. MaxTokens *int `json:"max_tokens"`
  87. PresencePenalty float32 `json:"presence_penalty"`
  88. Seed *int `json:"seed"`
  89. Stop any `json:"stop"`
  90. Stream bool `json:"stream"`
  91. Temperature *float32 `json:"temperature"`
  92. TopP float32 `json:"top_p"`
  93. }
  94. type Completion struct {
  95. Id string `json:"id"`
  96. Object string `json:"object"`
  97. Created int64 `json:"created"`
  98. Model string `json:"model"`
  99. SystemFingerprint string `json:"system_fingerprint"`
  100. Choices []CompleteChunkChoice `json:"choices"`
  101. Usage Usage `json:"usage,omitempty"`
  102. }
  103. type CompletionChunk struct {
  104. Id string `json:"id"`
  105. Object string `json:"object"`
  106. Created int64 `json:"created"`
  107. Choices []CompleteChunkChoice `json:"choices"`
  108. Model string `json:"model"`
  109. SystemFingerprint string `json:"system_fingerprint"`
  110. }
  111. type Model struct {
  112. Id string `json:"id"`
  113. Object string `json:"object"`
  114. Created int64 `json:"created"`
  115. OwnedBy string `json:"owned_by"`
  116. }
  117. type ListCompletion struct {
  118. Object string `json:"object"`
  119. Data []Model `json:"data"`
  120. }
  121. func NewError(code int, message string) ErrorResponse {
  122. var etype string
  123. switch code {
  124. case http.StatusBadRequest:
  125. etype = "invalid_request_error"
  126. case http.StatusNotFound:
  127. etype = "not_found_error"
  128. default:
  129. etype = "api_error"
  130. }
  131. return ErrorResponse{Error{Type: etype, Message: message}}
  132. }
  133. func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
  134. return ChatCompletion{
  135. Id: id,
  136. Object: "chat.completion",
  137. Created: r.CreatedAt.Unix(),
  138. Model: r.Model,
  139. SystemFingerprint: "fp_ollama",
  140. Choices: []Choice{{
  141. Index: 0,
  142. Message: Message{Role: r.Message.Role, Content: r.Message.Content},
  143. FinishReason: func(reason string) *string {
  144. if len(reason) > 0 {
  145. return &reason
  146. }
  147. return nil
  148. }(r.DoneReason),
  149. }},
  150. Usage: Usage{
  151. // TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
  152. PromptTokens: r.PromptEvalCount,
  153. CompletionTokens: r.EvalCount,
  154. TotalTokens: r.PromptEvalCount + r.EvalCount,
  155. },
  156. }
  157. }
  158. func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
  159. return ChatCompletionChunk{
  160. Id: id,
  161. Object: "chat.completion.chunk",
  162. Created: time.Now().Unix(),
  163. Model: r.Model,
  164. SystemFingerprint: "fp_ollama",
  165. Choices: []ChunkChoice{{
  166. Index: 0,
  167. Delta: Message{Role: "assistant", Content: r.Message.Content},
  168. FinishReason: func(reason string) *string {
  169. if len(reason) > 0 {
  170. return &reason
  171. }
  172. return nil
  173. }(r.DoneReason),
  174. }},
  175. }
  176. }
  177. func toCompletion(id string, r api.GenerateResponse) Completion {
  178. return Completion{
  179. Id: id,
  180. Object: "text_completion",
  181. Created: r.CreatedAt.Unix(),
  182. Model: r.Model,
  183. SystemFingerprint: "fp_ollama",
  184. Choices: []CompleteChunkChoice{{
  185. Text: r.Response,
  186. Index: 0,
  187. FinishReason: func(reason string) *string {
  188. if len(reason) > 0 {
  189. return &reason
  190. }
  191. return nil
  192. }(r.DoneReason),
  193. }},
  194. Usage: Usage{
  195. // TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
  196. PromptTokens: r.PromptEvalCount,
  197. CompletionTokens: r.EvalCount,
  198. TotalTokens: r.PromptEvalCount + r.EvalCount,
  199. },
  200. }
  201. }
  202. func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
  203. return CompletionChunk{
  204. Id: id,
  205. Object: "text_completion",
  206. Created: time.Now().Unix(),
  207. Model: r.Model,
  208. SystemFingerprint: "fp_ollama",
  209. Choices: []CompleteChunkChoice{{
  210. Text: r.Response,
  211. Index: 0,
  212. FinishReason: func(reason string) *string {
  213. if len(reason) > 0 {
  214. return &reason
  215. }
  216. return nil
  217. }(r.DoneReason),
  218. }},
  219. }
  220. }
  221. func toListCompletion(r api.ListResponse) ListCompletion {
  222. var data []Model
  223. for _, m := range r.Models {
  224. data = append(data, Model{
  225. Id: m.Name,
  226. Object: "model",
  227. Created: m.ModifiedAt.Unix(),
  228. OwnedBy: model.ParseName(m.Name).Namespace,
  229. })
  230. }
  231. return ListCompletion{
  232. Object: "list",
  233. Data: data,
  234. }
  235. }
  236. func toModel(r api.ShowResponse, m string) Model {
  237. return Model{
  238. Id: m,
  239. Object: "model",
  240. Created: r.ModifiedAt.Unix(),
  241. OwnedBy: model.ParseName(m).Namespace,
  242. }
  243. }
  244. func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
  245. var messages []api.Message
  246. for _, msg := range r.Messages {
  247. messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
  248. }
  249. options := make(map[string]interface{})
  250. switch stop := r.Stop.(type) {
  251. case string:
  252. options["stop"] = []string{stop}
  253. case []any:
  254. var stops []string
  255. for _, s := range stop {
  256. if str, ok := s.(string); ok {
  257. stops = append(stops, str)
  258. }
  259. }
  260. options["stop"] = stops
  261. }
  262. if r.MaxTokens != nil {
  263. options["num_predict"] = *r.MaxTokens
  264. }
  265. if r.Temperature != nil {
  266. options["temperature"] = *r.Temperature * 2.0
  267. } else {
  268. options["temperature"] = 1.0
  269. }
  270. if r.Seed != nil {
  271. options["seed"] = *r.Seed
  272. }
  273. if r.FrequencyPenalty != nil {
  274. options["frequency_penalty"] = *r.FrequencyPenalty * 2.0
  275. }
  276. if r.PresencePenalty != nil {
  277. options["presence_penalty"] = *r.PresencePenalty * 2.0
  278. }
  279. if r.TopP != nil {
  280. options["top_p"] = *r.TopP
  281. } else {
  282. options["top_p"] = 1.0
  283. }
  284. var format string
  285. if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" {
  286. format = "json"
  287. }
  288. return api.ChatRequest{
  289. Model: r.Model,
  290. Messages: messages,
  291. Format: format,
  292. Options: options,
  293. Stream: &r.Stream,
  294. }
  295. }
  296. func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
  297. options := make(map[string]any)
  298. switch stop := r.Stop.(type) {
  299. case string:
  300. options["stop"] = []string{stop}
  301. case []any:
  302. var stops []string
  303. for _, s := range stop {
  304. if str, ok := s.(string); ok {
  305. stops = append(stops, str)
  306. } else {
  307. return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
  308. }
  309. }
  310. options["stop"] = stops
  311. }
  312. if r.MaxTokens != nil {
  313. options["num_predict"] = *r.MaxTokens
  314. }
  315. if r.Temperature != nil {
  316. options["temperature"] = *r.Temperature * 2.0
  317. } else {
  318. options["temperature"] = 1.0
  319. }
  320. if r.Seed != nil {
  321. options["seed"] = *r.Seed
  322. }
  323. options["frequency_penalty"] = r.FrequencyPenalty * 2.0
  324. options["presence_penalty"] = r.PresencePenalty * 2.0
  325. if r.TopP != 0.0 {
  326. options["top_p"] = r.TopP
  327. } else {
  328. options["top_p"] = 1.0
  329. }
  330. return api.GenerateRequest{
  331. Model: r.Model,
  332. Prompt: r.Prompt,
  333. Options: options,
  334. Stream: &r.Stream,
  335. }, nil
  336. }
  337. type BaseWriter struct {
  338. gin.ResponseWriter
  339. }
  340. type ChatWriter struct {
  341. stream bool
  342. id string
  343. BaseWriter
  344. }
  345. type CompleteWriter struct {
  346. stream bool
  347. id string
  348. BaseWriter
  349. }
  350. type ListWriter struct {
  351. BaseWriter
  352. }
  353. type RetrieveWriter struct {
  354. BaseWriter
  355. model string
  356. }
  357. func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
  358. var serr api.StatusError
  359. err := json.Unmarshal(data, &serr)
  360. if err != nil {
  361. return 0, err
  362. }
  363. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  364. err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
  365. if err != nil {
  366. return 0, err
  367. }
  368. return len(data), nil
  369. }
  370. func (w *ChatWriter) writeResponse(data []byte) (int, error) {
  371. var chatResponse api.ChatResponse
  372. err := json.Unmarshal(data, &chatResponse)
  373. if err != nil {
  374. return 0, err
  375. }
  376. // chat chunk
  377. if w.stream {
  378. d, err := json.Marshal(toChunk(w.id, chatResponse))
  379. if err != nil {
  380. return 0, err
  381. }
  382. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  383. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  384. if err != nil {
  385. return 0, err
  386. }
  387. if chatResponse.Done {
  388. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  389. if err != nil {
  390. return 0, err
  391. }
  392. }
  393. return len(data), nil
  394. }
  395. // chat completion
  396. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  397. err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
  398. if err != nil {
  399. return 0, err
  400. }
  401. return len(data), nil
  402. }
  403. func (w *ChatWriter) Write(data []byte) (int, error) {
  404. code := w.ResponseWriter.Status()
  405. if code != http.StatusOK {
  406. return w.writeError(code, data)
  407. }
  408. return w.writeResponse(data)
  409. }
  410. func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
  411. var generateResponse api.GenerateResponse
  412. err := json.Unmarshal(data, &generateResponse)
  413. if err != nil {
  414. return 0, err
  415. }
  416. // completion chunk
  417. if w.stream {
  418. d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
  419. if err != nil {
  420. return 0, err
  421. }
  422. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  423. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  424. if err != nil {
  425. return 0, err
  426. }
  427. if generateResponse.Done {
  428. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  429. if err != nil {
  430. return 0, err
  431. }
  432. }
  433. return len(data), nil
  434. }
  435. // completion
  436. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  437. err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
  438. if err != nil {
  439. return 0, err
  440. }
  441. return len(data), nil
  442. }
  443. func (w *CompleteWriter) Write(data []byte) (int, error) {
  444. code := w.ResponseWriter.Status()
  445. if code != http.StatusOK {
  446. return w.writeError(code, data)
  447. }
  448. return w.writeResponse(data)
  449. }
  450. func (w *ListWriter) writeResponse(data []byte) (int, error) {
  451. var listResponse api.ListResponse
  452. err := json.Unmarshal(data, &listResponse)
  453. if err != nil {
  454. return 0, err
  455. }
  456. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  457. err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
  458. if err != nil {
  459. return 0, err
  460. }
  461. return len(data), nil
  462. }
  463. func (w *ListWriter) Write(data []byte) (int, error) {
  464. code := w.ResponseWriter.Status()
  465. if code != http.StatusOK {
  466. return w.writeError(code, data)
  467. }
  468. return w.writeResponse(data)
  469. }
  470. func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
  471. var showResponse api.ShowResponse
  472. err := json.Unmarshal(data, &showResponse)
  473. if err != nil {
  474. return 0, err
  475. }
  476. // retrieve completion
  477. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  478. err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
  479. if err != nil {
  480. return 0, err
  481. }
  482. return len(data), nil
  483. }
  484. func (w *RetrieveWriter) Write(data []byte) (int, error) {
  485. code := w.ResponseWriter.Status()
  486. if code != http.StatusOK {
  487. return w.writeError(code, data)
  488. }
  489. return w.writeResponse(data)
  490. }
  491. func ListMiddleware() gin.HandlerFunc {
  492. return func(c *gin.Context) {
  493. w := &ListWriter{
  494. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  495. }
  496. c.Writer = w
  497. c.Next()
  498. }
  499. }
  500. func RetrieveMiddleware() gin.HandlerFunc {
  501. return func(c *gin.Context) {
  502. var b bytes.Buffer
  503. if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
  504. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  505. return
  506. }
  507. c.Request.Body = io.NopCloser(&b)
  508. // response writer
  509. w := &RetrieveWriter{
  510. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  511. model: c.Param("model"),
  512. }
  513. c.Writer = w
  514. c.Next()
  515. }
  516. }
  517. func CompletionsMiddleware() gin.HandlerFunc {
  518. return func(c *gin.Context) {
  519. var req CompletionRequest
  520. err := c.ShouldBindJSON(&req)
  521. if err != nil {
  522. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  523. return
  524. }
  525. var b bytes.Buffer
  526. genReq, err := fromCompleteRequest(req)
  527. if err != nil {
  528. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  529. return
  530. }
  531. if err := json.NewEncoder(&b).Encode(genReq); err != nil {
  532. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  533. return
  534. }
  535. c.Request.Body = io.NopCloser(&b)
  536. w := &CompleteWriter{
  537. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  538. stream: req.Stream,
  539. id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
  540. }
  541. c.Writer = w
  542. c.Next()
  543. }
  544. }
  545. func ChatMiddleware() gin.HandlerFunc {
  546. return func(c *gin.Context) {
  547. var req ChatCompletionRequest
  548. err := c.ShouldBindJSON(&req)
  549. if err != nil {
  550. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  551. return
  552. }
  553. if len(req.Messages) == 0 {
  554. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
  555. return
  556. }
  557. var b bytes.Buffer
  558. if err := json.NewEncoder(&b).Encode(fromChatRequest(req)); err != nil {
  559. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  560. return
  561. }
  562. c.Request.Body = io.NopCloser(&b)
  563. w := &ChatWriter{
  564. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  565. stream: req.Stream,
  566. id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
  567. }
  568. c.Writer = w
  569. c.Next()
  570. }
  571. }