Michael Yang 1 سال پیش
والد
کامیت
8450bf66e6
3فایلهای تغییر یافته به همراه41 افزوده شده و 20 حذف شده
  1. 1 1
      llm/llama.go
  2. 18 9
      server/images.go
  3. 22 10
      server/routes.go

+ 1 - 1
llm/llama.go

@@ -62,7 +62,7 @@ const maxRetries = 3
 type PredictOpts struct {
 	Prompt  string
 	Format  string
-	Images  []api.ImageData
+	Images  map[int]api.ImageData
 	Options api.Options
 }
 

+ 18 - 9
server/images.go

@@ -58,11 +58,17 @@ type Message struct {
 	Content string `json:"content"`
 }
 
+type ImageData struct {
+	Rank int
+	api.ImageData
+}
+
 type PromptVars struct {
 	System   string
 	Prompt   string
 	Response string
 	First    bool
+	Images   []ImageData
 }
 
 // extractParts extracts the parts of the template before and after the {{.Response}} node.
@@ -147,15 +153,13 @@ func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
 }
 
 type ChatHistory struct {
-	Prompts       []PromptVars
-	CurrentImages []api.ImageData
-	LastSystem    string
+	Prompts    []PromptVars
+	LastSystem string
 }
 
 // ChatPrompts returns a list of formatted chat prompts from a list of messages
 func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
 	// build the prompt from the list of messages
-	var currentImages []api.ImageData
 	lastSystem := m.System
 	currentVars := PromptVars{
 		First:  true,
@@ -163,6 +167,7 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
 	}
 
 	prompts := []PromptVars{}
+	var images []ImageData
 
 	for _, msg := range msgs {
 		switch strings.ToLower(msg.Role) {
@@ -182,10 +187,15 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
 
 			currentVars.Prompt = msg.Content
 			for i := range msg.Images {
-				currentVars.Prompt += fmt.Sprintf(" [img-%d]", len(currentImages)+i)
+				currentVars.Prompt += fmt.Sprintf(" [img-%d]", len(images)+i)
+				currentVars.Images = append(currentVars.Images, ImageData{
+					Rank:      len(images) + i,
+					ImageData: msg.Images[i],
+				})
+
 			}
 
-			currentImages = append(currentImages, msg.Images...)
+			images = append(images, currentVars.Images...)
 		case "assistant":
 			currentVars.Response = msg.Content
 			prompts = append(prompts, currentVars)
@@ -201,9 +211,8 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
 	}
 
 	return &ChatHistory{
-		Prompts:       prompts,
-		CurrentImages: currentImages,
-		LastSystem:    lastSystem,
+		Prompts:    prompts,
+		LastSystem: lastSystem,
 	}, nil
 }
 

+ 22 - 10
server/routes.go

@@ -312,11 +312,16 @@ func GenerateHandler(c *gin.Context) {
 			ch <- resp
 		}
 
+		images := make(map[int]api.ImageData)
+		for i := range req.Images {
+			images[i] = req.Images[i]
+		}
+
 		// Start prediction
 		predictReq := llm.PredictOpts{
 			Prompt:  prompt,
 			Format:  req.Format,
-			Images:  req.Images,
+			Images:  images,
 			Options: opts,
 		}
 		if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
@@ -1143,7 +1148,8 @@ func ChatHandler(c *gin.Context) {
 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 	}
-	prompt, err := trimmedPrompt(c.Request.Context(), chat, model)
+
+	prompt, images, err := trimmedPrompt(c.Request.Context(), chat, model)
 	if err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
@@ -1186,7 +1192,7 @@ func ChatHandler(c *gin.Context) {
 		predictReq := llm.PredictOpts{
 			Prompt:  prompt,
 			Format:  req.Format,
-			Images:  chat.CurrentImages,
+			Images:  images,
 			Options: opts,
 		}
 		if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
@@ -1233,25 +1239,27 @@ type promptInfo struct {
 
 // trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
 // while preserving the most recent system message.
-func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, error) {
+func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, map[int]api.ImageData, error) {
 	if len(chat.Prompts) == 0 {
-		return "", nil
+		return "", nil, nil
 	}
 
 	var promptsToAdd []promptInfo
 	var totalTokenLength int
 	var systemPromptIncluded bool
 
+	images := make(map[int]api.ImageData)
+
 	// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
 	for i := len(chat.Prompts) - 1; i >= 0; i-- {
 		promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1)
 		if err != nil {
-			return "", err
+			return "", nil, err
 		}
 
 		encodedTokens, err := loaded.runner.Encode(ctx, promptText)
 		if err != nil {
-			return "", err
+			return "", nil, err
 		}
 
 		if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 {
@@ -1261,6 +1269,10 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
 		totalTokenLength += len(encodedTokens)
 		systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != ""
 		promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)})
+
+		for _, image := range chat.Prompts[i].Images {
+			images[image.Rank] = image.ImageData
+		}
 	}
 
 	// ensure the system prompt is included, if not already
@@ -1268,7 +1280,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
 		var err error
 		promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd)
 		if err != nil {
-			return "", err
+			return "", nil, err
 		}
 	}
 
@@ -1279,11 +1291,11 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
 	for i, prompt := range promptsToAdd {
 		promptText, err := promptString(model, prompt.vars, i == 0)
 		if err != nil {
-			return "", err
+			return "", nil, err
 		}
 		result = promptText + result
 	}
-	return result, nil
+	return result, images, nil
 }
 
 // promptString applies the model template to the prompt