Ver código fonte

only allow a single image to be passed

Patrick Devine 6 meses atrás
pai
commit
3a1c8da5e4
2 arquivos alterados com 82 adições e 24 exclusões
  1. 35 19
      server/prompt.go
  2. 47 5
      server/prompt_test.go

+ 35 - 19
server/prompt.go

@@ -4,6 +4,7 @@ import (
 	"bytes"
 	"bytes"
 	"context"
 	"context"
 	"encoding/binary"
 	"encoding/binary"
+	"errors"
 	"fmt"
 	"fmt"
 	"log/slog"
 	"log/slog"
 	"strings"
 	"strings"
@@ -16,16 +17,28 @@ import (
 
 
 type tokenizeFunc func(context.Context, string) ([]int, error)
 type tokenizeFunc func(context.Context, string) ([]int, error)
 
 
+var errTooManyImages = errors.New("vision model only supports a single image per message")
+
 // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
 // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
 // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
 // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
 // latest message and 2) system messages
 // latest message and 2) system messages
 func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
 func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
 	var system []api.Message
 	var system []api.Message
 
 
-	// always include the last message
+	isMllama := checkMllamaModelFamily(m)
+
 	n := len(msgs) - 1
 	n := len(msgs) - 1
 	// in reverse, find all messages that fit into context window
 	// in reverse, find all messages that fit into context window
-	for i := n - 1; i >= 0; i-- {
+	for i := n; i >= 0; i-- {
+		if isMllama && len(msgs[i].Images) > 1 {
+			return "", nil, errTooManyImages
+		}
+
+		// always include the last message
+		if i == n {
+			continue
+		}
+
 		system = make([]api.Message, 0)
 		system = make([]api.Message, 0)
 		for j := range i {
 		for j := range i {
 			if msgs[j].Role == "system" {
 			if msgs[j].Role == "system" {
@@ -62,27 +75,30 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
 
 
 	currMsgIdx := n
 	currMsgIdx := n
 
 
-	if checkMllamaModelFamily(m) {
+	if isMllama {
 		lastMsgIdx := len(msgs) - 1
 		lastMsgIdx := len(msgs) - 1
-		if len(msgs[lastMsgIdx].Images) == 1 {
-			data, aspectRatioID, err := imageproc.Preprocess(msgs[lastMsgIdx].Images[0])
-			if err != nil {
-				return "", nil, err
-			}
+		for i := lastMsgIdx; i > currMsgIdx; i-- {
+			if len(msgs[i].Images) > 0 {
+				data, aspectRatioID, err := imageproc.Preprocess(msgs[i].Images[0])
+				if err != nil {
+					return "", nil, err
+				}
 
 
-			buf := new(bytes.Buffer)
-			err = binary.Write(buf, binary.LittleEndian, data)
-			if err != nil {
-				return "", nil, err
-			}
+				buf := new(bytes.Buffer)
+				err = binary.Write(buf, binary.LittleEndian, data)
+				if err != nil {
+					return "", nil, err
+				}
 
 
-			imgData := llm.ImageData{
-				Data:          buf.Bytes(),
-				AspectRatioID: aspectRatioID,
-			}
+				imgData := llm.ImageData{
+					Data:          buf.Bytes(),
+					AspectRatioID: aspectRatioID,
+				}
 
 
-			msgs[lastMsgIdx].Content = strings.TrimSpace("<|image|>" + msgs[lastMsgIdx].Content)
-			images = append(images, imgData)
+				msgs[i].Content = strings.TrimSpace("<|image|>" + msgs[i].Content)
+				images = append(images, imgData)
+				break
+			}
 		}
 		}
 	} else {
 	} else {
 		for cnt, msg := range msgs[currMsgIdx:] {
 		for cnt, msg := range msgs[currMsgIdx:] {

+ 47 - 5
server/prompt_test.go

@@ -18,6 +18,7 @@ func TestChatPrompt(t *testing.T) {
 		prompt        string
 		prompt        string
 		images        [][]byte
 		images        [][]byte
 		aspectRatioID int
 		aspectRatioID int
+		error         error
 	}
 	}
 
 
 	tmpl, err := template.Parse(`
 	tmpl, err := template.Parse(`
@@ -30,15 +31,26 @@ func TestChatPrompt(t *testing.T) {
 	visionModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
 	visionModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
 	mllamaModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}, Config: ConfigV2{ModelFamilies: []string{"mllama"}}}
 	mllamaModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}, Config: ConfigV2{ModelFamilies: []string{"mllama"}}}
 
 
-	img := image.NewRGBA(image.Rect(0, 0, 5, 5))
-	var buf bytes.Buffer
+	createImg := func(width, height int) ([]byte, error) {
+		img := image.NewRGBA(image.Rect(0, 0, 5, 5))
+		var buf bytes.Buffer
 
 
-	err = png.Encode(&buf, img)
+		if err := png.Encode(&buf, img); err != nil {
+			return nil, err
+		}
+
+		return buf.Bytes(), nil
+	}
+
+	imgBuf, err := createImg(5, 5)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 
 
-	imgBuf := buf.Bytes()
+	imgBuf2, err := createImg(6, 6)
+	if err != nil {
+		t.Fatal(err)
+	}
 
 
 	cases := []struct {
 	cases := []struct {
 		name  string
 		name  string
@@ -232,6 +244,34 @@ func TestChatPrompt(t *testing.T) {
 				aspectRatioID: 1,
 				aspectRatioID: 1,
 			},
 			},
 		},
 		},
+		{
+			name:  "multiple messages with mllama",
+			model: mllamaModel,
+			limit: 2048,
+			msgs: []api.Message{
+				{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{imgBuf}},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf2}},
+			},
+			expect: expect{
+				prompt:        "You're a test, Harry! I-I'm a what? <|image|>A test. And a thumping good one at that, I'd wager. ",
+				images:        [][]byte{imgBuf2},
+				aspectRatioID: 1,
+			},
+		},
+		{
+			name:  "too many images with mllama",
+			model: mllamaModel,
+			limit: 2048,
+			msgs: []api.Message{
+				{Role: "user", Content: "You're a test, Harry!"},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf, imgBuf}},
+			},
+			expect: expect{
+				error: errTooManyImages,
+			},
+		},
 	}
 	}
 
 
 	for _, tt := range cases {
 	for _, tt := range cases {
@@ -239,8 +279,10 @@ func TestChatPrompt(t *testing.T) {
 			model := tt.model
 			model := tt.model
 			opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
 			opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
 			prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
 			prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
-			if err != nil {
+			if tt.error == nil && err != nil {
 				t.Fatal(err)
 				t.Fatal(err)
+			} else if tt.error != nil && err != tt.error {
+				t.Fatalf("expected err '%q', got '%q'", tt.error, err)
 			}
 			}
 
 
 			if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
 			if diff := cmp.Diff(prompt, tt.prompt); diff != "" {