瀏覽代碼

image embeddings

Roy Han 9 月之前
父節點
當前提交
eb7cc2d1ce
共有 5 個文件被更改,包括 56 次插入26 次删除
  1. 4 0
      api/types.go
  2. 11 1
      llm/ext_server/server.cpp
  3. 5 4
      llm/server.go
  4. 35 20
      server/routes.go
  5. 1 1
      server/sched_test.go

+ 4 - 0
api/types.go

@@ -187,6 +187,10 @@ type EmbedRequest struct {
 
 	Truncate *bool `json:"truncate,omitempty"`
 
+	// Images is an optional list of base64-encoded images accompanying this
+	// request, for multimodal models.
+	Images []ImageData `json:"images,omitempty"`
+
 	// Options lists model-specific options.
 	Options map[string]interface{} `json:"options"`
 }

+ 11 - 1
llm/ext_server/server.cpp

@@ -3192,12 +3192,22 @@ int main(int argc, char **argv) {
                     prompt = prompt[0];
                 }
 
+                json image_data;
+                if (body.count("image_data") != 0)
+                {
+                    image_data = body["image_data"];
+                }
+                else {
+                    image_data = "";
+                }
+                // TODO: prompt needs to represent the image data
+
                 // create and queue the task
                 json responses;
                 {
                     const int id_task = llama.queue_tasks.get_new_id();
                     llama.queue_results.add_waiting_task_id(id_task);
-                    llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
+                    llama.request_completion(id_task, { {"prompt", prompt}, {"image_data", image_data} }, true, -1);
 
                     // get the result
                     task_result result = llama.queue_results.recv(id_task);

+ 5 - 4
llm/server.go

@@ -33,7 +33,7 @@ type LlamaServer interface {
 	Ping(ctx context.Context) error
 	WaitUntilRunning(ctx context.Context) error
 	Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
-	Embed(ctx context.Context, input []string) ([][]float32, error)
+	Embed(ctx context.Context, input []string, images []ImageData) ([][]float32, error)
 	Tokenize(ctx context.Context, content string) ([]int, error)
 	Detokenize(ctx context.Context, tokens []int) (string, error)
 	Close() error
@@ -860,14 +860,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 }
 
 type EmbedRequest struct {
-	Content []string `json:"content"`
+	Content []string    `json:"content"`
+	Images  []ImageData `json:"image_data"`
 }
 
 type EmbedResponse struct {
 	Embedding [][]float32 `json:"embedding"`
 }
 
-func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) {
+func (s *llmServer) Embed(ctx context.Context, input []string, images []ImageData) ([][]float32, error) {
 	if err := s.sem.Acquire(ctx, 1); err != nil {
 		slog.Error("Failed to acquire semaphore", "error", err)
 		return nil, err
@@ -882,7 +883,7 @@ func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, err
 		return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 
-	data, err := json.Marshal(EmbedRequest{Content: input})
+	data, err := json.Marshal(EmbedRequest{Content: input, Images: images})
 	if err != nil {
 		return nil, fmt.Errorf("error marshaling embed data: %w", err)
 	}

+ 35 - 20
server/routes.go

@@ -265,29 +265,38 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 		truncate = false
 	}
 
+	inputCheck := true
+
+	if req.Images != nil {
+		inputCheck = false
+	}
+
 	var input []string
 
-	switch i := req.Input.(type) {
-	case string:
-		if len(i) > 0 {
-			input = append(input, i)
-		}
-	case []any:
-		for _, v := range i {
-			if _, ok := v.(string); !ok {
-				c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
-				return
+	if inputCheck {
+
+		switch i := req.Input.(type) {
+		case string:
+			if len(i) > 0 {
+				input = append(input, i)
 			}
-			input = append(input, v.(string))
+		case []any:
+			for _, v := range i {
+				if _, ok := v.(string); !ok {
+					c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
+					return
+				}
+				input = append(input, v.(string))
+			}
+		default:
+			c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
+			return
 		}
-	default:
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
-		return
-	}
 
-	if len(input) == 0 {
-		c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
-		return
+		if len(input) == 0 {
+			c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
+			return
+		}
 	}
 
 	r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
@@ -326,7 +335,13 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 
 		input[i] = s
 	}
-	embeddings, err := r.Embed(c.Request.Context(), input)
+
+	images := make([]llm.ImageData, len(req.Images))
+	for i := range req.Images {
+		images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
+	}
+
+	embeddings, err := r.Embed(c.Request.Context(), input, images)
 
 	if err != nil {
 		slog.Error("embedding generation failed", "error", err)
@@ -384,7 +399,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
 		return
 	}
 
-	embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
+	embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt}, nil)
 
 	if err != nil {
 		slog.Info(fmt.Sprintf("embedding generation failed: %v", err))

+ 1 - 1
server/sched_test.go

@@ -660,7 +660,7 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes
 func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
 	return s.completionResp
 }
-func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) {
+func (s *mockLlm) Embed(ctx context.Context, input []string, images []llm.ImageData) ([][]float32, error) {
 	return s.embedResp, s.embedRespErr
 }
 func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {