Roy Han 10 ay önce
ebeveyn
işleme
1daac52651
1 değiştirilmiş dosya ile 29 ekleme ve 0 silme
  1. 29 0
      server/routes.go

+ 29 - 0
server/routes.go

@@ -394,6 +394,21 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 		return
 	}
 
+	truncate := func(s string) (string, error) {
+		tokens, err := runner.llama.Tokenize(c.Request.Context(), s)
+		if err != nil {
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return "", err
+		}
+
+		if len(tokens) > opts.NumCtx {
+			tokens = tokens[len(tokens)-opts.NumCtx:]
+			return runner.llama.Detokenize(c.Request.Context(), tokens)
+		}
+
+		return s, nil
+	}
+
 	embeddings := [][]float64{}
 
 	switch reqEmbed := req.Input.(type) {
@@ -402,6 +417,13 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 			c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}})
 			return
 		}
+		if *req.Truncate {
+			reqEmbed, err = truncate(reqEmbed)
+			if err != nil {
+				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+				return
+			}
+		}
 		embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
 	case []any:
 		if reqEmbed == nil {
@@ -412,6 +434,13 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 		reqEmbedArray := make([]string, len(reqEmbed))
 		for i, v := range reqEmbed {
 			if s, ok := v.(string); ok {
+				if *req.Truncate {
+					s, err = truncate(s)
+					if err != nil {
+						c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+						return
+					}
+				}
 				reqEmbedArray[i] = s
 			} else {
 				c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})