瀏覽代碼

do no automatically aggregate system messages

Michael Yang 9 月之前
父節點
當前提交
e64f9ebb44
共有 2 個文件被更改,包括 27 次插入23 次删除
  1. 20 19
      template/template.go
  2. 7 4
      template/template_test.go

+ 20 - 19
template/template.go

@@ -102,8 +102,21 @@ var response = parse.ActionNode{
 	},
 }
 
+var funcs = template.FuncMap{
+	"aggregate": func(v []*api.Message, role string) string {
+		var aggregated []string
+		for _, m := range v {
+			if m.Role == role {
+				aggregated = append(aggregated, m.Content)
+			}
+		}
+
+		return strings.Join(aggregated, "\n\n")
+	},
+}
+
 func Parse(s string) (*Template, error) {
-	tmpl := template.New("").Option("missingkey=zero")
+	tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
 
 	tmpl, err := tmpl.Parse(s)
 	if err != nil {
@@ -149,23 +162,21 @@ type Values struct {
 }
 
 func (t *Template) Execute(w io.Writer, v Values) error {
-	system, collated := collate(v.Messages)
+	collated := collate(v.Messages)
 	if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
 		return t.Template.Execute(w, map[string]any{
-			"System":   system,
 			"Messages": collated,
 		})
 	}
 
 	var b bytes.Buffer
-	var prompt, response string
+	var system, prompt, response string
 	for i, m := range collated {
 		switch m.Role {
+		case "system":
+			system = m.Content
 		case "user":
 			prompt = m.Content
-			if i != 0 {
-				system = ""
-			}
 		case "assistant":
 			response = m.Content
 		}
@@ -179,6 +190,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
 				return err
 			}
 
+			system = ""
 			prompt = ""
 			response = ""
 		}
@@ -209,25 +221,14 @@ func (t *Template) Execute(w io.Writer, v Values) error {
 	return err
 }
 
-type messages []*api.Message
-
 // collate messages based on role. consecutive messages of the same role are merged
 // into a single message. collate also pulls out and merges messages with Role == "system"
 // which are templated separately. As a side effect, it mangles message content adding image
 // tags ([img-%d]) as needed
-func collate(msgs []api.Message) (system string, collated messages) {
+func collate(msgs []api.Message) (collated []*api.Message) {
 	var n int
 	for i := range msgs {
 		msg := msgs[i]
-		if msg.Role == "system" {
-			if system != "" {
-				system += "\n\n"
-			}
-
-			system += msg.Content
-			continue
-		}
-
 		for range msg.Images {
 			imageTag := fmt.Sprintf("[img-%d]", n)
 			if !strings.Contains(msg.Content, "[img]") {

+ 7 - 4
template/template_test.go

@@ -122,6 +122,7 @@ func TestTemplate(t *testing.T) {
 				})
 
 				t.Run("legacy", func(t *testing.T) {
+					t.Skip("legacy outputs are currently default outputs")
 					var legacy bytes.Buffer
 					if err := tmpl.Execute(&legacy, Values{Messages: tt, forceLegacy: true}); err != nil {
 						t.Fatal(err)
@@ -154,11 +155,13 @@ func TestParse(t *testing.T) {
 		{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
 		{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
 		{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
-		{"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}},
+		{`{{- range .Messages }}
+{{- if eq .Role "system" }}SYSTEM:
+{{- else if eq .Role "user" }}USER:
+{{- else if eq .Role "assistant" }}ASSISTANT:
+{{- end }} {{ .Content }}
+{{- end }}`, []string{"content", "messages", "role"}},
 		{`{{- if .Messages }}
-{{- if .System }}<|im_start|>system
-{{ .System }}<|im_end|>
-{{ end }}
 {{- range .Messages }}<|im_start|>{{ .Role }}
 {{ .Content }}<|im_end|>
 {{ end }}<|im_start|>assistant