|
@@ -116,7 +116,14 @@ func TestTemplate(t *testing.T) {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
|
|
|
|
- if diff := cmp.Diff(actual.Bytes(), expect); diff != "" {
|
|
|
+ bts := actual.Bytes()
|
|
|
+
|
|
|
+ if slices.Contains([]string{"chatqa.gotmpl", "llama2-chat.gotmpl", "mistral-instruct.gotmpl", "openchat.gotmpl", "vicuna.gotmpl"}, match) && bts[len(bts)-1] == ' ' {
|
|
|
+ t.Log("removing trailing space from output")
|
|
|
+ bts = bts[:len(bts)-1]
|
|
|
+ }
|
|
|
+
|
|
|
+ if diff := cmp.Diff(bts, expect); diff != "" {
|
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
|
}
|
|
|
})
|
|
@@ -203,11 +210,18 @@ func TestExecuteWithMessages(t *testing.T) {
|
|
|
{
|
|
|
"mistral",
|
|
|
[]template{
|
|
|
- {"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
|
|
|
- {"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
|
|
- {"messages", `{{- range $index, $_ := .Messages }}
|
|
|
-{{- if eq .Role "user" }}[INST] {{ if and (eq $index 0) $.System }}{{ $.System }}{{ "\n\n" }}
|
|
|
-{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
|
|
|
+ {"no response", `[INST] {{ if .System }}{{ .System }}
|
|
|
+
|
|
|
+{{ end }}{{ .Prompt }}[/INST] `},
|
|
|
+ {"response", `[INST] {{ if .System }}{{ .System }}
|
|
|
+
|
|
|
+{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
|
|
+ {"messages", `{{- $system := aggregate $.Messages "system" -}}
|
|
|
+{{- range $index, $_ := .Messages }}
|
|
|
+{{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }}
|
|
|
+{{- $system = "" }}
|
|
|
+
|
|
|
+{{ end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
|
|
|
{{- end }}
|
|
|
{{- end }}`},
|
|
|
},
|
|
@@ -223,12 +237,18 @@ func TestExecuteWithMessages(t *testing.T) {
|
|
|
{
|
|
|
"mistral system",
|
|
|
[]template{
|
|
|
- {"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
|
|
|
- {"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
|
|
- {"messages", `
|
|
|
+ {"no response", `[INST] {{ if .System }}{{ .System }}
|
|
|
+
|
|
|
+{{ end }}{{ .Prompt }}[/INST] `},
|
|
|
+ {"response", `[INST] {{ if .System }}{{ .System }}
|
|
|
+
|
|
|
+{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
|
|
+ {"messages", `{{- $system := aggregate $.Messages "system" -}}
|
|
|
{{- range $index, $_ := .Messages }}
|
|
|
-{{- if eq .Role "user" }}[INST] {{ if and (eq $index 0) $.System }}{{ $.System }}{{ "\n\n" }}
|
|
|
-{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
|
|
|
+{{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }}
|
|
|
+{{- $system = "" }}
|
|
|
+
|
|
|
+{{ end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
|
|
|
{{- end }}
|
|
|
{{- end }}`},
|
|
|
},
|
|
@@ -256,12 +276,9 @@ Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
|
|
|
{{ .Response }}<|im_end|>
|
|
|
`},
|
|
|
{"messages", `
|
|
|
-{{- range $index, $_ := .Messages }}
|
|
|
-{{- if and (eq .Role "user") (eq $index 0) $.System }}<|im_start|>system
|
|
|
-{{ $.System }}<|im_end|>{{ "\n" }}
|
|
|
-{{- end }}<|im_start|>{{ .Role }}
|
|
|
-{{ .Content }}<|im_end|>{{ "\n" }}
|
|
|
-{{- end }}<|im_start|>assistant
|
|
|
+{{- range $index, $_ := .Messages }}<|im_start|>{{ .Role }}
|
|
|
+{{ .Content }}<|im_end|>
|
|
|
+{{ end }}<|im_start|>assistant
|
|
|
`},
|
|
|
},
|
|
|
Values{
|
|
@@ -294,9 +311,11 @@ What is your name?<|im_end|>
|
|
|
`},
|
|
|
{"messages", `
|
|
|
{{- range .Messages }}
|
|
|
-{{- if eq .Role "user" }}Question: {{ .Content }}{{ "\n\n" }}
|
|
|
-{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ "\n\n" }}
|
|
|
-{{- end }}
|
|
|
+{{- if eq .Role "user" }}Question: {{ .Content }}
|
|
|
+
|
|
|
+{{ else if eq .Role "assistant" }}Answer: {{ .Content }}
|
|
|
+
|
|
|
+{{ end }}
|
|
|
{{- end }}Answer: `},
|
|
|
},
|
|
|
Values{
|