Browse Source

use llm.ImageData for chat

Michael Yang 1 year ago
parent
commit
d046bee790
2 changed files with 10 additions and 25 deletions
  1. 5 10
      server/images.go
  2. 5 15
      server/routes.go

+ 5 - 10
server/images.go

@@ -58,17 +58,12 @@ type Message struct {
 	Content string `json:"content"`
 	Content string `json:"content"`
 }
 }
 
 
-type ImageData struct {
-	Rank int
-	api.ImageData
-}
-
 type PromptVars struct {
 type PromptVars struct {
 	System   string
 	System   string
 	Prompt   string
 	Prompt   string
 	Response string
 	Response string
 	First    bool
 	First    bool
-	Images   []ImageData
+	Images   []llm.ImageData
 }
 }
 
 
 // extractParts extracts the parts of the template before and after the {{.Response}} node.
 // extractParts extracts the parts of the template before and after the {{.Response}} node.
@@ -167,7 +162,7 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
 	}
 	}
 
 
 	prompts := []PromptVars{}
 	prompts := []PromptVars{}
-	var images []ImageData
+	var images []llm.ImageData
 
 
 	for _, msg := range msgs {
 	for _, msg := range msgs {
 		switch strings.ToLower(msg.Role) {
 		switch strings.ToLower(msg.Role) {
@@ -188,9 +183,9 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
 			currentVars.Prompt = msg.Content
 			currentVars.Prompt = msg.Content
 			for i := range msg.Images {
 			for i := range msg.Images {
 				currentVars.Prompt += fmt.Sprintf(" [img-%d]", len(images)+i)
 				currentVars.Prompt += fmt.Sprintf(" [img-%d]", len(images)+i)
-				currentVars.Images = append(currentVars.Images, ImageData{
-					Rank:      len(images) + i,
-					ImageData: msg.Images[i],
+				currentVars.Images = append(currentVars.Images, llm.ImageData{
+					ID:   i,
+					Data: msg.Images[i],
 				})
 				})
 
 
 			}
 			}

+ 5 - 15
server/routes.go

@@ -1191,19 +1191,11 @@ func ChatHandler(c *gin.Context) {
 			ch <- resp
 			ch <- resp
 		}
 		}
 
 
-		var imageData []llm.ImageData
-		for k, v := range images {
-			imageData = append(imageData, llm.ImageData{
-				ID:   k,
-				Data: v,
-			})
-		}
-
 		// Start prediction
 		// Start prediction
 		predictReq := llm.PredictOpts{
 		predictReq := llm.PredictOpts{
 			Prompt:  prompt,
 			Prompt:  prompt,
 			Format:  req.Format,
 			Format:  req.Format,
-			Images:  imageData,
+			Images:  images,
 			Options: opts,
 			Options: opts,
 		}
 		}
 		if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
 		if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
@@ -1250,7 +1242,7 @@ type promptInfo struct {
 
 
 // trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
 // trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
 // while preserving the most recent system message.
 // while preserving the most recent system message.
-func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, map[int]api.ImageData, error) {
+func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, []llm.ImageData, error) {
 	if len(chat.Prompts) == 0 {
 	if len(chat.Prompts) == 0 {
 		return "", nil, nil
 		return "", nil, nil
 	}
 	}
@@ -1259,8 +1251,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
 	var totalTokenLength int
 	var totalTokenLength int
 	var systemPromptIncluded bool
 	var systemPromptIncluded bool
 
 
-	images := make(map[int]api.ImageData)
-
+	var images []llm.ImageData
 	// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
 	// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
 	for i := len(chat.Prompts) - 1; i >= 0; i-- {
 	for i := len(chat.Prompts) - 1; i >= 0; i-- {
 		promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1)
 		promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1)
@@ -1281,9 +1272,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
 		systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != ""
 		systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != ""
 		promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)})
 		promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)})
 
 
-		for _, image := range chat.Prompts[i].Images {
-			images[image.Rank] = image.ImageData
-		}
+		images = append(images, chat.Prompts[i].Images...)
 	}
 	}
 
 
 	// ensure the system prompt is included, if not already
 	// ensure the system prompt is included, if not already
@@ -1306,6 +1295,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
 		}
 		}
 		result = promptText + result
 		result = promptText + result
 	}
 	}
+
 	return result, images, nil
 	return result, images, nil
 }
 }