Roy Han hai 9 meses
pai
achega
424f3f81a9
Modificáronse 1 ficheiros con 29 adicións e 25 borrados
  1. 29 25
      server/routes.go

+ 29 - 25
server/routes.go

@@ -259,38 +259,42 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
+	truncate := true
+
+	if req.Truncate != nil && !*req.Truncate {
+		truncate = false
+	}
+
 	if req.Truncate == nil {
 	if req.Truncate == nil {
 		truncate := true
 		truncate := true
 		req.Truncate = &truncate
 		req.Truncate = &truncate
 	}
 	}
 
 
-	reqEmbed := []string{}
+	var input []string
 
 
-	switch embeddings := req.Input.(type) {
+	switch i := req.Input.(type) {
 	case string:
 	case string:
-		if embeddings == "" {
-			c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
-			return
+		if len(i) > 0 {
+			input = append(input, i)
 		}
 		}
-		reqEmbed = []string{embeddings}
 	case []any:
 	case []any:
-		if len(embeddings) == 0 {
-			c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
-			return
-		}
-
-		for _, v := range embeddings {
+		for _, v := range i {
 			if _, ok := v.(string); !ok {
 			if _, ok := v.(string); !ok {
 				c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
 				c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
 				return
 				return
 			}
 			}
-			reqEmbed = append(reqEmbed, v.(string))
+			input = append(input, v.(string))
 		}
 		}
 	default:
 	default:
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
 		return
 		return
 	}
 	}
 
 
+	if len(input) == 0 {
+		c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
+		return
+	}
+
 	r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
 	r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
 	if err != nil {
 	if err != nil {
 		handleScheduleError(c, req.Model, err)
 		handleScheduleError(c, req.Model, err)
@@ -303,8 +307,8 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
-	reqEmbedArray := make([]string, len(reqEmbed))
-	for i, s := range reqEmbed {
+	reqEmbedArray := make([]string, len(input))
+	for i, s := range input {
 		tokens, err := r.Tokenize(c.Request.Context(), s)
 		tokens, err := r.Tokenize(c.Request.Context(), s)
 		if err != nil {
 		if err != nil {
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -313,17 +317,17 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 
 
 		ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
 		ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
 		if len(tokens) > ctxLen {
 		if len(tokens) > ctxLen {
-			if *req.Truncate {
-				tokens = tokens[:ctxLen]
-				s, err = r.Detokenize(c.Request.Context(), tokens)
-				if err != nil {
-					c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-					return
-				}
-			} else {
+			if !truncate {
 				c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
 				c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
 				return
 				return
 			}
 			}
+
+			tokens = tokens[:ctxLen]
+			s, err = r.Detokenize(c.Request.Context(), tokens)
+			if err != nil {
+				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+				return
+			}
 		}
 		}
 
 
 		reqEmbedArray[i] = s
 		reqEmbedArray[i] = s
@@ -331,7 +335,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 	embeddings, err := r.Embed(c.Request.Context(), reqEmbedArray)
 	embeddings, err := r.Embed(c.Request.Context(), reqEmbedArray)
 
 
 	if err != nil {
 	if err != nil {
-		slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
+		slog.Error("embedding generation failed", "error", err)
 		c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
 		c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
 		return
 		return
 	}
 	}
@@ -1030,7 +1034,7 @@ func (s *Server) GenerateRoutes() http.Handler {
 	r.POST("/api/generate", s.GenerateHandler)
 	r.POST("/api/generate", s.GenerateHandler)
 	r.POST("/api/chat", s.ChatHandler)
 	r.POST("/api/chat", s.ChatHandler)
 	r.POST("/api/embed", s.EmbedHandler)
 	r.POST("/api/embed", s.EmbedHandler)
-	r.POST("/api/embeddings", s.EmbeddingsHandler) // legacy
+	r.POST("/api/embeddings", s.EmbeddingsHandler)
 	r.POST("/api/create", s.CreateModelHandler)
 	r.POST("/api/create", s.CreateModelHandler)
 	r.POST("/api/push", s.PushModelHandler)
 	r.POST("/api/push", s.PushModelHandler)
 	r.POST("/api/copy", s.CopyModelHandler)
 	r.POST("/api/copy", s.CopyModelHandler)