Roy Han 10 mesi fa
parent
commit
c406fa7a4c
6 ha cambiato i file con 141 aggiunte e 40 eliminazioni
  1. 10 1
      api/client.go
  2. 21 0
      api/types.go
  3. 12 4
      llm/ext_server/server.cpp
  4. 5 5
      llm/server.go
  5. 91 28
      server/routes.go
  6. 2 2
      server/sched_test.go

+ 10 - 1
api/client.go

@@ -347,7 +347,16 @@ func (c *Client) Heartbeat(ctx context.Context) error {
 	return nil
 }
 
-// Embeddings generates embeddings from a model.
+// Embed generates embeddings from a model.
+func (c *Client) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) {
+	var resp EmbedResponse
+	if err := c.do(ctx, http.MethodPost, "/api/embed", req, &resp); err != nil {
+		return nil, err
+	}
+	return &resp, nil
+}
+
+// Embeddings generates embeddings from a model. (Legacy)
 func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
 	var resp EmbeddingResponse
 	if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {

+ 21 - 0
api/types.go

@@ -204,6 +204,22 @@ func (b *TriState) MarshalJSON() ([]byte, error) {
 	return json.Marshal(v)
 }
 
+// EmbedRequest is the request passed to [Client.Embed].
+type EmbedRequest struct {
+	// Model is the model name.
+	Model string `json:"model"`
+
+	// Input is the input to embed.
+	Input any `json:"input,omitempty"`
+
+	// KeepAlive controls how long the model will stay loaded in memory following
+	// this request.
+	KeepAlive *Duration `json:"keep_alive,omitempty"`
+
+	// Options lists model-specific options.
+	Options map[string]interface{} `json:"options"`
+}
+
 // EmbeddingRequest is the request passed to [Client.Embeddings].
 type EmbeddingRequest struct {
 	// Model is the model name.
@@ -223,6 +239,11 @@ type EmbeddingRequest struct {
 	Options map[string]interface{} `json:"options"`
 }
 
+// EmbedResponse is the response from [Client.Embed].
+type EmbedResponse struct {
+	Embeddings [][]float64 `json:"embeddings,omitempty"`
+}
+
 // EmbeddingResponse is the response from [Client.Embeddings].
 type EmbeddingResponse struct {
 	Embedding      []float64   `json:"embedding,omitempty"`

+ 12 - 4
llm/ext_server/server.cpp

@@ -3156,14 +3156,22 @@ int main(int argc, char **argv) {
             {
                 res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
                 const json body = json::parse(req.body);
-                json prompt;
+                json input;
                 if (body.count("content") != 0)
                 {
-                    prompt = body["content"];
+                    input = body["content"];
                 }
                 else
                 {
-                    prompt = "";
+                    input = "";
+                }
+                if (body.count("input") != 0)
+                {
+                    input = body["input"];
+                }
+                else
+                {
+                    input = "";
                 }
 
                 // create and queue the task
@@ -3171,7 +3179,7 @@ int main(int argc, char **argv) {
                 {
                     const int id_task = llama.queue_tasks.get_new_id();
                     llama.queue_results.add_waiting_task_id(id_task);
-                    llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
+                    llama.request_completion(id_task, {{"prompt", input}}, true, -1);
 
                     // get the result
                     task_result result = llama.queue_results.recv(id_task);

+ 5 - 5
llm/server.go

@@ -33,7 +33,7 @@ type LlamaServer interface {
 	Ping(ctx context.Context) error
 	WaitUntilRunning(ctx context.Context) error
 	Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
-	Embedding(ctx context.Context, prompt []string) ([][]float64, error)
+	Embedding(ctx context.Context, prompt string) ([]float64, error)
 	Tokenize(ctx context.Context, content string) ([]int, error)
 	Detokenize(ctx context.Context, tokens []int) (string, error)
 	Close() error
@@ -842,14 +842,14 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 }
 
 type EmbeddingRequest struct {
-	Content []string `json:"content"`
+	Content string `json:"content"`
 }
 
 type EmbeddingResponse struct {
-	Embedding [][]float64 `json:"embedding"`
+	Embedding []float64 `json:"embedding"`
 }
 
-func (s *llmServer) Embedding(ctx context.Context, prompt []string) ([][]float64, error) {
+func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
 	if err := s.sem.Acquire(ctx, 1); err != nil {
 		slog.Error("Failed to acquire semaphore", "error", err)
 		return nil, err
@@ -864,7 +864,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt []string) ([][]float64
 		return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 
-	data, err := json.Marshal(EmbeddingRequest{Content: prompt})
+	data, err := json.Marshal(TokenizeRequest{Content: prompt})
 	if err != nil {
 		return nil, fmt.Errorf("error marshaling embed data: %w", err)
 	}

+ 91 - 28
server/routes.go

@@ -339,8 +339,8 @@ func getDefaultSessionDuration() time.Duration {
 	return defaultSessionDuration
 }
 
-func (s *Server) EmbeddingsHandler(c *gin.Context) {
-	var req api.EmbeddingRequest
+func (s *Server) EmbedHandler(c *gin.Context) {
+	var req api.EmbedRequest
 	err := c.ShouldBindJSON(&req)
 	switch {
 	case errors.Is(err, io.EOF):
@@ -389,39 +389,101 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
 		return
 	}
 
-	switch {
-	// single embedding
-	case len(req.Prompt) > 0:
-		slog.Info("embedding request", "prompt", req.Prompt)
-		embeddings, err := runner.llama.Embedding(c.Request.Context(), []string{req.Prompt})
-		if err != nil {
-			slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
-			c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
+	embeddings := [][]float64{}
+
+	switch reqEmbed := req.Input.(type) {
+	case string:
+		if reqEmbed == "" {
+			c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}})
 			return
 		}
-
-		resp := api.EmbeddingResponse{Embedding: embeddings[0]}
-		c.JSON(http.StatusOK, resp)
-	// batch embeddings
-	case len(req.PromptBatch) > 0:
-		embeddings, err := runner.llama.Embedding(c.Request.Context(), req.PromptBatch)
-		if err != nil {
-			slog.Info(fmt.Sprintf("batch embedding generation failed: %v", err))
-			c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
+		embeddings, err = runner.llama.Embedding(c.Request.Context(), []string{reqEmbed})
+	case []string:
+		if reqEmbed == nil {
+			c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}})
 			return
 		}
+		embeddings, err = runner.llama.Embedding(c.Request.Context(), reqEmbed)
+	default:
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
+	}
 
-		resp := api.EmbeddingResponse{EmbeddingBatch: embeddings}
-		c.JSON(http.StatusOK, resp)
+	if err != nil {
+		slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
+		c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
+		return
+	}
 
-	// empty prompt loads the model
-	default:
-		if req.PromptBatch != nil {
-			c.JSON(http.StatusOK, api.EmbeddingResponse{EmbeddingBatch: [][]float64{}})
-		} else {
-			c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
+	resp := api.EmbedResponse{Embeddings: embeddings}
+	c.JSON(http.StatusOK, resp)
+}
+
+func (s *Server) EmbeddingsHandler(c *gin.Context) {
+	var req api.EmbeddingRequest
+	err := c.ShouldBindJSON(&req)
+	switch {
+	case errors.Is(err, io.EOF):
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
+		return
+	case err != nil:
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
+	if req.Model == "" {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
+		return
+	}
+
+	model, err := GetModel(req.Model)
+	if err != nil {
+		var pErr *fs.PathError
+		if errors.As(err, &pErr) {
+			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
+			return
 		}
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+
+	opts, err := modelOptions(model, req.Options)
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+
+	var sessionDuration time.Duration
+	if req.KeepAlive == nil {
+		sessionDuration = getDefaultSessionDuration()
+	} else {
+		sessionDuration = req.KeepAlive.Duration
 	}
+
+	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
+	var runner *runnerRef
+	select {
+	case runner = <-rCh:
+	case err = <-eCh:
+		handleErrorResponse(c, err)
+		return
+	}
+
+	// an empty request loads the model
+	if req.Prompt == "" {
+		c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
+		return
+	}
+	embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
+	if err != nil {
+		slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
+		c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
+		return
+	}
+
+	resp := api.EmbeddingResponse{
+		Embedding: embedding,
+	}
+	c.JSON(http.StatusOK, resp)
 }
 
 func (s *Server) PullModelHandler(c *gin.Context) {
@@ -1005,7 +1067,8 @@ func (s *Server) GenerateRoutes() http.Handler {
 	r.POST("/api/pull", s.PullModelHandler)
 	r.POST("/api/generate", s.GenerateHandler)
 	r.POST("/api/chat", s.ChatHandler)
-	r.POST("/api/embeddings", s.EmbeddingsHandler)
+	r.POST("/api/embed", s.EmbedHandler)
+	r.POST("/api/embeddings", s.EmbeddingsHandler) // legacy
 	r.POST("/api/create", s.CreateModelHandler)
 	r.POST("/api/push", s.PushModelHandler)
 	r.POST("/api/copy", s.CopyModelHandler)

+ 2 - 2
server/sched_test.go

@@ -608,7 +608,7 @@ type mockLlm struct {
 	pingResp           error
 	waitResp           error
 	completionResp     error
-	embeddingResp      [][]float64
+	embeddingResp      []float64
 	embeddingRespErr   error
 	tokenizeResp       []int
 	tokenizeRespErr    error
@@ -626,7 +626,7 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes
 func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
 	return s.completionResp
 }
-func (s *mockLlm) Embedding(ctx context.Context, prompt []string) ([][]float64, error) {
+func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
 	return s.embeddingResp, s.embeddingRespErr
 }
 func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {