Roy Han преди 9 месеца
родител
ревизия
dbe9527305
променени са 1 файла, в които са добавени 20 реда и са изтрити 44 реда
  1. 20 44
      server/routes.go

+ 20 - 44
server/routes.go

@@ -8,7 +8,6 @@ import (
 	"errors"
 	"fmt"
 	"io"
-	"io/fs"
 	"log/slog"
 	"math"
 	"net"
@@ -260,33 +259,11 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 		return
 	}
 
-	if req.Model == "" {
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
-		return
-	}
-
 	if req.Truncate == nil {
 		truncate := true
 		req.Truncate = &truncate
 	}
 
-	model, err := GetModel(req.Model)
-	if err != nil {
-		var pErr *fs.PathError
-		if errors.As(err, &pErr) {
-			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
-			return
-		}
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
-	}
-
-	opts, err := modelOptions(model, req.Options)
-	if err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
-	}
-
 	reqEmbed := []string{}
 
 	switch embeddings := req.Input.(type) {
@@ -314,41 +291,40 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 		return
 	}
 
-	r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
+	r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
 	if err != nil {
 		handleScheduleError(c, req.Model, err)
 		return
 	}
 
-	kvData, err := getKVData(model.ModelPath, false)
+	kvData, err := getKVData(m.ModelPath, false)
 	if err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 	}
 
-	checkFit := func(s string, truncate bool) (string, error) {
-		tokens, err := r.Tokenize(c.Request.Context(), s)
-		if err != nil {
-			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-			return "", err
-		}
+	reqEmbedArray := make([]string, len(reqEmbed))
+	for i, v := range reqEmbed {
+		s, err := func(v string, truncate bool) (string, error) {
+			tokens, err := r.Tokenize(c.Request.Context(), v)
+			if err != nil {
+				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+				return "", err
+			}
 
-		ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
-		if len(tokens) > ctxLen {
-			if truncate {
-				tokens = tokens[:ctxLen]
-				return r.Detokenize(c.Request.Context(), tokens)
-			} else {
-				return "", fmt.Errorf("input length exceeds maximum context length")
+			ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
+			if len(tokens) > ctxLen {
+				if truncate {
+					tokens = tokens[:ctxLen]
+					return r.Detokenize(c.Request.Context(), tokens)
+				} else {
+					return "", fmt.Errorf("input length exceeds maximum context length")
+				}
 			}
-		}
 
-		return s, nil
-	}
+			return v, nil
+		}(v, *req.Truncate)
 
-	reqEmbedArray := make([]string, len(reqEmbed))
-	for i, v := range reqEmbed {
-		s, err := checkFit(v, *req.Truncate)
 		if err != nil {
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 			return