瀏覽代碼

rename aggregate to contents

Michael Yang 9 月之前
父節點
當前提交
5056bb9c01
共有 2 個文件被更改,包括 41 次插入7 次删除
  1. 6 5
      template/template.go
  2. 35 2
      template/template_test.go

+ 6 - 5
template/template.go

@@ -103,15 +103,16 @@ var response = parse.ActionNode{
 }
 
 var funcs = template.FuncMap{
-	"aggregate": func(v []*api.Message, role string) string {
-		var aggregated []string
+	// contents returns the contents of messages with an optional role filter
+	"contents": func(v []*api.Message, role ...string) string {
+		var parts []string
 		for _, m := range v {
-			if m.Role == role {
-				aggregated = append(aggregated, m.Content)
+			if len(role) == 0 || role[0] == "" || m.Role == role[0] {
+				parts = append(parts, m.Content)
 			}
 		}
 
-		return strings.Join(aggregated, "\n\n")
+		return strings.Join(parts, "\n\n")
 	},
 }
 

+ 35 - 2
template/template_test.go

@@ -216,7 +216,7 @@ func TestExecuteWithMessages(t *testing.T) {
 				{"response", `[INST] {{ if .System }}{{ .System }}
 
 {{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
-				{"messages", `{{- $system := aggregate $.Messages "system" -}}
+				{"messages", `{{- $system := contents .Messages "system" -}}
 {{- range $index, $_ := .Messages }}
 {{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }}
 {{- $system = "" }}
@@ -243,7 +243,7 @@ func TestExecuteWithMessages(t *testing.T) {
 				{"response", `[INST] {{ if .System }}{{ .System }}
 
 {{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
-				{"messages", `{{- $system := aggregate $.Messages "system" -}}
+				{"messages", `{{- $system := contents .Messages "system" -}}
 {{- range $index, $_ := .Messages }}
 {{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }}
 {{- $system = "" }}
@@ -363,3 +363,36 @@ Answer: `,
 		})
 	}
 }
+
+func TestFuncs(t *testing.T) {
+	t.Run("contents", func(t *testing.T) {
+		cases := map[string]string{
+			"":          "A\n\nB\n\nC\n\nD\n\nE\n\nF",
+			"system":    "A\n\nF",
+			"user":      "B\n\nE",
+			"assistant": "C\n\nD",
+		}
+
+		s := []*api.Message{
+			{Role: "system", Content: "A"},
+			{Role: "user", Content: "B"},
+			{Role: "assistant", Content: "C"},
+			{Role: "assistant", Content: "D"},
+			{Role: "user", Content: "E"},
+			{Role: "system", Content: "F"},
+		}
+
+		fn, ok := funcs["contents"].(func([]*api.Message, ...string) string)
+		if !ok {
+			t.Fatal("contents is not a function")
+		}
+
+		for k, v := range cases {
+			t.Run(k, func(t *testing.T) {
+				if diff := cmp.Diff(fn(s, k), v); diff != "" {
+					t.Errorf("mismatch (-got +want):\n%s", diff)
+				}
+			})
+		}
+	})
+}