Browse Source

Initial OpenAI `/v1/chat/completions` API compatibility (#2376)

Jeffrey Morgan 1 year ago
parent
commit
453f572f83
3 changed files with 466 additions and 0 deletions
  1. 140 0
      docs/openai.md
  2. 322 0
      openai/openai.go
  3. 4 0
      server/routes.go

+ 140 - 0
docs/openai.md

@@ -0,0 +1,140 @@
+# OpenAI compatibility
+
+Ollama provides experimental compatibility with parts of the [OpenAI API](https://platform.openai.com/docs/api-reference) to help connect existing applications to Ollama.
+
+> **Note:** OpenAI compatibility is experimental and is subject to major adjustments including breaking changes. For fully-featured access to the Ollama API, see the Ollama [Python library](https://github.com/ollama/ollama-python), [JavaScript library](https://github.com/ollama/ollama-js) and [REST API](https://github.com/jmorganca/ollama/blob/main/docs/api.md).
+
+## Usage
+
+### OpenAI Python library
+
+```python
+from openai import OpenAI
+
+client = OpenAI(
+    base_url='http://localhost:11434/v1/',
+
+    # required but ignored
+    api_key='ollama',
+)
+
+chat_completion = client.chat.completions.create(
+    messages=[
+        {
+            'role': 'user',
+            'content': 'Say this is a test',
+        }
+    ],
+    model='llama2',
+)
+```
+
+### OpenAI JavaScript library
+
+```javascript
+import OpenAI from 'openai'
+
+const openai = new OpenAI({
+  baseURL: 'http://localhost:11434/v1/',
+
+  // required but ignored
+  apiKey: 'ollama',
+})
+
+const chatCompletion = await openai.chat.completions.create({
+  messages: [{ role: 'user', content: 'Say this is a test' }],
+  model: 'llama2',
+})
+```
+
+### `curl`
+
+```
+curl http://localhost:11434/v1/chat/completions \
+    -H "Content-Type: application/json" \
+    -d '{
+        "model": "llama2",
+        "messages": [
+            {
+                "role": "system",
+                "content": "You are a helpful assistant."
+            },
+            {
+                "role": "user",
+                "content": "Hello!"
+            }
+        ]
+    }'
+```
+
+## Endpoints
+
+### `/v1/chat/completions`
+
+#### Supported features
+
+- [x] Chat completions
+- [x] Streaming
+- [x] JSON mode
+- [x] Reproducible outputs
+- [ ] Vision
+- [ ] Function calling
+- [ ] Logprobs
+
+#### Supported request fields
+
+- [x] `model`
+- [x] `messages`
+  - [x] Text `content`
+  - [ ] Array of `content` parts
+- [x] `frequency_penalty`
+- [x] `presence_penalty`
+- [x] `response_format`
+- [x] `seed`
+- [x] `stop`
+- [x] `stream`
+- [x] `temperature`
+- [x] `top_p`
+- [x] `max_tokens`
+- [ ] `logit_bias`
+- [ ] `tools`
+- [ ] `tool_choice`
+- [ ] `user`
+
+#### Notes
+
+- Setting `seed` will always set `temperature` to `0`
+- `finish_reason` will always be `stop`
+- `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
+
+## Models
+
+Before using a model, pull it locally `ollama pull`:
+
+```shell
+ollama pull llama2
+```
+
+### Default model names
+
+For tooling that relies on default OpenAI model names such as `gpt-3.5-turbo`, use `ollama cp` to copy an existing model name to a temporary name:
+
+```
+ollama cp llama2 gpt-3.5-turbo
+```
+
+Afterwards, this new model name can be specified the `model` field:
+
+```shell
+curl http://localhost:11434/v1/chat/completions \
+    -H "Content-Type: application/json" \
+    -d '{
+        "model": "gpt-3.5-turbo",
+        "messages": [
+            {
+                "role": "user",
+                "content": "Hello!"
+            }
+        ]
+    }'
+```

+ 322 - 0
openai/openai.go

@@ -0,0 +1,322 @@
+// openai package provides middleware for partial compatibility with the OpenAI REST API
+package openai
+
+import (
+	"bytes"
+	"encoding/json"
+	"fmt"
+	"io"
+	"math/rand"
+	"net/http"
+	"time"
+
+	"github.com/gin-gonic/gin"
+	"github.com/jmorganca/ollama/api"
+)
+
+type Error struct {
+	Message string      `json:"message"`
+	Type    string      `json:"type"`
+	Param   interface{} `json:"param"`
+	Code    *string     `json:"code"`
+}
+
+type ErrorResponse struct {
+	Error Error `json:"error"`
+}
+
+type Message struct {
+	Role    string `json:"role"`
+	Content string `json:"content"`
+}
+
+type Choice struct {
+	Index        int     `json:"index"`
+	Message      Message `json:"message"`
+	FinishReason *string `json:"finish_reason"`
+}
+
+type ChunkChoice struct {
+	Index        int     `json:"index"`
+	Delta        Message `json:"delta"`
+	FinishReason *string `json:"finish_reason"`
+}
+
+type Usage struct {
+	PromptTokens     int `json:"prompt_tokens"`
+	CompletionTokens int `json:"completion_tokens"`
+	TotalTokens      int `json:"total_tokens"`
+}
+
+type ResponseFormat struct {
+	Type string `json:"type"`
+}
+
+type ChatCompletionRequest struct {
+	Model            string          `json:"model"`
+	Messages         []Message       `json:"messages"`
+	Stream           bool            `json:"stream"`
+	MaxTokens        *int            `json:"max_tokens"`
+	Seed             *int            `json:"seed"`
+	Stop             any             `json:"stop"`
+	Temperature      *float64        `json:"temperature"`
+	FrequencyPenalty *float64        `json:"frequency_penalty"`
+	PresencePenalty  *float64        `json:"presence_penalty_penalty"`
+	TopP             *float64        `json:"top_p"`
+	ResponseFormat   *ResponseFormat `json:"response_format"`
+}
+
+type ChatCompletion struct {
+	Id                string   `json:"id"`
+	Object            string   `json:"object"`
+	Created           int64    `json:"created"`
+	Model             string   `json:"model"`
+	SystemFingerprint string   `json:"system_fingerprint"`
+	Choices           []Choice `json:"choices"`
+	Usage             Usage    `json:"usage,omitempty"`
+}
+
+type ChatCompletionChunk struct {
+	Id                string        `json:"id"`
+	Object            string        `json:"object"`
+	Created           int64         `json:"created"`
+	Model             string        `json:"model"`
+	SystemFingerprint string        `json:"system_fingerprint"`
+	Choices           []ChunkChoice `json:"choices"`
+}
+
+func NewError(code int, message string) ErrorResponse {
+	var etype string
+	switch code {
+	case http.StatusBadRequest:
+		etype = "invalid_request_error"
+	case http.StatusNotFound:
+		etype = "not_found_error"
+	default:
+		etype = "api_error"
+	}
+
+	return ErrorResponse{Error{Type: etype, Message: message}}
+}
+
+func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
+	return ChatCompletion{
+		Id:                id,
+		Object:            "chat.completion",
+		Created:           r.CreatedAt.Unix(),
+		Model:             r.Model,
+		SystemFingerprint: "fp_ollama",
+		Choices: []Choice{{
+			Index:   0,
+			Message: Message{Role: r.Message.Role, Content: r.Message.Content},
+			FinishReason: func(done bool) *string {
+				if done {
+					reason := "stop"
+					return &reason
+				}
+				return nil
+			}(r.Done),
+		}},
+		Usage: Usage{
+			// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
+			PromptTokens:     r.PromptEvalCount,
+			CompletionTokens: r.EvalCount,
+			TotalTokens:      r.PromptEvalCount + r.EvalCount,
+		},
+	}
+}
+
+func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
+	return ChatCompletionChunk{
+		Id:                id,
+		Object:            "chat.completion.chunk",
+		Created:           time.Now().Unix(),
+		Model:             r.Model,
+		SystemFingerprint: "fp_ollama",
+		Choices: []ChunkChoice{
+			{
+				Index: 0,
+				Delta: Message{Role: "assistant", Content: r.Message.Content},
+				FinishReason: func(done bool) *string {
+					if done {
+						reason := "stop"
+						return &reason
+					}
+					return nil
+				}(r.Done),
+			},
+		},
+	}
+}
+
+func fromRequest(r ChatCompletionRequest) api.ChatRequest {
+	var messages []api.Message
+	for _, msg := range r.Messages {
+		messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
+	}
+
+	options := make(map[string]interface{})
+
+	switch stop := r.Stop.(type) {
+	case string:
+		options["stop"] = []string{stop}
+	case []interface{}:
+		var stops []string
+		for _, s := range stop {
+			if str, ok := s.(string); ok {
+				stops = append(stops, str)
+			}
+		}
+		options["stop"] = stops
+	}
+
+	if r.MaxTokens != nil {
+		options["num_predict"] = *r.MaxTokens
+	}
+
+	if r.Temperature != nil {
+		options["temperature"] = *r.Temperature * 2.0
+	} else {
+		options["temperature"] = 1.0
+	}
+
+	if r.Seed != nil {
+		options["seed"] = *r.Seed
+
+		// temperature=0 is required for reproducible outputs
+		options["temperature"] = 0.0
+	}
+
+	if r.FrequencyPenalty != nil {
+		options["frequency_penalty"] = *r.FrequencyPenalty * 2.0
+	}
+
+	if r.PresencePenalty != nil {
+		options["presence_penalty"] = *r.PresencePenalty * 2.0
+	}
+
+	if r.TopP != nil {
+		options["top_p"] = *r.TopP
+	} else {
+		options["top_p"] = 1.0
+	}
+
+	var format string
+	if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" {
+		format = "json"
+	}
+
+	return api.ChatRequest{
+		Model:    r.Model,
+		Messages: messages,
+		Format:   format,
+		Options:  options,
+		Stream:   &r.Stream,
+	}
+}
+
+type writer struct {
+	stream bool
+	id     string
+	gin.ResponseWriter
+}
+
+func (w *writer) writeError(code int, data []byte) (int, error) {
+	var serr api.StatusError
+	err := json.Unmarshal(data, &serr)
+	if err != nil {
+		return 0, err
+	}
+
+	w.ResponseWriter.Header().Set("Content-Type", "application/json")
+	err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
+	if err != nil {
+		return 0, err
+	}
+
+	return len(data), nil
+}
+
+func (w *writer) writeResponse(data []byte) (int, error) {
+	var chatResponse api.ChatResponse
+	err := json.Unmarshal(data, &chatResponse)
+	if err != nil {
+		return 0, err
+	}
+
+	// chat chunk
+	if w.stream {
+		d, err := json.Marshal(toChunk(w.id, chatResponse))
+		if err != nil {
+			return 0, err
+
+		}
+
+		w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
+		_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
+		if err != nil {
+			return 0, err
+		}
+
+		if chatResponse.Done {
+			_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
+			if err != nil {
+				return 0, err
+			}
+		}
+
+		return len(data), nil
+	}
+
+	// chat completion
+	w.ResponseWriter.Header().Set("Content-Type", "application/json")
+	err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
+	if err != nil {
+		return 0, err
+	}
+
+	return len(data), nil
+}
+
+func (w *writer) Write(data []byte) (int, error) {
+	code := w.ResponseWriter.Status()
+	if code != http.StatusOK {
+		return w.writeError(code, data)
+	}
+
+	return w.writeResponse(data)
+}
+
+func Middleware() gin.HandlerFunc {
+	return func(c *gin.Context) {
+		var req ChatCompletionRequest
+		err := c.ShouldBindJSON(&req)
+		if err != nil {
+			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
+			return
+		}
+
+		if len(req.Messages) == 0 {
+			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
+			return
+		}
+
+		var b bytes.Buffer
+		if err := json.NewEncoder(&b).Encode(fromRequest(req)); err != nil {
+			c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
+			return
+		}
+
+		c.Request.Body = io.NopCloser(&b)
+
+		w := &writer{
+			ResponseWriter: c.Writer,
+			stream:         req.Stream,
+			id:             fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
+		}
+
+		c.Writer = w
+
+		c.Next()
+	}
+}

+ 4 - 0
server/routes.go

@@ -26,6 +26,7 @@ import (
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/gpu"
 	"github.com/jmorganca/ollama/llm"
+	"github.com/jmorganca/ollama/openai"
 	"github.com/jmorganca/ollama/parser"
 	"github.com/jmorganca/ollama/version"
 )
@@ -935,6 +936,9 @@ func (s *Server) GenerateRoutes() http.Handler {
 	r.POST("/api/blobs/:digest", CreateBlobHandler)
 	r.HEAD("/api/blobs/:digest", HeadBlobHandler)
 
+	// Compatibility endpoints
+	r.POST("/v1/chat/completions", openai.Middleware(), ChatHandler)
+
 	for _, method := range []string{http.MethodGet, http.MethodHead} {
 		r.Handle(method, "/", func(c *gin.Context) {
 			c.String(http.StatusOK, "Ollama is running")