Browse Source

OpenAI: v1/completions compatibility (#5209)

* OpenAI v1 models

* Refactor Writers

* Add Test

Co-Authored-By: Attila Kerekes

* Credit Co-Author

Co-Authored-By: Attila Kerekes <439392+keriati@users.noreply.github.com>

* Empty List Testing

* Use Namespace for Ownedby

* Update Test

* Add back envconfig

* v1/models docs

* Use ModelName Parser

* Test Names

* Remove Docs

* Clean Up

* Test name

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>

* Add Middleware for Chat and List

* Completions Endpoint

* Testing Cleanup

* Test with Fatal

* Add functionality to chat test

* Rename function

* float types

* type cleanup

* cleaning

* more cleaning

* Extra test cases

* merge conflicts

* merge conflicts

* merge conflicts

* merge conflicts

* cleaning

* cleaning

---------

Co-authored-by: Attila Kerekes <439392+keriati@users.noreply.github.com>
Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
royjhan 10 months ago
parent
commit
d626b99b54
3 changed files with 353 additions and 3 deletions
  1. 222 1
      openai/openai.go
  2. 130 2
      openai/openai_test.go
  3. 1 0
      server/routes.go

+ 222 - 1
openai/openai.go

@@ -43,6 +43,12 @@ type ChunkChoice struct {
 	FinishReason *string `json:"finish_reason"`
 }
 
+type CompleteChunkChoice struct {
+	Text         string  `json:"text"`
+	Index        int     `json:"index"`
+	FinishReason *string `json:"finish_reason"`
+}
+
 type Usage struct {
 	PromptTokens     int `json:"prompt_tokens"`
 	CompletionTokens int `json:"completion_tokens"`
@@ -86,6 +92,39 @@ type ChatCompletionChunk struct {
 	Choices           []ChunkChoice `json:"choices"`
 }
 
+// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
+type CompletionRequest struct {
+	Model            string   `json:"model"`
+	Prompt           string   `json:"prompt"`
+	FrequencyPenalty float32  `json:"frequency_penalty"`
+	MaxTokens        *int     `json:"max_tokens"`
+	PresencePenalty  float32  `json:"presence_penalty"`
+	Seed             *int     `json:"seed"`
+	Stop             any      `json:"stop"`
+	Stream           bool     `json:"stream"`
+	Temperature      *float32 `json:"temperature"`
+	TopP             float32  `json:"top_p"`
+}
+
+type Completion struct {
+	Id                string                `json:"id"`
+	Object            string                `json:"object"`
+	Created           int64                 `json:"created"`
+	Model             string                `json:"model"`
+	SystemFingerprint string                `json:"system_fingerprint"`
+	Choices           []CompleteChunkChoice `json:"choices"`
+	Usage             Usage                 `json:"usage,omitempty"`
+}
+
+type CompletionChunk struct {
+	Id                string                `json:"id"`
+	Object            string                `json:"object"`
+	Created           int64                 `json:"created"`
+	Choices           []CompleteChunkChoice `json:"choices"`
+	Model             string                `json:"model"`
+	SystemFingerprint string                `json:"system_fingerprint"`
+}
+
 type Model struct {
 	Id      string `json:"id"`
 	Object  string `json:"object"`
@@ -158,6 +197,52 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
 	}
 }
 
+func toCompletion(id string, r api.GenerateResponse) Completion {
+	return Completion{
+		Id:                id,
+		Object:            "text_completion",
+		Created:           r.CreatedAt.Unix(),
+		Model:             r.Model,
+		SystemFingerprint: "fp_ollama",
+		Choices: []CompleteChunkChoice{{
+			Text:  r.Response,
+			Index: 0,
+			FinishReason: func(reason string) *string {
+				if len(reason) > 0 {
+					return &reason
+				}
+				return nil
+			}(r.DoneReason),
+		}},
+		Usage: Usage{
+			// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
+			PromptTokens:     r.PromptEvalCount,
+			CompletionTokens: r.EvalCount,
+			TotalTokens:      r.PromptEvalCount + r.EvalCount,
+		},
+	}
+}
+
+func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
+	return CompletionChunk{
+		Id:                id,
+		Object:            "text_completion",
+		Created:           time.Now().Unix(),
+		Model:             r.Model,
+		SystemFingerprint: "fp_ollama",
+		Choices: []CompleteChunkChoice{{
+			Text:  r.Response,
+			Index: 0,
+			FinishReason: func(reason string) *string {
+				if len(reason) > 0 {
+					return &reason
+				}
+				return nil
+			}(r.DoneReason),
+		}},
+	}
+}
+
 func toListCompletion(r api.ListResponse) ListCompletion {
 	var data []Model
 	for _, m := range r.Models {
@@ -195,7 +280,7 @@ func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
 	switch stop := r.Stop.(type) {
 	case string:
 		options["stop"] = []string{stop}
-	case []interface{}:
+	case []any:
 		var stops []string
 		for _, s := range stop {
 			if str, ok := s.(string); ok {
@@ -247,6 +332,52 @@ func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
 	}
 }
 
+func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
+	options := make(map[string]any)
+
+	switch stop := r.Stop.(type) {
+	case string:
+		options["stop"] = []string{stop}
+	case []string:
+		options["stop"] = stop
+	default:
+		if r.Stop != nil {
+			return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", r.Stop)
+		}
+	}
+
+	if r.MaxTokens != nil {
+		options["num_predict"] = *r.MaxTokens
+	}
+
+	if r.Temperature != nil {
+		options["temperature"] = *r.Temperature * 2.0
+	} else {
+		options["temperature"] = 1.0
+	}
+
+	if r.Seed != nil {
+		options["seed"] = *r.Seed
+	}
+
+	options["frequency_penalty"] = r.FrequencyPenalty * 2.0
+
+	options["presence_penalty"] = r.PresencePenalty * 2.0
+
+	if r.TopP != 0.0 {
+		options["top_p"] = r.TopP
+	} else {
+		options["top_p"] = 1.0
+	}
+
+	return api.GenerateRequest{
+		Model:   r.Model,
+		Prompt:  r.Prompt,
+		Options: options,
+		Stream:  &r.Stream,
+	}, nil
+}
+
 type BaseWriter struct {
 	gin.ResponseWriter
 }
@@ -257,6 +388,12 @@ type ChatWriter struct {
 	BaseWriter
 }
 
+type CompleteWriter struct {
+	stream bool
+	id     string
+	BaseWriter
+}
+
 type ListWriter struct {
 	BaseWriter
 }
@@ -331,6 +468,55 @@ func (w *ChatWriter) Write(data []byte) (int, error) {
 	return w.writeResponse(data)
 }
 
+func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
+	var generateResponse api.GenerateResponse
+	err := json.Unmarshal(data, &generateResponse)
+	if err != nil {
+		return 0, err
+	}
+
+	// completion chunk
+	if w.stream {
+		d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
+		if err != nil {
+			return 0, err
+		}
+
+		w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
+		_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
+		if err != nil {
+			return 0, err
+		}
+
+		if generateResponse.Done {
+			_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
+			if err != nil {
+				return 0, err
+			}
+		}
+
+		return len(data), nil
+	}
+
+	// completion
+	w.ResponseWriter.Header().Set("Content-Type", "application/json")
+	err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
+	if err != nil {
+		return 0, err
+	}
+
+	return len(data), nil
+}
+
+func (w *CompleteWriter) Write(data []byte) (int, error) {
+	code := w.ResponseWriter.Status()
+	if code != http.StatusOK {
+		return w.writeError(code, data)
+	}
+
+	return w.writeResponse(data)
+}
+
 func (w *ListWriter) writeResponse(data []byte) (int, error) {
 	var listResponse api.ListResponse
 	err := json.Unmarshal(data, &listResponse)
@@ -416,6 +602,41 @@ func RetrieveMiddleware() gin.HandlerFunc {
 	}
 }
 
+func CompletionsMiddleware() gin.HandlerFunc {
+	return func(c *gin.Context) {
+		var req CompletionRequest
+		err := c.ShouldBindJSON(&req)
+		if err != nil {
+			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
+			return
+		}
+
+		var b bytes.Buffer
+		genReq, err := fromCompleteRequest(req)
+		if err != nil {
+			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
+			return
+		}
+
+		if err := json.NewEncoder(&b).Encode(genReq); err != nil {
+			c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
+			return
+		}
+
+		c.Request.Body = io.NopCloser(&b)
+
+		w := &CompleteWriter{
+			BaseWriter: BaseWriter{ResponseWriter: c.Writer},
+			stream:     req.Stream,
+			id:         fmt.Sprintf("cmpl-%d", rand.Intn(999)),
+		}
+
+		c.Writer = w
+
+		c.Next()
+	}
+}
+
 func ChatMiddleware() gin.HandlerFunc {
 	return func(c *gin.Context) {
 		var req ChatCompletionRequest

+ 130 - 2
openai/openai_test.go

@@ -3,9 +3,11 @@ package openai
 import (
 	"bytes"
 	"encoding/json"
+	"fmt"
 	"io"
 	"net/http"
 	"net/http/httptest"
+	"strings"
 	"testing"
 	"time"
 
@@ -69,6 +71,8 @@ func TestMiddleware(t *testing.T) {
 				req.Header.Set("Content-Type", "application/json")
 			},
 			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
+				assert.Equal(t, http.StatusOK, resp.Code)
+
 				var chatResp ChatCompletion
 				if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
 					t.Fatal(err)
@@ -83,6 +87,130 @@ func TestMiddleware(t *testing.T) {
 				}
 			},
 		},
+		{
+			Name:     "completions handler",
+			Method:   http.MethodPost,
+			Path:     "/api/generate",
+			TestPath: "/api/generate",
+			Handler:  CompletionsMiddleware,
+			Endpoint: func(c *gin.Context) {
+				c.JSON(http.StatusOK, api.GenerateResponse{
+					Response: "Hello!",
+				})
+			},
+			Setup: func(t *testing.T, req *http.Request) {
+				body := CompletionRequest{
+					Model:  "test-model",
+					Prompt: "Hello",
+				}
+
+				bodyBytes, _ := json.Marshal(body)
+
+				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+				req.Header.Set("Content-Type", "application/json")
+			},
+			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
+				assert.Equal(t, http.StatusOK, resp.Code)
+				var completionResp Completion
+				if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
+					t.Fatal(err)
+				}
+
+				if completionResp.Object != "text_completion" {
+					t.Fatalf("expected text_completion, got %s", completionResp.Object)
+				}
+
+				if completionResp.Choices[0].Text != "Hello!" {
+					t.Fatalf("expected Hello!, got %s", completionResp.Choices[0].Text)
+				}
+			},
+		},
+		{
+			Name:     "completions handler with params",
+			Method:   http.MethodPost,
+			Path:     "/api/generate",
+			TestPath: "/api/generate",
+			Handler:  CompletionsMiddleware,
+			Endpoint: func(c *gin.Context) {
+				var generateReq api.GenerateRequest
+				if err := c.ShouldBindJSON(&generateReq); err != nil {
+					c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
+					return
+				}
+
+				temperature := generateReq.Options["temperature"].(float64)
+				var assistantMessage string
+
+				switch temperature {
+				case 1.6:
+					assistantMessage = "Received temperature of 1.6"
+				default:
+					assistantMessage = fmt.Sprintf("Received temperature of %f", temperature)
+				}
+
+				c.JSON(http.StatusOK, api.GenerateResponse{
+					Response: assistantMessage,
+				})
+			},
+			Setup: func(t *testing.T, req *http.Request) {
+				temp := float32(0.8)
+				body := CompletionRequest{
+					Model:       "test-model",
+					Prompt:      "Hello",
+					Temperature: &temp,
+				}
+
+				bodyBytes, _ := json.Marshal(body)
+
+				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+				req.Header.Set("Content-Type", "application/json")
+			},
+			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
+				assert.Equal(t, http.StatusOK, resp.Code)
+				var completionResp Completion
+				if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
+					t.Fatal(err)
+				}
+
+				if completionResp.Object != "text_completion" {
+					t.Fatalf("expected text_completion, got %s", completionResp.Object)
+				}
+
+				if completionResp.Choices[0].Text != "Received temperature of 1.6" {
+					t.Fatalf("expected Received temperature of 1.6, got %s", completionResp.Choices[0].Text)
+				}
+			},
+		},
+		{
+			Name:     "completions handler with error",
+			Method:   http.MethodPost,
+			Path:     "/api/generate",
+			TestPath: "/api/generate",
+			Handler:  CompletionsMiddleware,
+			Endpoint: func(c *gin.Context) {
+				c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
+			},
+			Setup: func(t *testing.T, req *http.Request) {
+				body := CompletionRequest{
+					Model:  "test-model",
+					Prompt: "Hello",
+				}
+
+				bodyBytes, _ := json.Marshal(body)
+
+				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+				req.Header.Set("Content-Type", "application/json")
+			},
+			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
+				if resp.Code != http.StatusBadRequest {
+					t.Fatalf("expected 400, got %d", resp.Code)
+				}
+
+				if !strings.Contains(resp.Body.String(), `"invalid request"`) {
+					t.Fatalf("error was not forwarded")
+				}
+			},
+		},
 		{
 			Name:     "list handler",
 			Method:   http.MethodGet,
@@ -99,6 +227,8 @@ func TestMiddleware(t *testing.T) {
 				})
 			},
 			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
+				assert.Equal(t, http.StatusOK, resp.Code)
+
 				var listResp ListCompletion
 				if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
 					t.Fatal(err)
@@ -162,8 +292,6 @@ func TestMiddleware(t *testing.T) {
 			resp := httptest.NewRecorder()
 			router.ServeHTTP(resp, req)
 
-			assert.Equal(t, http.StatusOK, resp.Code)
-
 			tc.Expected(t, resp)
 		})
 	}

+ 1 - 0
server/routes.go

@@ -1054,6 +1054,7 @@ func (s *Server) GenerateRoutes() http.Handler {
 
 	// Compatibility endpoints
 	r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
+	r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
 	r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
 	r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)