Explorar o código

Initial Batch Embedding

Roy Han hai 10 meses
pai
achega
c22d54895a
Modificáronse 4 ficheiros con 93 adicións e 10 borrados
  1. 3 2
      api/types.go
  2. 49 2
      llm/server.go
  3. 38 4
      server/routes.go
  4. 3 2
      server/sched_test.go

+ 3 - 2
api/types.go

@@ -210,7 +210,8 @@ type EmbeddingRequest struct {
 	Model string `json:"model"`
 	Model string `json:"model"`
 
 
 	// Prompt is the textual prompt to embed.
 	// Prompt is the textual prompt to embed.
-	Prompt string `json:"prompt"`
+	// Prompt string `json:"prompt"`
+	Prompt interface{} `json:"prompt"`
 
 
 	// KeepAlive controls how long the model will stay loaded in memory following
 	// KeepAlive controls how long the model will stay loaded in memory following
 	// this request.
 	// this request.
@@ -222,7 +223,7 @@ type EmbeddingRequest struct {
 
 
 // EmbeddingResponse is the response from [Client.Embeddings].
 // EmbeddingResponse is the response from [Client.Embeddings].
 type EmbeddingResponse struct {
 type EmbeddingResponse struct {
-	Embedding []float64 `json:"embedding"`
+	Embedding [][]float64 `json:"embedding"`
 }
 }
 
 
 // CreateRequest is the request passed to [Client.Create].
 // CreateRequest is the request passed to [Client.Create].

+ 49 - 2
llm/server.go

@@ -19,6 +19,7 @@ import (
 	"runtime"
 	"runtime"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
+	"sync"
 	"time"
 	"time"
 
 
 	"golang.org/x/sync/semaphore"
 	"golang.org/x/sync/semaphore"
@@ -33,7 +34,7 @@ type LlamaServer interface {
 	Ping(ctx context.Context) error
 	Ping(ctx context.Context) error
 	WaitUntilRunning(ctx context.Context) error
 	WaitUntilRunning(ctx context.Context) error
 	Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
 	Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
-	Embedding(ctx context.Context, prompt string) ([]float64, error)
+	Embedding(ctx context.Context, prompts interface{}) ([][]float64, error)
 	Tokenize(ctx context.Context, content string) ([]int, error)
 	Tokenize(ctx context.Context, content string) ([]int, error)
 	Detokenize(ctx context.Context, tokens []int) (string, error)
 	Detokenize(ctx context.Context, tokens []int) (string, error)
 	Close() error
 	Close() error
@@ -849,7 +850,7 @@ type EmbeddingResponse struct {
 	Embedding []float64 `json:"embedding"`
 	Embedding []float64 `json:"embedding"`
 }
 }
 
 
-func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
+func (s *llmServer) Embedding(ctx context.Context, prompts interface{}) ([][]float64, error) {
 	if err := s.sem.Acquire(ctx, 1); err != nil {
 	if err := s.sem.Acquire(ctx, 1); err != nil {
 		slog.Error("Failed to acquire semaphore", "error", err)
 		slog.Error("Failed to acquire semaphore", "error", err)
 		return nil, err
 		return nil, err
@@ -864,6 +865,52 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
 		return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
 		return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 	}
 
 
+	switch prompts := prompts.(type) {
+	case string:
+		// single prompt
+		embedding, err := s.EmbeddingSingle(ctx, prompts)
+		if err != nil {
+			return nil, err
+		}
+		return [][]float64{embedding}, nil
+	case []string:
+		// multiple prompts
+		errCh := make(chan error, 1)
+		successCh := make(chan [][]float64, 1)
+		num_prompts := len(prompts)
+		embeddings := make([][]float64, num_prompts)
+		var wg sync.WaitGroup
+		wg.Add(num_prompts)
+		for i, p := range prompts {
+			go func(i int, p string) {
+				defer wg.Done()
+				slog.Info("embedding", "prompt", p)
+				embedding, err := s.EmbeddingSingle(ctx, p)
+				if err != nil {
+					errCh <- err
+					return
+				}
+				embeddings[i] = embedding
+			}(i, p)
+		}
+
+		go func() {
+			wg.Wait()
+			successCh <- embeddings
+		}()
+
+		select {
+		case err := <-errCh:
+			return nil, err
+		case embeddings := <-successCh:
+			return embeddings, nil
+		}
+	default:
+		return nil, fmt.Errorf("unsupported prompt type: %T", prompts)
+	}
+}
+
+func (s *llmServer) EmbeddingSingle(ctx context.Context, prompt string) ([]float64, error) {
 	data, err := json.Marshal(TokenizeRequest{Content: prompt})
 	data, err := json.Marshal(TokenizeRequest{Content: prompt})
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("error marshaling embed data: %w", err)
 		return nil, fmt.Errorf("error marshaling embed data: %w", err)

+ 38 - 4
server/routes.go

@@ -356,6 +356,27 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
+	// if we want to stick with the one prompt format, we can use custom unmarshalling
+	// otherwise just have separate fields
+
+	switch req.Prompt.(type) {
+	case string:
+	case []interface{}:
+		prompts := make([]string, len(req.Prompt.([]interface{})))
+		for i, p := range req.Prompt.([]interface{}) {
+			if str, ok := p.(string); ok {
+				prompts[i] = str
+			} else {
+				c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "prompt must be a string or list of strings"})
+				return
+			}
+		}
+		req.Prompt = prompts
+	default:
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "prompt must be a string or list of strings"})
+		return
+	}
+
 	model, err := GetModel(req.Model)
 	model, err := GetModel(req.Model)
 	if err != nil {
 	if err != nil {
 		var pErr *fs.PathError
 		var pErr *fs.PathError
@@ -389,13 +410,26 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
-	// an empty request loads the model
-	if req.Prompt == "" {
-		c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
+	var embedding [][]float64
+
+	switch prompt := req.Prompt.(type) {
+	case string:
+		if prompt == "" {
+			c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: [][]float64{}})
+			return
+		}
+		embedding, err = runner.llama.Embedding(c.Request.Context(), prompt)
+	case []string:
+		if len(prompt) == 0 {
+			c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: [][]float64{}})
+			return
+		}
+		embedding, err = runner.llama.Embedding(c.Request.Context(), prompt)
+	default:
+		c.AbortWithStatus(http.StatusInternalServerError)
 		return
 		return
 	}
 	}
 
 
-	embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
 	if err != nil {
 	if err != nil {
 		slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
 		slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
 		c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
 		c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})

+ 3 - 2
server/sched_test.go

@@ -608,7 +608,7 @@ type mockLlm struct {
 	pingResp           error
 	pingResp           error
 	waitResp           error
 	waitResp           error
 	completionResp     error
 	completionResp     error
-	embeddingResp      []float64
+	embeddingResp      [][]float64
 	embeddingRespErr   error
 	embeddingRespErr   error
 	tokenizeResp       []int
 	tokenizeResp       []int
 	tokenizeRespErr    error
 	tokenizeRespErr    error
@@ -626,7 +626,8 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes
 func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
 func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
 	return s.completionResp
 	return s.completionResp
 }
 }
-func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
+
+func (s *mockLlm) Embedding(ctx context.Context, prompts interface{}) ([][]float64, error) {
 	return s.embeddingResp, s.embeddingRespErr
 	return s.embeddingResp, s.embeddingRespErr
 }
 }
 func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
 func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {