Browse Source

OpenAI: /v1/embeddings compatibility (#5285)

* OpenAI v1 models

* Empty List Testing

* Add back envconfig

* v1/models docs

* Remove Docs

* OpenAI batch embed compatibility

* merge conflicts

* integrate with api/embed

* ep

* merge conflicts

* request tests

* rm resp test

* merge conflict

* merge conflict

* test fixes

* test fn renaming

* input validation for empty string

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
royjhan 9 months ago
parent
commit
987dbab0b0
3 changed files with 184 additions and 0 deletions
  1. 111 0
      openai/openai.go
  2. 72 0
      openai/openai_test.go
  3. 1 0
      server/routes.go

+ 111 - 0
openai/openai.go

@@ -61,6 +61,11 @@ type ResponseFormat struct {
 	Type string `json:"type"`
 }
 
+type EmbedRequest struct {
+	Input any    `json:"input"`
+	Model string `json:"model"`
+}
+
 type ChatCompletionRequest struct {
 	Model            string          `json:"model"`
 	Messages         []Message       `json:"messages"`
@@ -134,11 +139,23 @@ type Model struct {
 	OwnedBy string `json:"owned_by"`
 }
 
+type Embedding struct {
+	Object    string    `json:"object"`
+	Embedding []float32 `json:"embedding"`
+	Index     int       `json:"index"`
+}
+
 type ListCompletion struct {
 	Object string  `json:"object"`
 	Data   []Model `json:"data"`
 }
 
+type EmbeddingList struct {
+	Object string      `json:"object"`
+	Data   []Embedding `json:"data"`
+	Model  string      `json:"model"`
+}
+
 func NewError(code int, message string) ErrorResponse {
 	var etype string
 	switch code {
@@ -262,6 +279,27 @@ func toListCompletion(r api.ListResponse) ListCompletion {
 	}
 }
 
+func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
+	if r.Embeddings != nil {
+		var data []Embedding
+		for i, e := range r.Embeddings {
+			data = append(data, Embedding{
+				Object:    "embedding",
+				Embedding: e,
+				Index:     i,
+			})
+		}
+
+		return EmbeddingList{
+			Object: "list",
+			Data:   data,
+			Model:  model,
+		}
+	}
+
+	return EmbeddingList{}
+}
+
 func toModel(r api.ShowResponse, m string) Model {
 	return Model{
 		Id:      m,
@@ -465,6 +503,11 @@ type RetrieveWriter struct {
 	model string
 }
 
+type EmbedWriter struct {
+	BaseWriter
+	model string
+}
+
 func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
 	var serr api.StatusError
 	err := json.Unmarshal(data, &serr)
@@ -630,6 +673,33 @@ func (w *RetrieveWriter) Write(data []byte) (int, error) {
 	return w.writeResponse(data)
 }
 
+func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
+	var embedResponse api.EmbedResponse
+	err := json.Unmarshal(data, &embedResponse)
+
+	if err != nil {
+		return 0, err
+	}
+
+	w.ResponseWriter.Header().Set("Content-Type", "application/json")
+	err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
+
+	if err != nil {
+		return 0, err
+	}
+
+	return len(data), nil
+}
+
+func (w *EmbedWriter) Write(data []byte) (int, error) {
+	code := w.ResponseWriter.Status()
+	if code != http.StatusOK {
+		return w.writeError(code, data)
+	}
+
+	return w.writeResponse(data)
+}
+
 func ListMiddleware() gin.HandlerFunc {
 	return func(c *gin.Context) {
 		w := &ListWriter{
@@ -693,6 +763,47 @@ func CompletionsMiddleware() gin.HandlerFunc {
 			id:         fmt.Sprintf("cmpl-%d", rand.Intn(999)),
 		}
 
+		c.Writer = w
+		c.Next()
+	}
+}
+
+func EmbeddingsMiddleware() gin.HandlerFunc {
+	return func(c *gin.Context) {
+		var req EmbedRequest
+		err := c.ShouldBindJSON(&req)
+		if err != nil {
+			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
+			return
+		}
+
+		if req.Input == "" {
+			req.Input = []string{""}
+		}
+
+		if req.Input == nil {
+			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
+			return
+		}
+
+		if v, ok := req.Input.([]any); ok && len(v) == 0 {
+			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
+			return
+		}
+
+		var b bytes.Buffer
+		if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
+			c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
+			return
+		}
+
+		c.Request.Body = io.NopCloser(&b)
+
+		w := &EmbedWriter{
+			BaseWriter: BaseWriter{ResponseWriter: c.Writer},
+			model:      req.Model,
+		}
+
 		c.Writer = w
 
 		c.Next()

+ 72 - 0
openai/openai_test.go

@@ -161,6 +161,78 @@ func TestMiddlewareRequests(t *testing.T) {
 				}
 			},
 		},
+		{
+			Name:    "embed handler single input",
+			Method:  http.MethodPost,
+			Path:    "/api/embed",
+			Handler: EmbeddingsMiddleware,
+			Setup: func(t *testing.T, req *http.Request) {
+				body := EmbedRequest{
+					Input: "Hello",
+					Model: "test-model",
+				}
+
+				bodyBytes, _ := json.Marshal(body)
+
+				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+				req.Header.Set("Content-Type", "application/json")
+			},
+			Expected: func(t *testing.T, req *http.Request) {
+				var embedReq api.EmbedRequest
+				if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
+					t.Fatal(err)
+				}
+
+				if embedReq.Input != "Hello" {
+					t.Fatalf("expected 'Hello', got %s", embedReq.Input)
+				}
+
+				if embedReq.Model != "test-model" {
+					t.Fatalf("expected 'test-model', got %s", embedReq.Model)
+				}
+			},
+		},
+		{
+			Name:    "embed handler batch input",
+			Method:  http.MethodPost,
+			Path:    "/api/embed",
+			Handler: EmbeddingsMiddleware,
+			Setup: func(t *testing.T, req *http.Request) {
+				body := EmbedRequest{
+					Input: []string{"Hello", "World"},
+					Model: "test-model",
+				}
+
+				bodyBytes, _ := json.Marshal(body)
+
+				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+				req.Header.Set("Content-Type", "application/json")
+			},
+			Expected: func(t *testing.T, req *http.Request) {
+				var embedReq api.EmbedRequest
+				if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
+					t.Fatal(err)
+				}
+
+				input, ok := embedReq.Input.([]any)
+
+				if !ok {
+					t.Fatalf("expected input to be a list")
+				}
+
+				if input[0].(string) != "Hello" {
+					t.Fatalf("expected 'Hello', got %s", input[0])
+				}
+
+				if input[1].(string) != "World" {
+					t.Fatalf("expected 'World', got %s", input[1])
+				}
+
+				if embedReq.Model != "test-model" {
+					t.Fatalf("expected 'test-model', got %s", embedReq.Model)
+				}
+			},
+		},
 	}
 
 	gin.SetMode(gin.TestMode)

+ 1 - 0
server/routes.go

@@ -1064,6 +1064,7 @@ func (s *Server) GenerateRoutes() http.Handler {
 	// Compatibility endpoints
 	r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
 	r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
+	r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
 	r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
 	r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)