Browse Source

Merge pull request #2296 from ollama/mxyng/img-tags

append image tags to user content
Michael Yang 1 year ago
parent
commit
bfbf2f7cf7
6 changed files with 89 additions and 36 deletions
  1. 3 6
      llm/dyn_ext_server.go
  2. 1 1
      llm/llama.go
  3. 17 8
      server/images.go
  4. 26 7
      server/images_test.go
  5. 40 13
      server/routes.go
  6. 2 1
      server/routes_test.go

+ 3 - 6
llm/dyn_ext_server.go

@@ -161,13 +161,10 @@ func newDynExtServer(library, model string, adapters, projectors []string, opts
 func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
 	resp := newExtServerResp(128)
 	defer freeExtServerResp(resp)
-	var imageData []ImageData
+
 	if len(predict.Images) > 0 {
-		for cnt, i := range predict.Images {
-			imageData = append(imageData, ImageData{Data: i, ID: cnt})
-		}
+		slog.Info(fmt.Sprintf("loaded %d images", len(predict.Images)))
 	}
-	slog.Info(fmt.Sprintf("loaded %d images", len(imageData)))
 
 	request := map[string]any{
 		"prompt":            predict.Prompt,
@@ -189,7 +186,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
 		"penalize_nl":       predict.Options.PenalizeNewline,
 		"seed":              predict.Options.Seed,
 		"stop":              predict.Options.Stop,
-		"image_data":        imageData,
+		"image_data":        predict.Images,
 		"cache_prompt":      true,
 	}
 

+ 1 - 1
llm/llama.go

@@ -62,7 +62,7 @@ const maxRetries = 3
 type PredictOpts struct {
 	Prompt  string
 	Format  string
-	Images  []api.ImageData
+	Images  []ImageData
 	Options api.Options
 }
 

+ 17 - 8
server/images.go

@@ -63,6 +63,7 @@ type PromptVars struct {
 	Prompt   string
 	Response string
 	First    bool
+	Images   []llm.ImageData
 }
 
 // extractParts extracts the parts of the template before and after the {{.Response}} node.
@@ -147,15 +148,13 @@ func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
 }
 
 type ChatHistory struct {
-	Prompts       []PromptVars
-	CurrentImages []api.ImageData
-	LastSystem    string
+	Prompts    []PromptVars
+	LastSystem string
 }
 
 // ChatPrompts returns a list of formatted chat prompts from a list of messages
 func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
 	// build the prompt from the list of messages
-	var currentImages []api.ImageData
 	lastSystem := m.System
 	currentVars := PromptVars{
 		First:  true,
@@ -163,6 +162,7 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
 	}
 
 	prompts := []PromptVars{}
+	var images []llm.ImageData
 
 	for _, msg := range msgs {
 		switch strings.ToLower(msg.Role) {
@@ -179,8 +179,18 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
 				prompts = append(prompts, currentVars)
 				currentVars = PromptVars{}
 			}
+
 			currentVars.Prompt = msg.Content
-			currentImages = msg.Images
+			for i := range msg.Images {
+				id := len(images) + i
+				currentVars.Prompt += fmt.Sprintf(" [img-%d]", id)
+				currentVars.Images = append(currentVars.Images, llm.ImageData{
+					ID:   id,
+					Data: msg.Images[i],
+				})
+			}
+
+			images = append(images, currentVars.Images...)
 		case "assistant":
 			currentVars.Response = msg.Content
 			prompts = append(prompts, currentVars)
@@ -196,9 +206,8 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
 	}
 
 	return &ChatHistory{
-		Prompts:       prompts,
-		CurrentImages: currentImages,
-		LastSystem:    lastSystem,
+		Prompts:    prompts,
+		LastSystem: lastSystem,
 	}, nil
 }
 

+ 26 - 7
server/images_test.go

