|
@@ -1,21 +1,98 @@
|
|
package server
|
|
package server
|
|
|
|
|
|
import (
|
|
import (
|
|
|
|
+ "strings"
|
|
"testing"
|
|
"testing"
|
|
|
|
+
|
|
|
|
+ "github.com/jmorganca/ollama/api"
|
|
)
|
|
)
|
|
|
|
|
|
-func TestModelPrompt(t *testing.T) {
|
|
|
|
- m := Model{
|
|
|
|
- Template: "a{{ .Prompt }}b",
|
|
|
|
- }
|
|
|
|
- s, err := m.Prompt(PromptVars{
|
|
|
|
- Prompt: "<h1>",
|
|
|
|
- })
|
|
|
|
- if err != nil {
|
|
|
|
- t.Fatal(err)
|
|
|
|
|
|
+func TestChat(t *testing.T) {
|
|
|
|
+ tests := []struct {
|
|
|
|
+ name string
|
|
|
|
+ template string
|
|
|
|
+ msgs []api.Message
|
|
|
|
+ want string
|
|
|
|
+ wantErr string
|
|
|
|
+ }{
|
|
|
|
+ {
|
|
|
|
+ name: "Single Message",
|
|
|
|
+ template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
|
|
+ msgs: []api.Message{
|
|
|
|
+ {
|
|
|
|
+ Role: "system",
|
|
|
|
+ Content: "You are a Wizard.",
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ Role: "user",
|
|
|
|
+ Content: "What are the potion ingredients?",
|
|
|
|
+ },
|
|
|
|
+ },
|
|
|
|
+ want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ name: "Message History",
|
|
|
|
+ template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
|
|
+ msgs: []api.Message{
|
|
|
|
+ {
|
|
|
|
+ Role: "system",
|
|
|
|
+ Content: "You are a Wizard.",
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ Role: "user",
|
|
|
|
+ Content: "What are the potion ingredients?",
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ Role: "assistant",
|
|
|
|
+ Content: "sugar",
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ Role: "user",
|
|
|
|
+ Content: "Anything else?",
|
|
|
|
+ },
|
|
|
|
+ },
|
|
|
|
+ want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ name: "Assistant Only",
|
|
|
|
+ template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
|
|
+ msgs: []api.Message{
|
|
|
|
+ {
|
|
|
|
+ Role: "assistant",
|
|
|
|
+ Content: "everything nice",
|
|
|
|
+ },
|
|
|
|
+ },
|
|
|
|
+ want: "[INST] [/INST]everything nice",
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ name: "Invalid Role",
|
|
|
|
+ msgs: []api.Message{
|
|
|
|
+ {
|
|
|
|
+ Role: "not-a-role",
|
|
|
|
+ Content: "howdy",
|
|
|
|
+ },
|
|
|
|
+ },
|
|
|
|
+ wantErr: "invalid role: not-a-role",
|
|
|
|
+ },
|
|
}
|
|
}
|
|
- want := "a<h1>b"
|
|
|
|
- if s != want {
|
|
|
|
- t.Errorf("got %q, want %q", s, want)
|
|
|
|
|
|
+
|
|
|
|
+ for _, tt := range tests {
|
|
|
|
+ m := Model{
|
|
|
|
+ Template: tt.template,
|
|
|
|
+ }
|
|
|
|
+ t.Run(tt.name, func(t *testing.T) {
|
|
|
|
+ got, err := m.ChatPrompt(tt.msgs)
|
|
|
|
+ if tt.wantErr != "" {
|
|
|
|
+ if err == nil {
|
|
|
|
+ t.Errorf("ChatPrompt() expected error, got nil")
|
|
|
|
+ }
|
|
|
|
+ if !strings.Contains(err.Error(), tt.wantErr) {
|
|
|
|
+ t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ if got != tt.want {
|
|
|
|
+ t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
|
|
|
|
+ }
|
|
|
|
+ })
|
|
}
|
|
}
|
|
}
|
|
}
|