Browse Source

refactoring

Roy Han 9 months ago
parent
commit
8f6d0242b6
2 changed files with 25 additions and 36 deletions
  1. 10 11
      llm/ext_server/server.cpp
  2. 15 25
      server/routes.go

+ 10 - 11
llm/ext_server/server.cpp

@@ -3202,19 +3202,18 @@ int main(int argc, char **argv) {
                     // get the result
                     // get the result
                     task_result result = llama.queue_results.recv(id_task);
                     task_result result = llama.queue_results.recv(id_task);
                     llama.queue_results.remove_waiting_task_id(id_task);
                     llama.queue_results.remove_waiting_task_id(id_task);
-                    if (!result.error) {
-                        responses = result.result_json.value("results", std::vector<json>{result.result_json});
-                        json embeddings = json::array();
-                        for (auto & elem : responses) {
-                            embeddings.push_back(elem.at("embedding"));
-                        }
-                        // send the result
-                        json result = json{{"embedding", embeddings}};
-                        return res.set_content(result.dump(), "application/json; charset=utf-8");
-                    } else {
-                        // return error
+                    if (result.error) {
                         return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
                         return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
                     }
                     }
+
+                    responses = result.result_json.value("results", std::vector<json>{result.result_json});
+                    json embeddings = json::array();
+                    for (auto & elem : responses) {
+                        embeddings.push_back(elem.at("embedding"));
+                    }
+                    // send the result
+                    json embedding_res = json{{"embedding", embeddings}};
+                    return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
                 }
                 }
             });
             });
 
 

+ 15 - 25
server/routes.go

@@ -10,6 +10,7 @@ import (
 	"io"
 	"io"
 	"io/fs"
 	"io/fs"
 	"log/slog"
 	"log/slog"
+	"math"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
 	"net/netip"
 	"net/netip"
@@ -21,7 +22,6 @@ import (
 	"syscall"
 	"syscall"
 	"time"
 	"time"
 
 
-	"github.com/chewxy/math32"
 	"github.com/gin-contrib/cors"
 	"github.com/gin-contrib/cors"
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 
 
@@ -287,23 +287,27 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
-	switch reqEmbed := req.Input.(type) {
+	reqEmbed := []string{}
+
+	switch embeddings := req.Input.(type) {
 	case string:
 	case string:
-		if reqEmbed == "" {
+		if embeddings == "" {
 			c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
 			c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
 			return
 			return
 		}
 		}
+		reqEmbed = []string{embeddings}
 	case []any:
 	case []any:
-		if reqEmbed == nil {
+		if len(embeddings) == 0 {
 			c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
 			c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
 			return
 			return
 		}
 		}
 
 
-		for _, v := range reqEmbed {
+		for _, v := range embeddings {
 			if _, ok := v.(string); !ok {
 			if _, ok := v.(string); !ok {
 				c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
 				c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
 				return
 				return
 			}
 			}
+			reqEmbed = append(reqEmbed, v.(string))
 		}
 		}
 	default:
 	default:
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
@@ -335,30 +339,16 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 		return s, nil
 		return s, nil
 	}
 	}
 
 
-	embeddings := [][]float32{}
-
-	switch reqEmbed := req.Input.(type) {
-	case string:
-		reqEmbed, err = checkFit(reqEmbed, *req.Truncate)
+	reqEmbedArray := make([]string, len(reqEmbed))
+	for i, v := range reqEmbed {
+		s, err := checkFit(v, *req.Truncate)
 		if err != nil {
 		if err != nil {
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 			return
 			return
 		}
 		}
-		embeddings, err = r.Embed(c.Request.Context(), []string{reqEmbed})
-	case []any:
-		reqEmbedArray := make([]string, len(reqEmbed))
-		for i, v := range reqEmbed {
-			s, err := checkFit(v.(string), *req.Truncate)
-			if err != nil {
-				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-				return
-			}
-			reqEmbedArray[i] = s
-		}
-		embeddings, err = r.Embed(c.Request.Context(), reqEmbedArray)
-	default:
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
+		reqEmbedArray[i] = s
 	}
 	}
+	embeddings, err := r.Embed(c.Request.Context(), reqEmbedArray)
 
 
 	if err != nil {
 	if err != nil {
 		slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
 		slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
@@ -385,7 +375,7 @@ func normalize(vec []float32) []float32 {
 
 
 	norm := float32(0.0)
 	norm := float32(0.0)
 	if sum > 0 {
 	if sum > 0 {
-		norm = float32(1.0 / math32.Sqrt(sum))
+		norm = float32(1.0 / math.Sqrt(float64(sum)))
 	}
 	}
 
 
 	for i := range vec {
 	for i := range vec {