|
@@ -256,15 +256,17 @@ func chatHistoryEqual(a, b ChatHistory) bool {
|
|
|
|
|
|
func TestChat(t *testing.T) {
|
|
|
tests := []struct {
|
|
|
- name string
|
|
|
- template string
|
|
|
- msgs []api.Message
|
|
|
- want ChatHistory
|
|
|
- wantErr string
|
|
|
+ name string
|
|
|
+ model Model
|
|
|
+ msgs []api.Message
|
|
|
+ want ChatHistory
|
|
|
+ wantErr string
|
|
|
}{
|
|
|
{
|
|
|
- name: "Single Message",
|
|
|
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
|
+ name: "Single Message",
|
|
|
+ model: Model{
|
|
|
+ Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
|
+ },
|
|
|
msgs: []api.Message{
|
|
|
{
|
|
|
Role: "system",
|
|
@@ -287,8 +289,10 @@ func TestChat(t *testing.T) {
|
|
|
},
|
|
|
},
|
|
|
{
|
|
|
- name: "Message History",
|
|
|
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
|
+ name: "Message History",
|
|
|
+ model: Model{
|
|
|
+ Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
|
+ },
|
|
|
msgs: []api.Message{
|
|
|
{
|
|
|
Role: "system",
|
|
@@ -323,8 +327,10 @@ func TestChat(t *testing.T) {
|
|
|
},
|
|
|
},
|
|
|
{
|
|
|
- name: "Assistant Only",
|
|
|
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
|
+ name: "Assistant Only",
|
|
|
+ model: Model{
|
|
|
+ Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
|
+ },
|
|
|
msgs: []api.Message{
|
|
|
{
|
|
|
Role: "assistant",
|
|
@@ -340,6 +346,51 @@ func TestChat(t *testing.T) {
|
|
|
},
|
|
|
},
|
|
|
},
|
|
|
+ {
|
|
|
+ name: "Last system message is preserved from modelfile",
|
|
|
+ model: Model{
|
|
|
+ Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
|
+ System: "You are Mojo Jojo.",
|
|
|
+ },
|
|
|
+ msgs: []api.Message{
|
|
|
+ {
|
|
|
+ Role: "user",
|
|
|
+ Content: "hi",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ want: ChatHistory{
|
|
|
+ Prompts: []PromptVars{
|
|
|
+ {
|
|
|
+ System: "You are Mojo Jojo.",
|
|
|
+ Prompt: "hi",
|
|
|
+ First: true,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ LastSystem: "You are Mojo Jojo.",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "Last system message is preserved from messages",
|
|
|
+ model: Model{
|
|
|
+ Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
|
+ System: "You are Mojo Jojo.",
|
|
|
+ },
|
|
|
+ msgs: []api.Message{
|
|
|
+ {
|
|
|
+ Role: "system",
|
|
|
+ Content: "You are Professor Utonium.",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ want: ChatHistory{
|
|
|
+ Prompts: []PromptVars{
|
|
|
+ {
|
|
|
+ System: "You are Professor Utonium.",
|
|
|
+ First: true,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ LastSystem: "You are Professor Utonium.",
|
|
|
+ },
|
|
|
+ },
|
|
|
{
|
|
|
name: "Invalid Role",
|
|
|
msgs: []api.Message{
|
|
@@ -353,11 +404,8 @@ func TestChat(t *testing.T) {
|
|
|
}
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
- m := Model{
|
|
|
- Template: tt.template,
|
|
|
- }
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
- got, err := m.ChatPrompts(tt.msgs)
|
|
|
+ got, err := tt.model.ChatPrompts(tt.msgs)
|
|
|
if tt.wantErr != "" {
|
|
|
if err == nil {
|
|
|
t.Errorf("ChatPrompt() expected error, got nil")
|