openai.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. // openai package provides middleware for partial compatibility with the OpenAI REST API
  2. package openai
  3. import (
  4. "bytes"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "math/rand"
  9. "net/http"
  10. "time"
  11. "github.com/gin-gonic/gin"
  12. "github.com/ollama/ollama/api"
  13. )
  14. type Error struct {
  15. Message string `json:"message"`
  16. Type string `json:"type"`
  17. Param interface{} `json:"param"`
  18. Code *string `json:"code"`
  19. }
  20. type ErrorResponse struct {
  21. Error Error `json:"error"`
  22. }
  23. type Message struct {
  24. Role string `json:"role"`
  25. Content string `json:"content"`
  26. }
  27. type Choice struct {
  28. Index int `json:"index"`
  29. Message Message `json:"message"`
  30. FinishReason *string `json:"finish_reason"`
  31. }
  32. type ChunkChoice struct {
  33. Index int `json:"index"`
  34. Delta Message `json:"delta"`
  35. FinishReason *string `json:"finish_reason"`
  36. }
  37. type Usage struct {
  38. PromptTokens int `json:"prompt_tokens"`
  39. CompletionTokens int `json:"completion_tokens"`
  40. TotalTokens int `json:"total_tokens"`
  41. }
  42. type ResponseFormat struct {
  43. Type string `json:"type"`
  44. }
  45. type ChatCompletionRequest struct {
  46. Model string `json:"model"`
  47. Messages []Message `json:"messages"`
  48. Stream bool `json:"stream"`
  49. MaxTokens *int `json:"max_tokens"`
  50. Seed *int `json:"seed"`
  51. Stop any `json:"stop"`
  52. Temperature *float64 `json:"temperature"`
  53. FrequencyPenalty *float64 `json:"frequency_penalty"`
  54. PresencePenalty *float64 `json:"presence_penalty_penalty"`
  55. TopP *float64 `json:"top_p"`
  56. ResponseFormat *ResponseFormat `json:"response_format"`
  57. }
  58. type ChatCompletion struct {
  59. Id string `json:"id"`
  60. Object string `json:"object"`
  61. Created int64 `json:"created"`
  62. Model string `json:"model"`
  63. SystemFingerprint string `json:"system_fingerprint"`
  64. Choices []Choice `json:"choices"`
  65. Usage Usage `json:"usage,omitempty"`
  66. }
  67. type ChatCompletionChunk struct {
  68. Id string `json:"id"`
  69. Object string `json:"object"`
  70. Created int64 `json:"created"`
  71. Model string `json:"model"`
  72. SystemFingerprint string `json:"system_fingerprint"`
  73. Choices []ChunkChoice `json:"choices"`
  74. }
  75. func NewError(code int, message string) ErrorResponse {
  76. var etype string
  77. switch code {
  78. case http.StatusBadRequest:
  79. etype = "invalid_request_error"
  80. case http.StatusNotFound:
  81. etype = "not_found_error"
  82. default:
  83. etype = "api_error"
  84. }
  85. return ErrorResponse{Error{Type: etype, Message: message}}
  86. }
  87. func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
  88. return ChatCompletion{
  89. Id: id,
  90. Object: "chat.completion",
  91. Created: r.CreatedAt.Unix(),
  92. Model: r.Model,
  93. SystemFingerprint: "fp_ollama",
  94. Choices: []Choice{{
  95. Index: 0,
  96. Message: Message{Role: r.Message.Role, Content: r.Message.Content},
  97. FinishReason: func(done bool) *string {
  98. if done {
  99. reason := "stop"
  100. return &reason
  101. }
  102. return nil
  103. }(r.Done),
  104. }},
  105. Usage: Usage{
  106. // TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
  107. PromptTokens: r.PromptEvalCount,
  108. CompletionTokens: r.EvalCount,
  109. TotalTokens: r.PromptEvalCount + r.EvalCount,
  110. },
  111. }
  112. }
  113. func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
  114. return ChatCompletionChunk{
  115. Id: id,
  116. Object: "chat.completion.chunk",
  117. Created: time.Now().Unix(),
  118. Model: r.Model,
  119. SystemFingerprint: "fp_ollama",
  120. Choices: []ChunkChoice{
  121. {
  122. Index: 0,
  123. Delta: Message{Role: "assistant", Content: r.Message.Content},
  124. FinishReason: func(done bool) *string {
  125. if done {
  126. reason := "stop"
  127. return &reason
  128. }
  129. return nil
  130. }(r.Done),
  131. },
  132. },
  133. }
  134. }
  135. func fromRequest(r ChatCompletionRequest) api.ChatRequest {
  136. var messages []api.Message
  137. for _, msg := range r.Messages {
  138. messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
  139. }
  140. options := make(map[string]interface{})
  141. switch stop := r.Stop.(type) {
  142. case string:
  143. options["stop"] = []string{stop}
  144. case []interface{}:
  145. var stops []string
  146. for _, s := range stop {
  147. if str, ok := s.(string); ok {
  148. stops = append(stops, str)
  149. }
  150. }
  151. options["stop"] = stops
  152. }
  153. if r.MaxTokens != nil {
  154. options["num_predict"] = *r.MaxTokens
  155. }
  156. if r.Temperature != nil {
  157. options["temperature"] = *r.Temperature * 2.0
  158. } else {
  159. options["temperature"] = 1.0
  160. }
  161. if r.Seed != nil {
  162. options["seed"] = *r.Seed
  163. // temperature=0 is required for reproducible outputs
  164. options["temperature"] = 0.0
  165. }
  166. if r.FrequencyPenalty != nil {
  167. options["frequency_penalty"] = *r.FrequencyPenalty * 2.0
  168. }
  169. if r.PresencePenalty != nil {
  170. options["presence_penalty"] = *r.PresencePenalty * 2.0
  171. }
  172. if r.TopP != nil {
  173. options["top_p"] = *r.TopP
  174. } else {
  175. options["top_p"] = 1.0
  176. }
  177. var format string
  178. if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" {
  179. format = "json"
  180. }
  181. return api.ChatRequest{
  182. Model: r.Model,
  183. Messages: messages,
  184. Format: format,
  185. Options: options,
  186. Stream: &r.Stream,
  187. }
  188. }
  189. type writer struct {
  190. stream bool
  191. id string
  192. gin.ResponseWriter
  193. }
  194. func (w *writer) writeError(code int, data []byte) (int, error) {
  195. var serr api.StatusError
  196. err := json.Unmarshal(data, &serr)
  197. if err != nil {
  198. return 0, err
  199. }
  200. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  201. err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
  202. if err != nil {
  203. return 0, err
  204. }
  205. return len(data), nil
  206. }
  207. func (w *writer) writeResponse(data []byte) (int, error) {
  208. var chatResponse api.ChatResponse
  209. err := json.Unmarshal(data, &chatResponse)
  210. if err != nil {
  211. return 0, err
  212. }
  213. // chat chunk
  214. if w.stream {
  215. d, err := json.Marshal(toChunk(w.id, chatResponse))
  216. if err != nil {
  217. return 0, err
  218. }
  219. w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
  220. _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
  221. if err != nil {
  222. return 0, err
  223. }
  224. if chatResponse.Done {
  225. _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
  226. if err != nil {
  227. return 0, err
  228. }
  229. }
  230. return len(data), nil
  231. }
  232. // chat completion
  233. w.ResponseWriter.Header().Set("Content-Type", "application/json")
  234. err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
  235. if err != nil {
  236. return 0, err
  237. }
  238. return len(data), nil
  239. }
  240. func (w *writer) Write(data []byte) (int, error) {
  241. code := w.ResponseWriter.Status()
  242. if code != http.StatusOK {
  243. return w.writeError(code, data)
  244. }
  245. return w.writeResponse(data)
  246. }
  247. func Middleware() gin.HandlerFunc {
  248. return func(c *gin.Context) {
  249. var req ChatCompletionRequest
  250. err := c.ShouldBindJSON(&req)
  251. if err != nil {
  252. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
  253. return
  254. }
  255. if len(req.Messages) == 0 {
  256. c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
  257. return
  258. }
  259. var b bytes.Buffer
  260. if err := json.NewEncoder(&b).Encode(fromRequest(req)); err != nil {
  261. c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
  262. return
  263. }
  264. c.Request.Body = io.NopCloser(&b)
  265. w := &writer{
  266. ResponseWriter: c.Writer,
  267. stream: req.Stream,
  268. id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
  269. }
  270. c.Writer = w
  271. c.Next()
  272. }
  273. }