浏览代码

input handling and handler testing

Roy Han 10 月之前
父节点
当前提交
922b8f2584
共有 2 个文件被更改,包括 76 次插入19 次删除
  1. 27 18
      server/routes.go
  2. 49 1
      server/routes_test.go

+ 27 - 18
server/routes.go

@@ -392,6 +392,29 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 		sessionDuration = req.KeepAlive.Duration
 	}
 
+	switch reqEmbed := req.Input.(type) {
+	case string:
+		if reqEmbed == "" {
+			c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
+			return
+		}
+	case []any:
+		if reqEmbed == nil {
+			c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
+			return
+		}
+
+		for _, v := range reqEmbed {
+			if _, ok := v.(string); !ok {
+				c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
+				return
+			}
+		}
+	default:
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
+		return
+	}
+
 	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
 	var runner *runnerRef
 	select {
@@ -424,10 +447,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 
 	switch reqEmbed := req.Input.(type) {
 	case string:
-		if reqEmbed == "" {
-			c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float32{}})
-			return
-		}
 		reqEmbed, err = checkFit(reqEmbed, *req.Truncate)
 		if err != nil {
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -435,24 +454,14 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 		}
 		embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
 	case []any:
-		if reqEmbed == nil {
-			c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float32{}})
-			return
-		}
-
 		reqEmbedArray := make([]string, len(reqEmbed))
 		for i, v := range reqEmbed {
-			if s, ok := v.(string); ok {
-				s, err = checkFit(s, *req.Truncate)
-				if err != nil {
-					c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-					return
-				}
-				reqEmbedArray[i] = s
-			} else {
-				c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
+			s, err := checkFit(v.(string), *req.Truncate)
+			if err != nil {
+				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 				return
 			}
+			reqEmbedArray[i] = s
 		}
 		embeddings, err = runner.llama.Embed(c.Request.Context(), reqEmbedArray)
 	default:

+ 49 - 1
server/routes_test.go

@@ -273,6 +273,54 @@ func Test_Routes(t *testing.T) {
 				assert.Equal(t, "library", retrieveResp.OwnedBy)
 			},
 		},
+		{
+			Name:   "Embed Handler Empty Input",
+			Method: http.MethodPost,
+			Path:   "/api/embed",
+			Setup: func(t *testing.T, req *http.Request) {
+				embedReq := api.EmbedRequest{
+					Model: "t-bone",
+					Input: "",
+				}
+				jsonData, err := json.Marshal(embedReq)
+				require.NoError(t, err)
+				req.Body = io.NopCloser(bytes.NewReader(jsonData))
+			},
+			Expected: func(t *testing.T, resp *http.Response) {
+				contentType := resp.Header.Get("Content-Type")
+				assert.Equal(t, "application/json; charset=utf-8", contentType)
+				body, err := io.ReadAll(resp.Body)
+				require.NoError(t, err)
+
+				var embedResp api.EmbedResponse
+				err = json.Unmarshal(body, &embedResp)
+				require.NoError(t, err)
+
+				assert.Equal(t, "t-bone", embedResp.Model)
+				assert.Nil(t, embedResp.Embeddings)
+			},
+		},
+		{
+			Name:   "Embed Handler Invalid Input",
+			Method: http.MethodPost,
+			Path:   "/api/embed",
+			Setup: func(t *testing.T, req *http.Request) {
+				embedReq := api.EmbedRequest{
+					Model: "t-bone",
+					Input: 2,
+				}
+				jsonData, err := json.Marshal(embedReq)
+				require.NoError(t, err)
+				req.Body = io.NopCloser(bytes.NewReader(jsonData))
+			},
+			Expected: func(t *testing.T, resp *http.Response) {
+				contentType := resp.Header.Get("Content-Type")
+				assert.Equal(t, "application/json; charset=utf-8", contentType)
+				_, err := io.ReadAll(resp.Body)
+				require.NoError(t, err)
+				assert.Equal(t, 400, resp.StatusCode)
+			},
+		},
 	}
 
 	t.Setenv("OLLAMA_MODELS", t.TempDir())
@@ -454,5 +502,5 @@ func TestNormalize(t *testing.T) {
 				t.Errorf("Vector %v is not normalized", tc.input)
 			}
 		})
-  }
+	}
 }