|
@@ -8,7 +8,6 @@ import (
|
|
"errors"
|
|
"errors"
|
|
"fmt"
|
|
"fmt"
|
|
"io"
|
|
"io"
|
|
- "io/fs"
|
|
|
|
"log/slog"
|
|
"log/slog"
|
|
"math"
|
|
"math"
|
|
"net"
|
|
"net"
|
|
@@ -260,33 +259,11 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- if req.Model == "" {
|
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
|
|
- return
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
if req.Truncate == nil {
|
|
if req.Truncate == nil {
|
|
truncate := true
|
|
truncate := true
|
|
req.Truncate = &truncate
|
|
req.Truncate = &truncate
|
|
}
|
|
}
|
|
|
|
|
|
- model, err := GetModel(req.Model)
|
|
|
|
- if err != nil {
|
|
|
|
- var pErr *fs.PathError
|
|
|
|
- if errors.As(err, &pErr) {
|
|
|
|
- c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
|
|
|
- return
|
|
|
|
- }
|
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
|
- return
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- opts, err := modelOptions(model, req.Options)
|
|
|
|
- if err != nil {
|
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
|
- return
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
reqEmbed := []string{}
|
|
reqEmbed := []string{}
|
|
|
|
|
|
switch embeddings := req.Input.(type) {
|
|
switch embeddings := req.Input.(type) {
|
|
@@ -314,41 +291,40 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
|
|
|
|
|
+ r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
|
if err != nil {
|
|
if err != nil {
|
|
handleScheduleError(c, req.Model, err)
|
|
handleScheduleError(c, req.Model, err)
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- kvData, err := getKVData(model.ModelPath, false)
|
|
|
|
|
|
+ kvData, err := getKVData(m.ModelPath, false)
|
|
if err != nil {
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- checkFit := func(s string, truncate bool) (string, error) {
|
|
|
|
- tokens, err := r.Tokenize(c.Request.Context(), s)
|
|
|
|
- if err != nil {
|
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
|
- return "", err
|
|
|
|
- }
|
|
|
|
|
|
+ reqEmbedArray := make([]string, len(reqEmbed))
|
|
|
|
+ for i, v := range reqEmbed {
|
|
|
|
+ s, err := func(v string, truncate bool) (string, error) {
|
|
|
|
+ tokens, err := r.Tokenize(c.Request.Context(), v)
|
|
|
|
+ if err != nil {
|
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
|
+ return "", err
|
|
|
|
+ }
|
|
|
|
|
|
- ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
|
|
|
- if len(tokens) > ctxLen {
|
|
|
|
- if truncate {
|
|
|
|
- tokens = tokens[:ctxLen]
|
|
|
|
- return r.Detokenize(c.Request.Context(), tokens)
|
|
|
|
- } else {
|
|
|
|
- return "", fmt.Errorf("input length exceeds maximum context length")
|
|
|
|
|
|
+ ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
|
|
|
+ if len(tokens) > ctxLen {
|
|
|
|
+ if truncate {
|
|
|
|
+ tokens = tokens[:ctxLen]
|
|
|
|
+ return r.Detokenize(c.Request.Context(), tokens)
|
|
|
|
+ } else {
|
|
|
|
+ return "", fmt.Errorf("input length exceeds maximum context length")
|
|
|
|
+ }
|
|
}
|
|
}
|
|
- }
|
|
|
|
|
|
|
|
- return s, nil
|
|
|
|
- }
|
|
|
|
|
|
+ return v, nil
|
|
|
|
+ }(v, *req.Truncate)
|
|
|
|
|
|
- reqEmbedArray := make([]string, len(reqEmbed))
|
|
|
|
- for i, v := range reqEmbed {
|
|
|
|
- s, err := checkFit(v, *req.Truncate)
|
|
|
|
if err != nil {
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
return
|