|
@@ -216,7 +216,7 @@ func TestExecuteWithMessages(t *testing.T) {
|
|
{"response", `[INST] {{ if .System }}{{ .System }}
|
|
{"response", `[INST] {{ if .System }}{{ .System }}
|
|
|
|
|
|
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
|
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
|
- {"messages", `{{- $system := aggregate $.Messages "system" -}}
|
|
|
|
|
|
+ {"messages", `{{- $system := contents .Messages "system" -}}
|
|
{{- range $index, $_ := .Messages }}
|
|
{{- range $index, $_ := .Messages }}
|
|
{{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }}
|
|
{{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }}
|
|
{{- $system = "" }}
|
|
{{- $system = "" }}
|
|
@@ -243,7 +243,7 @@ func TestExecuteWithMessages(t *testing.T) {
|
|
{"response", `[INST] {{ if .System }}{{ .System }}
|
|
{"response", `[INST] {{ if .System }}{{ .System }}
|
|
|
|
|
|
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
|
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
|
- {"messages", `{{- $system := aggregate $.Messages "system" -}}
|
|
|
|
|
|
+ {"messages", `{{- $system := contents .Messages "system" -}}
|
|
{{- range $index, $_ := .Messages }}
|
|
{{- range $index, $_ := .Messages }}
|
|
{{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }}
|
|
{{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }}
|
|
{{- $system = "" }}
|
|
{{- $system = "" }}
|
|
@@ -363,3 +363,36 @@ Answer: `,
|
|
})
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+func TestFuncs(t *testing.T) {
|
|
|
|
+ t.Run("contents", func(t *testing.T) {
|
|
|
|
+ cases := map[string]string{
|
|
|
|
+ "": "A\n\nB\n\nC\n\nD\n\nE\n\nF",
|
|
|
|
+ "system": "A\n\nF",
|
|
|
|
+ "user": "B\n\nE",
|
|
|
|
+ "assistant": "C\n\nD",
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ s := []*api.Message{
|
|
|
|
+ {Role: "system", Content: "A"},
|
|
|
|
+ {Role: "user", Content: "B"},
|
|
|
|
+ {Role: "assistant", Content: "C"},
|
|
|
|
+ {Role: "assistant", Content: "D"},
|
|
|
|
+ {Role: "user", Content: "E"},
|
|
|
|
+ {Role: "system", Content: "F"},
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ fn, ok := funcs["contents"].(func([]*api.Message, ...string) string)
|
|
|
|
+ if !ok {
|
|
|
|
+ t.Fatal("contents is not a function")
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ for k, v := range cases {
|
|
|
|
+ t.Run(k, func(t *testing.T) {
|
|
|
|
+ if diff := cmp.Diff(fn(s, k), v); diff != "" {
|
|
|
|
+ t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
|
|
+ }
|
|
|
|
+ })
|
|
|
|
+ }
|
|
|
|
+ })
|
|
|
|
+}
|