浏览代码

fix system prompt (#5662)

* fix system prompt

* execute template when hitting previous roles

* fix tests

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
Michael Yang 9 月之前
父节点
当前提交
22c5451fc2
共有 3 个文件被更改,包括 51 次插入30 次删除
  1. 7 16
      server/prompt.go
  2. 18 0
      server/prompt_test.go
  3. 26 14
      template/template.go

+ 7 - 16
server/prompt.go

@@ -4,7 +4,6 @@ import (
 	"bytes"
 	"context"
 	"log/slog"
-	"slices"
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/llm"
@@ -17,26 +16,18 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
 // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
 // latest message and 2) system messages
 func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
-	// pull out any system messages which should always be included in the prompt
 	var system []api.Message
-	msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
-		if m.Role == "system" {
-			system = append(system, m)
-			return true
-		}
-
-		return false
-	})
-
-	if len(system) == 0 && m.System != "" {
-		// add model system prompt since it wasn't provided
-		system = append(system, api.Message{Role: "system", Content: m.System})
-	}
-
 	// always include the last message
 	n := len(msgs) - 1
 	// in reverse, find all messages that fit into context window
 	for i := n - 1; i >= 0; i-- {
+		system = make([]api.Message, 0)
+		for j := range i {
+			if msgs[j].Role == "system" {
+				system = append(system, msgs[j])
+			}
+		}
+
 		var b bytes.Buffer
 		if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
 			return "", nil, err

+ 18 - 0
server/prompt_test.go

@@ -6,6 +6,7 @@ import (
 	"strings"
 	"testing"
 
+	"github.com/google/go-cmp/cmp"
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/template"
 )
@@ -164,6 +165,19 @@ func TestChatPrompt(t *testing.T) {
 				prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
 			},
 		},
+		{
+			name:  "out of order system",
+			limit: 2048,
+			msgs: []api.Message{
+				{Role: "user", Content: "You're a test, Harry!"},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "system", Content: "You are the Test Who Lived."},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
+			},
+			expect: expect{
+				prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
+			},
+		},
 	}
 
 	tmpl, err := template.Parse(`
@@ -187,6 +201,10 @@ func TestChatPrompt(t *testing.T) {
 				t.Errorf("expected %q, got %q", tt.prompt, prompt)
 			}
 
+			if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
+				t.Errorf("mismatch (-got +want):\n%s", diff)
+			}
+
 			if len(images) != len(tt.images) {
 				t.Fatalf("expected %d images, got %d", len(tt.images), len(images))
 			}

+ 26 - 14
template/template.go

@@ -149,27 +149,19 @@ type Values struct {
 }
 
 func (t *Template) Execute(w io.Writer, v Values) error {
-	system, collated := collate(v.Messages)
+	system, messages := collate(v.Messages)
 	if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
 		return t.Template.Execute(w, map[string]any{
 			"System":   system,
-			"Messages": collated,
+			"Messages": messages,
 		})
 	}
 
+	system = ""
 	var b bytes.Buffer
 	var prompt, response string
-	for i, m := range collated {
-		switch m.Role {
-		case "system":
-			system = m.Content
-		case "user":
-			prompt = m.Content
-		case "assistant":
-			response = m.Content
-		}
-
-		if i != len(collated)-1 && prompt != "" && response != "" {
+	for _, m := range messages {
+		execute := func () error {
 			if err := t.Template.Execute(&b, map[string]any{
 				"System":   system,
 				"Prompt":   prompt,
@@ -181,6 +173,26 @@ func (t *Template) Execute(w io.Writer, v Values) error {
 			system = ""
 			prompt = ""
 			response = ""
+			return nil
+		}
+
+		switch m.Role {
+		case "system":
+			if prompt != "" || response != "" {
+				if err := execute(); err != nil {
+					return err
+				}
+			}
+			system = m.Content
+		case "user":
+			if response != "" {
+				if err := execute(); err != nil {
+					return err
+				}
+			}
+			prompt = m.Content
+		case "assistant":
+			response = m.Content
 		}
 	}
 
@@ -199,7 +211,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
 
 	tree := parse.Tree{Root: nodes.(*parse.ListNode)}
 	if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
-		"System": "",
+		"System": system,
 		"Prompt": prompt,
 	}); err != nil {
 		return err