Bläddra i källkod

Merge pull request #5653 from ollama/mxyng/collect-system

template: preprocess message and collect system
Michael Yang 9 månader sedan
förälder
incheckning
e5c65a85df
2 ändrade filer med 23 tillägg och 67 borttagningar
  1. 15 22
      template/template.go
  2. 8 45
      template/template_test.go

+ 15 - 22
template/template.go

@@ -102,22 +102,8 @@ var response = parse.ActionNode{
 	},
 	},
 }
 }
 
 
-var funcs = template.FuncMap{
-	// contents returns the contents of messages with an optional role filter
-	"contents": func(v []*api.Message, role ...string) string {
-		var parts []string
-		for _, m := range v {
-			if len(role) == 0 || role[0] == "" || m.Role == role[0] {
-				parts = append(parts, m.Content)
-			}
-		}
-
-		return strings.Join(parts, "\n\n")
-	},
-}
-
 func Parse(s string) (*Template, error) {
 func Parse(s string) (*Template, error) {
-	tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
+	tmpl := template.New("").Option("missingkey=zero")
 
 
 	tmpl, err := tmpl.Parse(s)
 	tmpl, err := tmpl.Parse(s)
 	if err != nil {
 	if err != nil {
@@ -163,15 +149,16 @@ type Values struct {
 }
 }
 
 
 func (t *Template) Execute(w io.Writer, v Values) error {
 func (t *Template) Execute(w io.Writer, v Values) error {
-	collated := collate(v.Messages)
+	system, collated := collate(v.Messages)
 	if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
 	if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
 		return t.Template.Execute(w, map[string]any{
 		return t.Template.Execute(w, map[string]any{
+			"System":   system,
 			"Messages": collated,
 			"Messages": collated,
 		})
 		})
 	}
 	}
 
 
 	var b bytes.Buffer
 	var b bytes.Buffer
-	var system, prompt, response string
+	var prompt, response string
 	for i, m := range collated {
 	for i, m := range collated {
 		switch m.Role {
 		switch m.Role {
 		case "system":
 		case "system":
@@ -223,11 +210,13 @@ func (t *Template) Execute(w io.Writer, v Values) error {
 }
 }
 
 
 // collate messages based on role. consecutive messages of the same role are merged
 // collate messages based on role. consecutive messages of the same role are merged
-// into a single message. collate also pulls out and merges messages with Role == "system"
-// which are templated separately. As a side effect, it mangles message content adding image
-// tags ([img-%d]) as needed
-func collate(msgs []api.Message) (collated []*api.Message) {
+// into a single message. collate also collects and returns all system messages.
+// collate mutates message content adding image tags ([img-%d]) as needed
+func collate(msgs []api.Message) (string, []*api.Message) {
 	var n int
 	var n int
+
+	var system []string
+	var collated []*api.Message
 	for i := range msgs {
 	for i := range msgs {
 		msg := msgs[i]
 		msg := msgs[i]
 		for range msg.Images {
 		for range msg.Images {
@@ -240,6 +229,10 @@ func collate(msgs []api.Message) (collated []*api.Message) {
 			n++
 			n++
 		}
 		}
 
 
+		if msg.Role == "system" {
+			system = append(system, msg.Content)
+		}
+
 		if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
 		if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
 			collated[len(collated)-1].Content += "\n\n" + msg.Content
 			collated[len(collated)-1].Content += "\n\n" + msg.Content
 		} else {
 		} else {
@@ -247,7 +240,7 @@ func collate(msgs []api.Message) (collated []*api.Message) {
 		}
 		}
 	}
 	}
 
 
-	return
+	return strings.Join(system, "\n\n"), collated
 }
 }
 
 
 func parseNode(n parse.Node) []string {
 func parseNode(n parse.Node) []string {

+ 8 - 45
template/template_test.go

@@ -216,13 +216,11 @@ func TestExecuteWithMessages(t *testing.T) {
 				{"response", `[INST] {{ if .System }}{{ .System }}
 				{"response", `[INST] {{ if .System }}{{ .System }}
 
 
 {{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
 {{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
-				{"messages", `{{- $system := contents .Messages "system" -}}
-{{- range $index, $_ := .Messages }}
-{{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }}
-{{- $system = "" }}
+				{"messages", `[INST] {{ if .System }}{{ .System }}
 
 
-{{ end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
-{{- end }}
+{{ end }}
+{{- range .Messages }}
+{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
 {{- end }}`},
 {{- end }}`},
 			},
 			},
 			Values{
 			Values{
@@ -243,13 +241,11 @@ func TestExecuteWithMessages(t *testing.T) {
 				{"response", `[INST] {{ if .System }}{{ .System }}
 				{"response", `[INST] {{ if .System }}{{ .System }}
 
 
 {{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
 {{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
-				{"messages", `{{- $system := contents .Messages "system" -}}
-{{- range $index, $_ := .Messages }}
-{{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }}
-{{- $system = "" }}
+				{"messages", `[INST] {{ if .System }}{{ .System }}
 
 
-{{ end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
-{{- end }}
+{{ end }}
+{{- range .Messages }}
+{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
 {{- end }}`},
 {{- end }}`},
 			},
 			},
 			Values{
 			Values{
@@ -363,36 +359,3 @@ Answer: `,
 		})
 		})
 	}
 	}
 }
 }
-
-func TestFuncs(t *testing.T) {
-	t.Run("contents", func(t *testing.T) {
-		cases := map[string]string{
-			"":          "A\n\nB\n\nC\n\nD\n\nE\n\nF",
-			"system":    "A\n\nF",
-			"user":      "B\n\nE",
-			"assistant": "C\n\nD",
-		}
-
-		s := []*api.Message{
-			{Role: "system", Content: "A"},
-			{Role: "user", Content: "B"},
-			{Role: "assistant", Content: "C"},
-			{Role: "assistant", Content: "D"},
-			{Role: "user", Content: "E"},
-			{Role: "system", Content: "F"},
-		}
-
-		fn, ok := funcs["contents"].(func([]*api.Message, ...string) string)
-		if !ok {
-			t.Fatal("contents is not a function")
-		}
-
-		for k, v := range cases {
-			t.Run(k, func(t *testing.T) {
-				if diff := cmp.Diff(fn(s, k), v); diff != "" {
-					t.Errorf("mismatch (-got +want):\n%s", diff)
-				}
-			})
-		}
-	})
-}