openai.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. // openai package provides middleware for partial compatibility with the OpenAI REST API
  2. package openai
  3. import (
  4. "bytes"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "math/rand"
  9. "net/http"
  10. "time"
  11. "github.com/gin-gonic/gin"
  12. "github.com/ollama/ollama/api"
  13. "github.com/ollama/ollama/types/model"
  14. )
  15. type Error struct {
  16. Message string `json:"message"`
  17. Type string `json:"type"`
  18. Param interface{} `json:"param"`
  19. Code *string `json:"code"`
  20. }
  21. type ErrorResponse struct {
  22. Error Error `json:"error"`
  23. }
  24. type Message struct {
  25. Role string `json:"role"`
  26. Content string `json:"content"`
  27. }
  28. type Choice struct {
  29. Index int `json:"index"`
  30. Message Message `json:"message"`
  31. FinishReason *string `json:"finish_reason"`
  32. }
  33. type ChunkChoice struct {
  34. Index int `json:"index"`
  35. Delta Message `json:"delta"`
  36. FinishReason *string `json:"finish_reason"`
  37. }
  38. type Usage struct {
  39. PromptTokens int `json:"prompt_tokens"`
  40. CompletionTokens int `json:"completion_tokens"`
  41. TotalTokens int `json:"total_tokens"`
  42. }
  43. type ResponseFormat struct {
  44. Type string `json:"type"`
  45. }
  46. type ChatCompletionRequest struct {
  47. Model string `json:"model"`
  48. Messages []Message `json:"messages"`
  49. Stream bool `json:"stream"`
  50. MaxTokens *int `json:"max_tokens"`
  51. Seed *int `json:"seed"`
  52. Stop any `json:"stop"`
  53. Temperature *float64 `json:"temperature"`
  54. FrequencyPenalty *float64 `json:"frequency_penalty"`
  55. PresencePenalty *float64 `json:"presence_penalty_penalty"`
  56. TopP *float64 `json:"top_p"`
  57. ResponseFormat *ResponseFormat `json:"response_format"`
  58. }
  59. type ChatCompletion struct {
  60. Id string `json:"id"`
  61. Object string `json:"object"`
  62. Created int64 `json:"created"`
  63. Model string `json:"model"`
  64. SystemFingerprint string `json:"system_fingerprint"`
  65. Choices []Choice `json:"choices"`
  66. Usage Usage `json:"usage,omitempty"`
  67. }
  68. type ChatCompletionChunk struct {
  69. Id string `json:"id"`
  70. Object string `json:"object"`
  71. Created int64 `json:"created"`
  72. Model string `json:"model"`
  73. SystemFingerprint string `json:"system_fingerprint"`
  74. Choices []ChunkChoice `json:"choices"`
  75. }
  76. type Model struct {
  77. Id string `json:"id"`
  78. Object string `json:"object"`
  79. Created int64 `json:"created"`
  80. OwnedBy string `json:"owned_by"`
  81. }
  82. type ListCompletion struct {
  83. Object string `json:"object"`
  84. Data []Model `json:"data"`
  85. }
  86. func NewError(code int, message string) ErrorResponse {
  87. var etype string
  88. switch code {
  89. case http.StatusBadRequest:
  90. etype = "invalid_request_error"
  91. case http.StatusNotFound:
  92. etype = "not_found_error"
  93. default:
  94. etype = "api_error"
  95. }
  96. return ErrorResponse{Error{Type: etype, Message: message}}
  97. }
  98. func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
  99. return ChatCompletion{
  100. Id: id,
  101. Object: "chat.completion",
  102. Created: r.CreatedAt.Unix(),
  103. Model: r.Model,
  104. SystemFingerprint: "fp_ollama",
  105. Choices: []Choice{{
  106. Index: 0,
  107. Message: Message{Role: r.Message.Role, Content: r.Message.Content},
  108. FinishReason: func(reason string) *string {
  109. if len(reason) > 0 {
  110. return &reason
  111. }
  112. return nil
  113. }(r.DoneReason),
  114. }},
  115. Usage: Usage{
  116. // TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
  117. PromptTokens: r.PromptEvalCount,
  118. CompletionTokens: r.EvalCount,
  119. TotalTokens: r.PromptEvalCount + r.EvalCount,
  120. },
  121. }
  122. }
  123. func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
  124. return ChatCompletionChunk{
  125. Id: id,
  126. Object: "chat.completion.chunk",
  127. Created: time.Now().Unix(),
  128. Model: r.Model,
  129. SystemFingerprint: "fp_ollama",
  130. Choices: []ChunkChoice{{
  131. Index: 0,
  132. Delta: Message{Role: "assistant", Content: r.Message.Content},
  133. FinishReason: func(reason string) *string {
  134. if len(reason) > 0 {
  135. return &reason
  136. }
  137. return nil
  138. }(r.DoneReason),
  139. }},
  140. }
  141. }
  142. func toListCompletion(r api.ListResponse) ListCompletion {
  143. var data []Model
  144. for _, m := range r.Models {
  145. data = append(data, Model{
  146. Id: m.Name,
  147. Object: "model",
  148. Created: m.ModifiedAt.Unix(),
  149. OwnedBy: model.ParseName(m.Name).Namespace,
  150. })
  151. }
  152. return ListCompletion{
  153. Object: "list",
  154. Data: data,
  155. }
  156. }
  157. func toModel(r api.ShowResponse, m string) Model {
  158. return Model{
  159. Id: m,
  160. Object: "model",
  161. Created: r.ModifiedAt.Unix(),
  162. OwnedBy: model.ParseName(m).Namespace,
  163. }
  164. }
  165. func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
  166. var messages []api.Message
  167. for _, msg := range r.Messages {
  168. messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
  169. }
  170. options := make(map[string]interface{})
  171. switch stop := r.Stop.(type) {
  172. case string:
  173. options["stop"] = []string{stop}
  174. case []interface{}:
  175. var stops []string
  176. for _, s := range stop {
  177. if str, ok := s.(string); ok {
  178. stops = append(stops, str)
  179. }
  180. }
  181. options["stop"] = stops
  182. }
  183. if r.MaxTokens != nil {
  184. options["num_predict"] = *r.MaxTokens
  185. }
  186. if r.Temperature != nil {
  187. options["temperature"] = *r.Temperature * 2.0
  188. } else {
  189. options["temperature"] = 1.0
  190. }
  191. if r.Seed != nil {
  192. options["seed"] = *r.Seed
  193. }
  194. if r.FrequencyPenalty != nil {
  195. options["frequency_penalty"] = *r.FrequencyPenalty * 2.0
  196. }
  197. if r.PresencePenalty != nil {
  198. options["presence_penalty"] = *r.PresencePenalty * 2.0
  199. }
  200. if r.TopP != nil {
  201. options["top_p"] = *r.TopP
  202. } else {
  203. options["top_p"] = 1.0
  204. }
  205. var format string
  206. if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" {
  207. format = "json"
  208. }
  209. return api.ChatRequest{
  210. Model: r.Model,
  211. Messages: messages,
  212. Format: format,
  213. Options: options,
  214. Stream: &r.Stream,
  215. }
  216. }
  217. type BaseWriter struct {
  218. gin.ResponseWriter
  219. }
  220. type ChatWriter struct {
  221. stream bool
  222. id string
  223. BaseWriter
  224. }
  225. type ListWriter struct {
  226. BaseWriter
  227. }
  228. type RetrieveWriter struct {
  229. BaseWriter
  230. model string
  231. }
  232. func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
  233. var serr api.StatusError
  234. err := json.Unmarshal(data, &serr)
  235. if err != nil {
  236. return 0, err
  237. }
  238. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  239. err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
  240. if err != nil {
  241. return 0, err
  242. }
  243. return len(data), nil
  244. }
  245. func (w *ChatWriter) writeResponse(data []byte) (int, error) {
  246. var chatResponse api.ChatResponse
  247. err := json.Unmarshal(data, &chatResponse)
  248. if err != nil {
  249. return 0, err
  250. }
  251. // chat chunk
  252. if w.stream {
  253. d, err := json.Marshal(toChunk(w.id, chatResponse))
  254. if err != nil {
  255. return 0, err
  256. }
  257. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  258. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  259. if err != nil {
  260. return 0, err
  261. }
  262. if chatResponse.Done {
  263. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  264. if err != nil {
  265. return 0, err
  266. }
  267. }
  268. return len(data), nil
  269. }
  270. // chat completion
  271. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  272. err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
  273. if err != nil {
  274. return 0, err
  275. }
  276. return len(data), nil
  277. }
  278. func (w *ChatWriter) Write(data []byte) (int, error) {
  279. code := w.ResponseWriter.Status()
  280. if code != http.StatusOK {
  281. return w.writeError(code, data)
  282. }
  283. return w.writeResponse(data)
  284. }
  285. func (w *ListWriter) writeResponse(data []byte) (int, error) {
  286. var listResponse api.ListResponse
  287. err := json.Unmarshal(data, &listResponse)
  288. if err != nil {
  289. return 0, err
  290. }
  291. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  292. err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
  293. if err != nil {
  294. return 0, err
  295. }
  296. return len(data), nil
  297. }
  298. func (w *ListWriter) Write(data []byte) (int, error) {
  299. code := w.ResponseWriter.Status()
  300. if code != http.StatusOK {
  301. return w.writeError(code, data)
  302. }
  303. return w.writeResponse(data)
  304. }
  305. func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
  306. var showResponse api.ShowResponse
  307. err := json.Unmarshal(data, &showResponse)
  308. if err != nil {
  309. return 0, err
  310. }
  311. // retrieve completion
  312. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  313. err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
  314. if err != nil {
  315. return 0, err
  316. }
  317. return len(data), nil
  318. }
  319. func (w *RetrieveWriter) Write(data []byte) (int, error) {
  320. code := w.ResponseWriter.Status()
  321. if code != http.StatusOK {
  322. return w.writeError(code, data)
  323. }
  324. return w.writeResponse(data)
  325. }
  326. func ListMiddleware() gin.HandlerFunc {
  327. return func(c *gin.Context) {
  328. w := &ListWriter{
  329. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  330. }
  331. c.Writer = w
  332. c.Next()
  333. }
  334. }
  335. func RetrieveMiddleware() gin.HandlerFunc {
  336. return func(c *gin.Context) {
  337. var b bytes.Buffer
  338. if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
  339. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  340. return
  341. }
  342. c.Request.Body = io.NopCloser(&b)
  343. // response writer
  344. w := &RetrieveWriter{
  345. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  346. model: c.Param("model"),
  347. }
  348. c.Writer = w
  349. c.Next()
  350. }
  351. }
  352. func ChatMiddleware() gin.HandlerFunc {
  353. return func(c *gin.Context) {
  354. var req ChatCompletionRequest
  355. err := c.ShouldBindJSON(&req)
  356. if err != nil {
  357. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  358. return
  359. }
  360. if len(req.Messages) == 0 {
  361. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
  362. return
  363. }
  364. var b bytes.Buffer
  365. if err := json.NewEncoder(&b).Encode(fromChatRequest(req)); err != nil {
  366. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  367. return
  368. }
  369. c.Request.Body = io.NopCloser(&b)
  370. w := &ChatWriter{
  371. BaseWriter: BaseWriter{ResponseWriter: c.Writer},
  372. stream: req.Stream,
  373. id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
  374. }
  375. c.Writer = w
  376. c.Next()
  377. }
  378. }