Sfoglia il codice sorgente

OpenAI: /v1/models and /v1/models/{model} compatibility (#5007)

* OpenAI v1 models

* Refactor Writers

* Add Test

Co-Authored-By: Attila Kerekes

* Credit Co-Author

Co-Authored-By: Attila Kerekes <439392+keriati@users.noreply.github.com>

* Empty List Testing

* Use Namespace for Ownedby

* Update Test

* Add back envconfig

* v1/models docs

* Use ModelName Parser

* Test Names

* Remove Docs

* Clean Up

* Test name

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>

* Add Middleware for Chat and List

* Testing Cleanup

* Test with Fatal

* Add functionality to chat test

* OpenAI: /v1/models/{model} compatibility (#5028)

* Retrieve Model

* OpenAI Delete Model

* Retrieve Middleware

* Remove Delete from Branch

* Update Test

* Middleware Test File

* Function name

* Cleanup

* Test Update

* Test Update

---------

Co-authored-by: Attila Kerekes <439392+keriati@users.noreply.github.com>
Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
royjhan 10 mesi fa
parent
commit
996bb1b85e
6 ha cambiato i file con 386 aggiunte e 13 eliminazioni
  1. 7 0
      api/types.go
  2. 1 0
      docs/openai.md
  3. 149 12
      openai/openai.go
  4. 170 0
      openai/openai_test.go
  5. 3 1
      server/routes.go
  6. 56 0
      server/routes_test.go

+ 7 - 0
api/types.go

@@ -345,6 +345,13 @@ type ProcessModelResponse struct {
 	SizeVRAM  int64        `json:"size_vram"`
 }
 
+type RetrieveModelResponse struct {
+	Id      string `json:"id"`
+	Object  string `json:"object"`
+	Created int64  `json:"created"`
+	OwnedBy string `json:"owned_by"`
+}
+
 type TokenResponse struct {
 	Token string `json:"token"`
 }

+ 1 - 0
docs/openai.md

@@ -65,6 +65,7 @@ curl http://localhost:11434/v1/chat/completions \
             }
         ]
     }'
