|
@@ -8,6 +8,7 @@ import (
|
|
"errors"
|
|
"errors"
|
|
"fmt"
|
|
"fmt"
|
|
"io"
|
|
"io"
|
|
|
|
+ "io/fs"
|
|
"log/slog"
|
|
"log/slog"
|
|
"net"
|
|
"net"
|
|
"net/http"
|
|
"net/http"
|
|
@@ -309,17 +310,14 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
|
|
|
|
- var runner *runnerRef
|
|
|
|
- select {
|
|
|
|
- case runner = <-rCh:
|
|
|
|
- case err = <-eCh:
|
|
|
|
- handleErrorResponse(c, err)
|
|
|
|
|
|
+ r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
|
|
|
+ if err != nil {
|
|
|
|
+ handleScheduleError(c, req.Model, err)
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
checkFit := func(s string, truncate bool) (string, error) {
|
|
checkFit := func(s string, truncate bool) (string, error) {
|
|
- tokens, err := runner.llama.Tokenize(c.Request.Context(), s)
|
|
|
|
|
|
+ tokens, err := r.Tokenize(c.Request.Context(), s)
|
|
if err != nil {
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return "", err
|
|
return "", err
|
|
@@ -328,7 +326,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|
if len(tokens) > opts.NumCtx {
|
|
if len(tokens) > opts.NumCtx {
|
|
if truncate {
|
|
if truncate {
|
|
tokens = tokens[:opts.NumCtx]
|
|
tokens = tokens[:opts.NumCtx]
|
|
- return runner.llama.Detokenize(c.Request.Context(), tokens)
|
|
|
|
|
|
+ return r.Detokenize(c.Request.Context(), tokens)
|
|
} else {
|
|
} else {
|
|
return "", fmt.Errorf("input length exceeds maximum context length")
|
|
return "", fmt.Errorf("input length exceeds maximum context length")
|
|
}
|
|
}
|
|
@@ -346,7 +344,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
return
|
|
}
|
|
}
|
|
- embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
|
|
|
|
|
|
+ embeddings, err = r.Embed(c.Request.Context(), []string{reqEmbed})
|
|
case []any:
|
|
case []any:
|
|
reqEmbedArray := make([]string, len(reqEmbed))
|
|
reqEmbedArray := make([]string, len(reqEmbed))
|
|
for i, v := range reqEmbed {
|
|
for i, v := range reqEmbed {
|
|
@@ -357,7 +355,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|
}
|
|
}
|
|
reqEmbedArray[i] = s
|
|
reqEmbedArray[i] = s
|
|
}
|
|
}
|
|
- embeddings, err = runner.llama.Embed(c.Request.Context(), reqEmbedArray)
|
|
|
|
|
|
+ embeddings, err = r.Embed(c.Request.Context(), reqEmbedArray)
|
|
default:
|
|
default:
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
|
}
|
|
}
|
|
@@ -418,7 +416,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- embedding, err := runner.llama.Embed(c.Request.Context(), []string{req.Prompt})
|
|
|
|
|
|
+ embedding, err := r.Embed(c.Request.Context(), []string{req.Prompt})
|
|
|
|
|
|
if err != nil {
|
|
if err != nil {
|
|
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
|
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|