|
@@ -10,6 +10,7 @@ import (
|
|
"io"
|
|
"io"
|
|
"io/fs"
|
|
"io/fs"
|
|
"log/slog"
|
|
"log/slog"
|
|
|
|
+ "math"
|
|
"net"
|
|
"net"
|
|
"net/http"
|
|
"net/http"
|
|
"net/netip"
|
|
"net/netip"
|
|
@@ -21,7 +22,6 @@ import (
|
|
"syscall"
|
|
"syscall"
|
|
"time"
|
|
"time"
|
|
|
|
|
|
- "github.com/chewxy/math32"
|
|
|
|
"github.com/gin-contrib/cors"
|
|
"github.com/gin-contrib/cors"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gin-gonic/gin"
|
|
|
|
|
|
@@ -287,23 +287,27 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- switch reqEmbed := req.Input.(type) {
|
|
|
|
|
|
+ reqEmbed := []string{}
|
|
|
|
+
|
|
|
|
+ switch embeddings := req.Input.(type) {
|
|
case string:
|
|
case string:
|
|
- if reqEmbed == "" {
|
|
|
|
|
|
+ if embeddings == "" {
|
|
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
|
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
+ reqEmbed = []string{embeddings}
|
|
case []any:
|
|
case []any:
|
|
- if reqEmbed == nil {
|
|
|
|
|
|
+ if len(embeddings) == 0 {
|
|
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
|
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- for _, v := range reqEmbed {
|
|
|
|
|
|
+ for _, v := range embeddings {
|
|
if _, ok := v.(string); !ok {
|
|
if _, ok := v.(string); !ok {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
+ reqEmbed = append(reqEmbed, v.(string))
|
|
}
|
|
}
|
|
default:
|
|
default:
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
|
@@ -335,30 +339,16 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|
return s, nil
|
|
return s, nil
|
|
}
|
|
}
|
|
|
|
|
|
- embeddings := [][]float32{}
|
|
|
|
-
|
|
|
|
- switch reqEmbed := req.Input.(type) {
|
|
|
|
- case string:
|
|
|
|
- reqEmbed, err = checkFit(reqEmbed, *req.Truncate)
|
|
|
|
|
|
+ reqEmbedArray := make([]string, len(reqEmbed))
|
|
|
|
+ for i, v := range reqEmbed {
|
|
|
|
+ s, err := checkFit(v, *req.Truncate)
|
|
if err != nil {
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
return
|
|
}
|
|
}
|
|
- embeddings, err = r.Embed(c.Request.Context(), []string{reqEmbed})
|
|
|
|
- case []any:
|
|
|
|
- reqEmbedArray := make([]string, len(reqEmbed))
|
|
|
|
- for i, v := range reqEmbed {
|
|
|
|
- s, err := checkFit(v.(string), *req.Truncate)
|
|
|
|
- if err != nil {
|
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
|
- return
|
|
|
|
- }
|
|
|
|
- reqEmbedArray[i] = s
|
|
|
|
- }
|
|
|
|
- embeddings, err = r.Embed(c.Request.Context(), reqEmbedArray)
|
|
|
|
- default:
|
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
|
|
|
|
|
+ reqEmbedArray[i] = s
|
|
}
|
|
}
|
|
|
|
+ embeddings, err := r.Embed(c.Request.Context(), reqEmbedArray)
|
|
|
|
|
|
if err != nil {
|
|
if err != nil {
|
|
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
|
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
|
@@ -385,7 +375,7 @@ func normalize(vec []float32) []float32 {
|
|
|
|
|
|
norm := float32(0.0)
|
|
norm := float32(0.0)
|
|
if sum > 0 {
|
|
if sum > 0 {
|
|
- norm = float32(1.0 / math32.Sqrt(sum))
|
|
|
|
|
|
+ norm = float32(1.0 / math.Sqrt(float64(sum)))
|
|
}
|
|
}
|
|
|
|
|
|
for i := range vec {
|
|
for i := range vec {
|