+
 ```
 
 ## Endpoints

+ 149 - 12
openai/openai.go

@@ -12,6 +12,7 @@ import (
 
 	"github.com/gin-gonic/gin"
 	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/types/model"
 )
 
 type Error struct {
@@ -85,6 +86,18 @@ type ChatCompletionChunk struct {
 	Choices           []ChunkChoice `json:"choices"`
 }
 
+type Model struct {
+	Id      string `json:"id"`
+	Object  string `json:"object"`
+	Created int64  `json:"created"`
+	OwnedBy string `json:"owned_by"`
+}
+
+type ListCompletion struct {
+	Object string  `json:"object"`
+	Data   []Model `json:"data"`
+}
+
 func NewError(code int, message string) ErrorResponse {
 	var etype string
 	switch code {
@@ -145,7 +158,33 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
 	}
 }
 
-func fromRequest(r ChatCompletionRequest) api.ChatRequest {
+func toListCompletion(r api.ListResponse) ListCompletion {
+	var data []Model
+	for _, m := range r.Models {
+		data = append(data, Model{
+			Id:      m.Name,
+			Object:  "model",
+			Created: m.ModifiedAt.Unix(),
+			OwnedBy: model.ParseName(m.Name).Namespace,
+		})
+	}
+
+	return ListCompletion{
+		Object: "list",
+		Data:   data,
+	}
+}
+
+func toModel(r api.ShowResponse, m string) Model {
+	return Model{
+		Id:      m,
+		Object:  "model",
+		Created: r.ModifiedAt.Unix(),
+		OwnedBy: model.ParseName(m).Namespace,
+	}
+}
+
+func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
 	var messages []api.Message
 	for _, msg := range r.Messages {
 		messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
@@ -208,13 +247,26 @@ func fromRequest(r ChatCompletionRequest) api.ChatRequest {
 	}
 }
 
-type writer struct {
+type BaseWriter struct {
+	gin.ResponseWriter
+}
+
+type ChatWriter struct {
 	stream bool
 	id     string
-	gin.ResponseWriter
+	BaseWriter
 }
 
-func (w *writer) writeError(code int, data []byte) (int, error) {
+type ListWriter struct {
+	BaseWriter
+}
+
+type RetrieveWriter struct {
+	BaseWriter
+	model string
+}
+
+func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
 	var serr api.StatusError
 	err := json.Unmarshal(data, &serr)
 	if err != nil {
@@ -230,7 +282,7 @@ func (w *writer) writeError(code int, data []byte) (int, error) {
 	return len(data), nil
 }
 
-func (w *writer) writeResponse(data []byte) (int, error) {
+func (w *ChatWriter) writeResponse(data []byte) (int, error) {
 	var chatResponse api.ChatResponse
 	err := json.Unmarshal(data, &chatResponse)
 	if err != nil {
@@ -270,7 +322,7 @@ func (w *writer) writeResponse(data []byte) (int, error) {
 	return len(data), nil
 }
 
-func (w *writer) Write(data []byte) (int, error) {
+func (w *ChatWriter) Write(data []byte) (int, error) {
 	code := w.ResponseWriter.Status()
 	if code != http.StatusOK {
 		return w.writeError(code, data)
@@ -279,7 +331,92 @@ func (w *writer) Write(data []byte) (int, error) {
 	return w.writeResponse(data)
 }
 
-func Middleware() gin.HandlerFunc {
+func (w *ListWriter) writeResponse(data []byte) (int, error) {
+	var listResponse api.ListResponse
+	err := json.Unmarshal(data, &listResponse)
+	if err != nil {
+		return 0, err
+	}
+
+	w.ResponseWriter.Header().Set("Content-Type", "application/json")
+	err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
+	if err != nil {
+		return 0, err
+	}
+
+	return len(data), nil
+}
+
+func (w *ListWriter) Write(data []byte) (int, error) {
+	code := w.ResponseWriter.Status()
+	if code != http.StatusOK {
+		return w.writeError(code, data)
+	}
+
+	return w.writeResponse(data)
+}
+
+func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
+	var showResponse api.ShowResponse
+	err := json.Unmarshal(data, &showResponse)
+	if err != nil {
+		return 0, err
+	}
+
+	// retrieve completion
+	w.ResponseWriter.Header().Set("Content-Type", "application/json")
+	err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
+	if err != nil {
+		return 0, err
+	}
+
+	return len(data), nil
+}
+
+func (w *RetrieveWriter) Write(data []byte) (int, error) {
+	code := w.ResponseWriter.Status()
+	if code != http.StatusOK {
+		return w.writeError(code, data)
+	}
+
+	return w.writeResponse(data)
+}
+
+func ListMiddleware() gin.HandlerFunc {
+	return func(c *gin.Context) {
+		w := &ListWriter{
+			BaseWriter: BaseWriter{ResponseWriter: c.Writer},
+		}
+
+		c.Writer = w
+
+		c.Next()
+	}
+}
+
+func RetrieveMiddleware() gin.HandlerFunc {
+	return func(c *gin.Context) {
+		var b bytes.Buffer
+		if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
+			c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
+			return
+		}
+
+		c.Request.Body = io.NopCloser(&b)
+
+		// response writer
+		w := &RetrieveWriter{
+			BaseWriter: BaseWriter{ResponseWriter: c.Writer},
+			model:      c.Param("model"),
+		}
+
+		c.Writer = w
+
+		c.Next()
+	}
+}
+
+func ChatMiddleware() gin.HandlerFunc {
 	return func(c *gin.Context) {
 		var req ChatCompletionRequest
 		err := c.ShouldBindJSON(&req)
@@ -294,17 +431,17 @@ func Middleware() gin.HandlerFunc {
 		}
 
 		var b bytes.Buffer
-		if err := json.NewEncoder(&b).Encode(fromRequest(req)); err != nil {
+		if err := json.NewEncoder(&b).Encode(fromChatRequest(req)); err != nil {
 			c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
 			return
 		}
 
 		c.Request.Body = io.NopCloser(&b)
 
-		w := &writer{
-			ResponseWriter: c.Writer,
-			stream:         req.Stream,
-			id:             fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
+		w := &ChatWriter{
+			BaseWriter: BaseWriter{ResponseWriter: c.Writer},
+			stream:     req.Stream,
+			id:         fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
 		}
 
 		c.Writer = w

+ 170 - 0
openai/openai_test.go

@@ -0,0 +1,170 @@
+package openai
+
+import (
+	"bytes"
+	"encoding/json"
+	"io"
+	"net/http"
+	"net/http/httptest"
+	"testing"
+	"time"
+
+	"github.com/gin-gonic/gin"
+	"github.com/ollama/ollama/api"
+	"github.com/stretchr/testify/assert"
+)
+
+func TestMiddleware(t *testing.T) {
+	type testCase struct {
+		Name     string
+		Method   string
+		Path     string
+		TestPath string
+		Handler  func() gin.HandlerFunc
+		Endpoint func(c *gin.Context)
+		Setup    func(t *testing.T, req *http.Request)
+		Expected func(t *testing.T, resp *httptest.ResponseRecorder)
+	}
+
+	testCases := []testCase{
+		{
+			Name:     "chat handler",
+			Method:   http.MethodPost,
+			Path:     "/api/chat",
+			TestPath: "/api/chat",
+			Handler:  ChatMiddleware,
+			Endpoint: func(c *gin.Context) {
+				var chatReq api.ChatRequest
+				if err := c.ShouldBindJSON(&chatReq); err != nil {
+					c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
+					return
+				}
+
+				userMessage := chatReq.Messages[0].Content
+				var assistantMessage string
+
+				switch userMessage {
+				case "Hello":
+					assistantMessage = "Hello!"
+				default:
+					assistantMessage = "I'm not sure how to respond to that."
+				}
+
+				c.JSON(http.StatusOK, api.ChatResponse{
+					Message: api.Message{
+						Role:    "assistant",
+						Content: assistantMessage,
+					},
+				})
+			},
+			Setup: func(t *testing.T, req *http.Request) {
+				body := ChatCompletionRequest{
+					Model:    "test-model",
+					Messages: []Message{{Role: "user", Content: "Hello"}},
+				}
+
+				bodyBytes, _ := json.Marshal(body)
+
+				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+				req.Header.Set("Content-Type", "application/json")
+			},
+			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
+				var chatResp ChatCompletion
+				if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
+					t.Fatal(err)
+				}
+
+				if chatResp.Object != "chat.completion" {
+					t.Fatalf("expected chat.completion, got %s", chatResp.Object)
+				}
+
+				if chatResp.Choices[0].Message.Content != "Hello!" {
+					t.Fatalf("expected Hello!, got %s", chatResp.Choices[0].Message.Content)
+				}
+			},
+		},
+		{
+			Name:     "list handler",
+			Method:   http.MethodGet,
+			Path:     "/api/tags",
+			TestPath: "/api/tags",
+			Handler:  ListMiddleware,
+			Endpoint: func(c *gin.Context) {
+				c.JSON(http.StatusOK, api.ListResponse{
+					Models: []api.ListModelResponse{
+						{
+							Name: "Test Model",
+						},
+					},
+				})
+			},
+			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
+				var listResp ListCompletion
+				if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
+					t.Fatal(err)
+				}
+
+				if listResp.Object != "list" {
+					t.Fatalf("expected list, got %s", listResp.Object)
+				}
+
+				if len(listResp.Data) != 1 {
+					t.Fatalf("expected 1, got %d", len(listResp.Data))
+				}
+
+				if listResp.Data[0].Id != "Test Model" {
+					t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
+				}
+			},
+		},
+		{
+			Name:     "retrieve model",
+			Method:   http.MethodGet,
+			Path:     "/api/show/:model",
+			TestPath: "/api/show/test-model",
+			Handler:  RetrieveMiddleware,
+			Endpoint: func(c *gin.Context) {
+				c.JSON(http.StatusOK, api.ShowResponse{
+					ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
+				})
+			},
+			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
+				var retrieveResp Model
+				if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
+					t.Fatal(err)
+				}
+
+				if retrieveResp.Object != "model" {
+					t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
+				}
+
+				if retrieveResp.Id != "test-model" {
+					t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
+				}
+			},
+		},
+	}
+
+	gin.SetMode(gin.TestMode)
+	router := gin.New()
+
+	for _, tc := range testCases {
+		t.Run(tc.Name, func(t *testing.T) {
+			router = gin.New()
+			router.Use(tc.Handler())
+			router.Handle(tc.Method, tc.Path, tc.Endpoint)
+			req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
+
+			if tc.Setup != nil {
+				tc.Setup(t, req)
+			}
+
+			resp := httptest.NewRecorder()
+			router.ServeHTTP(resp, req)
+
+			assert.Equal(t, http.StatusOK, resp.Code)
+
+			tc.Expected(t, resp)
+		})
+	}
+}

+ 3 - 1
server/routes.go

@@ -1039,7 +1039,9 @@ func (s *Server) GenerateRoutes() http.Handler {
 	r.GET("/api/ps", s.ProcessHandler)
 
 	// Compatibility endpoints
-	r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler)
+	r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
+	r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
+	r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
 
 	for _, method := range []string{http.MethodGet, http.MethodHead} {
 		r.Handle(method, "/", func(c *gin.Context) {

+ 56 - 0
server/routes_test.go

@@ -20,6 +20,7 @@ import (
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/envconfig"
 	"github.com/ollama/ollama/llm"
+	"github.com/ollama/ollama/openai"
 	"github.com/ollama/ollama/parser"
 	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/version"
@@ -105,6 +106,24 @@ func Test_Routes(t *testing.T) {
 				assert.Empty(t, len(modelList.Models))
 			},
 		},
+		{
+			Name:   "openai empty list",
+			Method: http.MethodGet,
+			Path:   "/v1/models",
+			Expected: func(t *testing.T, resp *http.Response) {
+				contentType := resp.Header.Get("Content-Type")
+				assert.Equal(t, "application/json", contentType)
+				body, err := io.ReadAll(resp.Body)
+				require.NoError(t, err)
+
+				var modelList openai.ListCompletion
+				err = json.Unmarshal(body, &modelList)
+				require.NoError(t, err)
+
+				assert.Equal(t, "list", modelList.Object)
+				assert.Empty(t, modelList.Data)
+			},
+		},
 		{
 			Name:   "Tags Handler (yes tags)",
 			Method: http.MethodGet,
@@ -128,6 +147,25 @@ func Test_Routes(t *testing.T) {
 				assert.Equal(t, "test-model:latest", modelList.Models[0].Name)
 			},
 		},
+		{
+			Name:   "openai list models with tags",
+			Method: http.MethodGet,
+			Path:   "/v1/models",
+			Expected: func(t *testing.T, resp *http.Response) {
+				contentType := resp.Header.Get("Content-Type")
+				assert.Equal(t, "application/json", contentType)
+				body, err := io.ReadAll(resp.Body)
+				require.NoError(t, err)
+
+				var modelList openai.ListCompletion
+				err = json.Unmarshal(body, &modelList)
+				require.NoError(t, err)
+
+				assert.Len(t, modelList.Data, 1)
+				assert.Equal(t, "test-model:latest", modelList.Data[0].Id)
+				assert.Equal(t, "library", modelList.Data[0].OwnedBy)
+			},
+		},
 		{
 			Name:   "Create Model Handler",
 			Method: http.MethodPost,
@@ -216,6 +254,24 @@ func Test_Routes(t *testing.T) {
 				assert.InDelta(t, 0, showResp.ModelInfo["general.parameter_count"], 1e-9, "Parameter count should be 0")
 			},
 		},
+		{
+			Name:   "openai retrieve model handler",
+			Method: http.MethodGet,
+			Path:   "/v1/models/show-model",
+			Expected: func(t *testing.T, resp *http.Response) {
+				contentType := resp.Header.Get("Content-Type")
+				assert.Equal(t, "application/json", contentType)
+				body, err := io.ReadAll(resp.Body)
+				require.NoError(t, err)
+
+				var retrieveResp api.RetrieveModelResponse
+				err = json.Unmarshal(body, &retrieveResp)
+				require.NoError(t, err)
+
+				assert.Equal(t, "show-model", retrieveResp.Id)
+				assert.Equal(t, "library", retrieveResp.OwnedBy)
+			},
+		},
 	}
 
 	t.Setenv("OLLAMA_MODELS", t.TempDir())