Selaa lähdekoodia

Load Embedding Model on Empty Input (#6325)

* load on empty input

* no load on invalid input
royjhan 8 kuukautta sitten
vanhempi
commit
8b00a415ab
2 muutettua tiedostoa jossa 9 lisäystä ja 77 poistoa
  1. 9 7
      server/routes.go
  2. 0 70
      server/routes_test.go

+ 9 - 7
server/routes.go

@@ -324,13 +324,10 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 			input = append(input, v.(string))
 		}
 	default:
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
-		return
-	}
-
-	if len(input) == 0 {
-		c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
-		return
+		if req.Input != nil {
+			c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
+			return
+		}
 	}
 
 	r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
@@ -341,6 +338,11 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 
 	checkpointLoaded := time.Now()
 
+	if len(input) == 0 {
+		c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
+		return
+	}
+
 	kvData, err := getKVData(m.ModelPath, false)
 	if err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})

+ 0 - 70
server/routes_test.go

@@ -272,76 +272,6 @@ func Test_Routes(t *testing.T) {
 				assert.Equal(t, "library", retrieveResp.OwnedBy)
 			},
 		},
-		{
-			Name:   "Embed Handler Empty Input",
-			Method: http.MethodPost,
-			Path:   "/api/embed",
-			Setup: func(t *testing.T, req *http.Request) {
-				embedReq := api.EmbedRequest{
-					Model: "t-bone",
-					Input: "",
-				}
-				jsonData, err := json.Marshal(embedReq)
-				require.NoError(t, err)
-				req.Body = io.NopCloser(bytes.NewReader(jsonData))
-			},
-			Expected: func(t *testing.T, resp *http.Response) {
-				contentType := resp.Header.Get("Content-Type")
-				if contentType != "application/json; charset=utf-8" {
-					t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
-				}
-				body, err := io.ReadAll(resp.Body)
-				if err != nil {
-					t.Fatal(err)
-				}
-
-				var embedResp api.EmbedResponse
-				err = json.Unmarshal(body, &embedResp)
-				if err != nil {
-					t.Fatal(err)
-				}
-
-				if embedResp.Model != "t-bone" {
-					t.Fatalf("expected model t-bone, got %s", embedResp.Model)
-				}
-
-				if embedResp.Embeddings == nil {
-					t.Fatalf("expected embeddings to not be nil, got %v", embedResp.Embeddings)
-				}
-
-				if len(embedResp.Embeddings) != 0 {
-					t.Fatalf("expected embeddings to be empty, got %v", embedResp.Embeddings)
-				}
-			},
-		},
-		{
-			Name:   "Embed Handler Invalid Input",
-			Method: http.MethodPost,
-			Path:   "/api/embed",
-			Setup: func(t *testing.T, req *http.Request) {
-				embedReq := api.EmbedRequest{
-					Model: "t-bone",
-					Input: 2,
-				}
-				jsonData, err := json.Marshal(embedReq)
-				require.NoError(t, err)
-				req.Body = io.NopCloser(bytes.NewReader(jsonData))
-			},
-			Expected: func(t *testing.T, resp *http.Response) {
-				contentType := resp.Header.Get("Content-Type")
-				if contentType != "application/json; charset=utf-8" {
-					t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
-				}
-				_, err := io.ReadAll(resp.Body)
-				if err != nil {
-					t.Fatal(err)
-				}
-
-				if resp.StatusCode != http.StatusBadRequest {
-					t.Fatalf("expected status code 400, got %d", resp.StatusCode)
-				}
-			},
-		},
 	}
 
 	t.Setenv("OLLAMA_MODELS", t.TempDir())