Selaa lähdekoodia

prepend image tags (#2789)

instead of appending image tags, prepend them - this generally produces better results
Michael Yang 1 vuosi sitten
vanhempi
commit
0e19476b56
3 muutettua tiedostoa jossa 21 lisäystä ja 18 poistoa
  1. 5 3
      server/prompt.go
  2. 3 3
      server/prompt_test.go
  3. 13 12
      server/routes.go

+ 5 - 3
server/prompt.go

@@ -121,13 +121,15 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
 				p = prompt{}
 			}
 
-			p.Prompt = msg.Content
-
+			var sb strings.Builder
 			for range msg.Images {
-				p.Prompt += fmt.Sprintf(" [img-%d]", imgId)
+				fmt.Fprintf(&sb, "[img-%d] ", imgId)
 				p.images = append(p.images, imgId)
 				imgId += 1
 			}
+
+			sb.WriteString(msg.Content)
+			p.Prompt = sb.String()
 		case "assistant":
 			if p.Response != "" {
 				prompts = append(prompts, p)

+ 3 - 3
server/prompt_test.go

@@ -155,7 +155,7 @@ func TestChatPrompt(t *testing.T) {
 				{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}},
 			},
 			window: 1024,
-			want:   "You are a Wizard. Hello [img-0]",
+			want:   "You are a Wizard. [img-0] Hello",
 		},
 		{
 			name:     "images truncated",
@@ -165,7 +165,7 @@ func TestChatPrompt(t *testing.T) {
 				{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}},
 			},
 			window: 1024,
-			want:   "You are a Wizard. Hello [img-1]",
+			want:   "You are a Wizard. [img-0] [img-1] Hello",
 		},
 		{
 			name:     "empty list",
@@ -198,7 +198,7 @@ func TestChatPrompt(t *testing.T) {
 			}
 
 			if got != tc.want {
-				t.Errorf("got = %v, want %v", got, tc.want)
+				t.Errorf("got: %q, want: %q", got, tc.want)
 			}
 		})
 	}

+ 13 - 12
server/routes.go

@@ -250,6 +250,19 @@ func GenerateHandler(c *gin.Context) {
 		slog.Debug("generate handler", "system", req.System)
 
 		var sb strings.Builder
+		for i := range req.Images {
+			fmt.Fprintf(&sb, "[img-%d] ", i)
+		}
+
+		sb.WriteString(req.Prompt)
+
+		p, err := Prompt(req.Template, req.System, sb.String(), "", true)
+		if err != nil {
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
+		}
+
+		sb.Reset()
 		if req.Context != nil {
 			prev, err := loaded.runner.Decode(c.Request.Context(), req.Context)
 			if err != nil {
@@ -260,18 +273,6 @@ func GenerateHandler(c *gin.Context) {
 			sb.WriteString(prev)
 		}
 
-		// write image tags
-		// TODO: limit the number of images to fit in the context similar to the chat endpoint
-		for i := range req.Images {
-			req.Prompt += fmt.Sprintf(" [img-%d]", i)
-		}
-
-		p, err := Prompt(req.Template, req.System, req.Prompt, "", true)
-		if err != nil {
-			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-			return
-		}
-
 		sb.WriteString(p)
 
 		prompt = sb.String()