@@ -238,18 +238,37 @@ func chatHistoryEqual(a, b ChatHistory) bool {
 	if len(a.Prompts) != len(b.Prompts) {
 		return false
 	}
-	if len(a.CurrentImages) != len(b.CurrentImages) {
-		return false
-	}
 	for i, v := range a.Prompts {
-		if v != b.Prompts[i] {
+
+		if v.First != b.Prompts[i].First {
 			return false
 		}
-	}
-	for i, v := range a.CurrentImages {
-		if !bytes.Equal(v, b.CurrentImages[i]) {
+
+		if v.Response != b.Prompts[i].Response {
 			return false
 		}
+
+		if v.Prompt != b.Prompts[i].Prompt {
+			return false
+		}
+
+		if v.System != b.Prompts[i].System {
+			return false
+		}
+
+		if len(v.Images) != len(b.Prompts[i].Images) {
+			return false
+		}
+
+		for j, img := range v.Images {
+			if img.ID != b.Prompts[i].Images[j].ID {
+				return false
+			}
+
+			if !bytes.Equal(img.Data, b.Prompts[i].Images[j].Data) {
+				return false
+			}
+		}
 	}
 	return a.LastSystem == b.LastSystem
 }

+ 40 - 13
server/routes.go

@@ -244,6 +244,10 @@ func GenerateHandler(c *gin.Context) {
 			promptVars.System = model.System
 		}
 
+		for i := range req.Images {
+			promptVars.Prompt += fmt.Sprintf(" [img-%d]", i)
+		}
+
 		p, err := model.PreResponsePrompt(promptVars)
 		if err != nil {
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -308,11 +312,19 @@ func GenerateHandler(c *gin.Context) {
 			ch <- resp
 		}
 
+		var images []llm.ImageData
+		for i := range req.Images {
+			images = append(images, llm.ImageData{
+				ID:   i,
+				Data: req.Images[i],
+			})
+		}
+
 		// Start prediction
 		predictReq := llm.PredictOpts{
 			Prompt:  prompt,
 			Format:  req.Format,
-			Images:  req.Images,
+			Images:  images,
 			Options: opts,
 		}
 		if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
@@ -1139,7 +1151,8 @@ func ChatHandler(c *gin.Context) {
 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 	}
-	prompt, err := trimmedPrompt(c.Request.Context(), chat, model)
+
+	prompt, images, err := trimmedPrompt(c.Request.Context(), chat, model)
 	if err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
@@ -1182,7 +1195,7 @@ func ChatHandler(c *gin.Context) {
 		predictReq := llm.PredictOpts{
 			Prompt:  prompt,
 			Format:  req.Format,
-			Images:  chat.CurrentImages,
+			Images:  images,
 			Options: opts,
 		}
 		if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
@@ -1229,34 +1242,47 @@ type promptInfo struct {
 
 // trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
 // while preserving the most recent system message.
-func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, error) {
+func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, []llm.ImageData, error) {
 	if len(chat.Prompts) == 0 {
-		return "", nil
+		return "", nil, nil
 	}
 
 	var promptsToAdd []promptInfo
 	var totalTokenLength int
 	var systemPromptIncluded bool
 
+	var images []llm.ImageData
 	// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
 	for i := len(chat.Prompts) - 1; i >= 0; i-- {
-		promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1)
+		prompt := chat.Prompts[i]
+		promptText, err := promptString(model, prompt, i == len(chat.Prompts)-1)
 		if err != nil {
-			return "", err
+			return "", nil, err
 		}
 
 		encodedTokens, err := loaded.runner.Encode(ctx, promptText)
 		if err != nil {
-			return "", err
+			return "", nil, err
 		}
 
 		if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 {
 			break // reached max context length, stop adding more prompts
 		}
 
+		for j := range prompt.Images {
+			if totalTokenLength+768 > loaded.NumCtx {
+				// this decreases the token length but overestimating is fine
+				prompt.Prompt = strings.ReplaceAll(prompt.Prompt, fmt.Sprintf(" [img-%d]", prompt.Images[j].ID), "")
+				continue
+			}
+
+			totalTokenLength += 768
+			images = append(images, prompt.Images[j])
+		}
+
 		totalTokenLength += len(encodedTokens)
-		systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != ""
-		promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)})
+		systemPromptIncluded = systemPromptIncluded || prompt.System != ""
+		promptsToAdd = append(promptsToAdd, promptInfo{vars: prompt, tokenLen: len(encodedTokens)})
 	}
 
 	// ensure the system prompt is included, if not already
@@ -1264,7 +1290,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
 		var err error
 		promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd)
 		if err != nil {
-			return "", err
+			return "", nil, err
 		}
 	}
 
@@ -1275,11 +1301,12 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
 	for i, prompt := range promptsToAdd {
 		promptText, err := promptString(model, prompt.vars, i == 0)
 		if err != nil {
-			return "", err
+			return "", nil, err
 		}
 		result = promptText + result
 	}
-	return result, nil
+
+	return result, images, nil
 }
 
 // promptString applies the model template to the prompt

+ 2 - 1
server/routes_test.go

@@ -455,7 +455,8 @@ func Test_ChatPrompt(t *testing.T) {
 					NumCtx: tt.numCtx,
 				},
 			}
-			got, err := trimmedPrompt(context.Background(), tt.chat, m)
+			// TODO: add tests for trimming images
+			got, _, err := trimmedPrompt(context.Background(), tt.chat, m)
 			if tt.wantErr != "" {
 				if err == nil {
 					t.Errorf("ChatPrompt() expected error, got nil")