|
@@ -9,6 +9,7 @@ import (
|
|
"fmt"
|
|
"fmt"
|
|
"io"
|
|
"io"
|
|
"log/slog"
|
|
"log/slog"
|
|
|
|
+ "math"
|
|
"net"
|
|
"net"
|
|
"net/http"
|
|
"net/http"
|
|
"net/netip"
|
|
"net/netip"
|
|
@@ -271,6 +272,121 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|
streamResponse(c, ch)
|
|
streamResponse(c, ch)
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+func (s *Server) EmbedHandler(c *gin.Context) {
|
|
|
|
+ var req api.EmbedRequest
|
|
|
|
+ err := c.ShouldBindJSON(&req)
|
|
|
|
+ switch {
|
|
|
|
+ case errors.Is(err, io.EOF):
|
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
|
|
+ return
|
|
|
|
+ case err != nil:
|
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ truncate := true
|
|
|
|
+
|
|
|
|
+ if req.Truncate != nil && !*req.Truncate {
|
|
|
|
+ truncate = false
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ var input []string
|
|
|
|
+
|
|
|
|
+ switch i := req.Input.(type) {
|
|
|
|
+ case string:
|
|
|
|
+ if len(i) > 0 {
|
|
|
|
+ input = append(input, i)
|
|
|
|
+ }
|
|
|
|
+ case []any:
|
|
|
|
+ for _, v := range i {
|
|
|
|
+ if _, ok := v.(string); !ok {
|
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ input = append(input, v.(string))
|
|
|
|
+ }
|
|
|
|
+ default:
|
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if len(input) == 0 {
|
|
|
|
+ c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
|
|
|
+ if err != nil {
|
|
|
|
+ handleScheduleError(c, req.Model, err)
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ kvData, err := getKVData(m.ModelPath, false)
|
|
|
|
+ if err != nil {
|
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ for i, s := range input {
|
|
|
|
+ tokens, err := r.Tokenize(c.Request.Context(), s)
|
|
|
|
+ if err != nil {
|
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
|
|
|
+ if len(tokens) > ctxLen {
|
|
|
|
+ if !truncate {
|
|
|
|
+ c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ tokens = tokens[:ctxLen]
|
|
|
|
+ s, err = r.Detokenize(c.Request.Context(), tokens)
|
|
|
|
+ if err != nil {
|
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ input[i] = s
|
|
|
|
+ }
|
|
|
|
+ embeddings, err := r.Embed(c.Request.Context(), input)
|
|
|
|
+
|
|
|
|
+ if err != nil {
|
|
|
|
+ slog.Error("embedding generation failed", "error", err)
|
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ for i, e := range embeddings {
|
|
|
|
+ embeddings[i] = normalize(e)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ resp := api.EmbedResponse{
|
|
|
|
+ Model: req.Model,
|
|
|
|
+ Embeddings: embeddings,
|
|
|
|
+ }
|
|
|
|
+ c.JSON(http.StatusOK, resp)
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func normalize(vec []float32) []float32 {
|
|
|
|
+ var sum float32
|
|
|
|
+ for _, v := range vec {
|
|
|
|
+ sum += v * v
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ norm := float32(0.0)
|
|
|
|
+ if sum > 0 {
|
|
|
|
+ norm = float32(1.0 / math.Sqrt(float64(sum)))
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ for i := range vec {
|
|
|
|
+ vec[i] *= norm
|
|
|
|
+ }
|
|
|
|
+ return vec
|
|
|
|
+}
|
|
|
|
+
|
|
func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|
func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|
var req api.EmbeddingRequest
|
|
var req api.EmbeddingRequest
|
|
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
|
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
|
@@ -293,14 +409,24 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
|
|
|
|
|
|
+ embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
|
|
|
|
+
|
|
if err != nil {
|
|
if err != nil {
|
|
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
|
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: embedding})
|
|
|
|
|
|
+ embedding := make([]float64, len(embeddings[0]))
|
|
|
|
+
|
|
|
|
+ for i, v := range embeddings[0] {
|
|
|
|
+ embedding[i] = float64(v)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ resp := api.EmbeddingResponse{
|
|
|
|
+ Embedding: embedding,
|
|
|
|
+ }
|
|
|
|
+ c.JSON(http.StatusOK, resp)
|
|
}
|
|
}
|
|
|
|
|
|
func (s *Server) PullModelHandler(c *gin.Context) {
|
|
func (s *Server) PullModelHandler(c *gin.Context) {
|
|
@@ -919,6 +1045,7 @@ func (s *Server) GenerateRoutes() http.Handler {
|
|
r.POST("/api/pull", s.PullModelHandler)
|
|
r.POST("/api/pull", s.PullModelHandler)
|
|
r.POST("/api/generate", s.GenerateHandler)
|
|
r.POST("/api/generate", s.GenerateHandler)
|
|
r.POST("/api/chat", s.ChatHandler)
|
|
r.POST("/api/chat", s.ChatHandler)
|
|
|
|
+ r.POST("/api/embed", s.EmbedHandler)
|
|
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
|
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
|
r.POST("/api/create", s.CreateModelHandler)
|
|
r.POST("/api/create", s.CreateModelHandler)
|
|
r.POST("/api/push", s.PushModelHandler)
|
|
r.POST("/api/push", s.PushModelHandler)
|