|
@@ -8,9 +8,10 @@ import (
|
|
|
"os"
|
|
|
"path/filepath"
|
|
|
"slices"
|
|
|
+ "strings"
|
|
|
"testing"
|
|
|
- "text/template"
|
|
|
|
|
|
+ "github.com/google/go-cmp/cmp"
|
|
|
"github.com/ollama/ollama/api"
|
|
|
"github.com/ollama/ollama/llm"
|
|
|
)
|
|
@@ -47,7 +48,7 @@ func TestNamed(t *testing.T) {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
|
|
|
|
- tmpl, err := template.New(s).Parse(b.String())
|
|
|
+ tmpl, err := Parse(b.String())
|
|
|
if err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
@@ -60,6 +61,70 @@ func TestNamed(t *testing.T) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func TestTemplate(t *testing.T) {
|
|
|
+ cases := make(map[string][]api.Message)
|
|
|
+ for _, mm := range [][]api.Message{
|
|
|
+ {
|
|
|
+ {Role: "user", Content: "Hello, how are you?"},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ {Role: "user", Content: "Hello, how are you?"},
|
|
|
+ {Role: "assistant", Content: "I'm doing great. How can I help you today?"},
|
|
|
+ {Role: "user", Content: "I'd like to show off how chat templating works!"},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ {Role: "system", Content: "You are a helpful assistant."},
|
|
|
+ {Role: "user", Content: "Hello, how are you?"},
|
|
|
+ {Role: "assistant", Content: "I'm doing great. How can I help you today?"},
|
|
|
+ {Role: "user", Content: "I'd like to show off how chat templating works!"},
|
|
|
+ },
|
|
|
+ } {
|
|
|
+ var roles []string
|
|
|
+ for _, m := range mm {
|
|
|
+ roles = append(roles, m.Role)
|
|
|
+ }
|
|
|
+
|
|
|
+ cases[strings.Join(roles, "-")] = mm
|
|
|
+ }
|
|
|
+
|
|
|
+ matches, err := filepath.Glob("*.gotmpl")
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, match := range matches {
|
|
|
+ t.Run(match, func(t *testing.T) {
|
|
|
+ bts, err := os.ReadFile(match)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ tmpl, err := Parse(string(bts))
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ for n, tt := range cases {
|
|
|
+ t.Run(n, func(t *testing.T) {
|
|
|
+ var actual bytes.Buffer
|
|
|
+ if err := tmpl.Execute(&actual, Values{Messages: tt}); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ expect, err := os.ReadFile(filepath.Join("testdata", match, n))
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if diff := cmp.Diff(actual.Bytes(), expect); diff != "" {
|
|
|
+ t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func TestParse(t *testing.T) {
|
|
|
cases := []struct {
|
|
|
template string
|