|
@@ -259,38 +259,42 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ truncate := true
|
|
|
|
+
|
|
|
|
+ if req.Truncate != nil && !*req.Truncate {
|
|
|
|
+ truncate = false
|
|
|
|
+ }
|
|
|
|
+
|
|
if req.Truncate == nil {
|
|
if req.Truncate == nil {
|
|
truncate := true
|
|
truncate := true
|
|
req.Truncate = &truncate
|
|
req.Truncate = &truncate
|
|
}
|
|
}
|
|
|
|
|
|
- reqEmbed := []string{}
|
|
|
|
|
|
+ var input []string
|
|
|
|
|
|
- switch embeddings := req.Input.(type) {
|
|
|
|
|
|
+ switch i := req.Input.(type) {
|
|
case string:
|
|
case string:
|
|
- if embeddings == "" {
|
|
|
|
- c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
|
|
|
- return
|
|
|
|
|
|
+ if len(i) > 0 {
|
|
|
|
+ input = append(input, i)
|
|
}
|
|
}
|
|
- reqEmbed = []string{embeddings}
|
|
|
|
case []any:
|
|
case []any:
|
|
- if len(embeddings) == 0 {
|
|
|
|
- c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
|
|
|
- return
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- for _, v := range embeddings {
|
|
|
|
|
|
+ for _, v := range i {
|
|
if _, ok := v.(string); !ok {
|
|
if _, ok := v.(string); !ok {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
|
return
|
|
return
|
|
}
|
|
}
|
|
- reqEmbed = append(reqEmbed, v.(string))
|
|
|
|
|
|
+ input = append(input, v.(string))
|
|
}
|
|
}
|
|
default:
|
|
default:
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ if len(input) == 0 {
|
|
|
|
+ c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
|
if err != nil {
|
|
if err != nil {
|
|
handleScheduleError(c, req.Model, err)
|
|
handleScheduleError(c, req.Model, err)
|
|
@@ -303,8 +307,8 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- reqEmbedArray := make([]string, len(reqEmbed))
|
|
|
|
- for i, s := range reqEmbed {
|
|
|
|
|
|
+ reqEmbedArray := make([]string, len(input))
|
|
|
|
+ for i, s := range input {
|
|
tokens, err := r.Tokenize(c.Request.Context(), s)
|
|
tokens, err := r.Tokenize(c.Request.Context(), s)
|
|
if err != nil {
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
@@ -313,17 +317,17 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|
|
|
|
|
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
|
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
|
if len(tokens) > ctxLen {
|
|
if len(tokens) > ctxLen {
|
|
- if *req.Truncate {
|
|
|
|
- tokens = tokens[:ctxLen]
|
|
|
|
- s, err = r.Detokenize(c.Request.Context(), tokens)
|
|
|
|
- if err != nil {
|
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
|
- return
|
|
|
|
- }
|
|
|
|
- } else {
|
|
|
|
|
|
+ if !truncate {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+ tokens = tokens[:ctxLen]
|
|
|
|
+ s, err = r.Detokenize(c.Request.Context(), tokens)
|
|
|
|
+ if err != nil {
|
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
|
+ return
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
reqEmbedArray[i] = s
|
|
reqEmbedArray[i] = s
|
|
@@ -331,7 +335,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|
embeddings, err := r.Embed(c.Request.Context(), reqEmbedArray)
|
|
embeddings, err := r.Embed(c.Request.Context(), reqEmbedArray)
|
|
|
|
|
|
if err != nil {
|
|
if err != nil {
|
|
- slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
|
|
|
|
|
+ slog.Error("embedding generation failed", "error", err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
|
return
|
|
return
|
|
}
|
|
}
|
|
@@ -1030,7 +1034,7 @@ func (s *Server) GenerateRoutes() http.Handler {
|
|
r.POST("/api/generate", s.GenerateHandler)
|
|
r.POST("/api/generate", s.GenerateHandler)
|
|
r.POST("/api/chat", s.ChatHandler)
|
|
r.POST("/api/chat", s.ChatHandler)
|
|
r.POST("/api/embed", s.EmbedHandler)
|
|
r.POST("/api/embed", s.EmbedHandler)
|
|
- r.POST("/api/embeddings", s.EmbeddingsHandler) // legacy
|
|
|
|
|
|
+ r.POST("/api/embeddings", s.EmbeddingsHandler)
|
|
r.POST("/api/create", s.CreateModelHandler)
|
|
r.POST("/api/create", s.CreateModelHandler)
|
|
r.POST("/api/push", s.PushModelHandler)
|
|
r.POST("/api/push", s.PushModelHandler)
|
|
r.POST("/api/copy", s.CopyModelHandler)
|
|
r.POST("/api/copy", s.CopyModelHandler)
|