Browse Source

merge conflicts

Roy Han 9 months ago
parent
commit
b686ac144c
1 changed files with 9 additions and 11 deletions
  1. 9 11
      server/routes.go

+ 9 - 11
server/routes.go

@@ -8,6 +8,7 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
+	"io/fs"
 	"log/slog"
 	"log/slog"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
@@ -309,17 +310,14 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
-	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
-	var runner *runnerRef
-	select {
-	case runner = <-rCh:
-	case err = <-eCh:
-		handleErrorResponse(c, err)
+	r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
+	if err != nil {
+		handleScheduleError(c, req.Model, err)
 		return
 		return
 	}
 	}
 
 
 	checkFit := func(s string, truncate bool) (string, error) {
 	checkFit := func(s string, truncate bool) (string, error) {
-		tokens, err := runner.llama.Tokenize(c.Request.Context(), s)
+		tokens, err := r.Tokenize(c.Request.Context(), s)
 		if err != nil {
 		if err != nil {
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 			return "", err
 			return "", err
@@ -328,7 +326,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 		if len(tokens) > opts.NumCtx {
 		if len(tokens) > opts.NumCtx {
 			if truncate {
 			if truncate {
 				tokens = tokens[:opts.NumCtx]
 				tokens = tokens[:opts.NumCtx]
-				return runner.llama.Detokenize(c.Request.Context(), tokens)
+				return r.Detokenize(c.Request.Context(), tokens)
 			} else {
 			} else {
 				return "", fmt.Errorf("input length exceeds maximum context length")
 				return "", fmt.Errorf("input length exceeds maximum context length")
 			}
 			}
@@ -346,7 +344,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 			return
 			return
 		}
 		}
-		embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
+		embeddings, err = r.Embed(c.Request.Context(), []string{reqEmbed})
 	case []any:
 	case []any:
 		reqEmbedArray := make([]string, len(reqEmbed))
 		reqEmbedArray := make([]string, len(reqEmbed))
 		for i, v := range reqEmbed {
 		for i, v := range reqEmbed {
@@ -357,7 +355,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 			}
 			}
 			reqEmbedArray[i] = s
 			reqEmbedArray[i] = s
 		}
 		}
-		embeddings, err = runner.llama.Embed(c.Request.Context(), reqEmbedArray)
+		embeddings, err = r.Embed(c.Request.Context(), reqEmbedArray)
 	default:
 	default:
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
 	}
 	}
@@ -418,7 +416,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
-	embedding, err := runner.llama.Embed(c.Request.Context(), []string{req.Prompt})
+	embedding, err := r.Embed(c.Request.Context(), []string{req.Prompt})
 
 
 	if err != nil {
 	if err != nil {
 		slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
 		slog.Info(fmt.Sprintf("embedding generation failed: %v", err))