123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442 |
- package server
- import (
- "bytes"
- "strings"
- "testing"
- "github.com/jmorganca/ollama/api"
- )
- func TestPrompt(t *testing.T) {
- tests := []struct {
- name string
- template string
- vars PromptVars
- want string
- wantErr bool
- }{
- {
- name: "System Prompt",
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- vars: PromptVars{
- System: "You are a Wizard.",
- Prompt: "What are the potion ingredients?",
- },
- want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
- },
- {
- name: "System Prompt with Response",
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
- vars: PromptVars{
- System: "You are a Wizard.",
- Prompt: "What are the potion ingredients?",
- Response: "I don't know.",
- },
- want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
- },
- {
- name: "Conditional Logic Nodes",
- template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
- vars: PromptVars{
- First: true,
- System: "You are a Wizard.",
- Prompt: "What are the potion ingredients?",
- Response: "I don't know.",
- },
- want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := Prompt(tt.template, tt.vars)
- if (err != nil) != tt.wantErr {
- t.Errorf("Prompt() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if got != tt.want {
- t.Errorf("Prompt() got = %v, want %v", got, tt.want)
- }
- })
- }
- }
- func TestModel_PreResponsePrompt(t *testing.T) {
- tests := []struct {
- name string
- template string
- vars PromptVars
- want string
- wantErr bool
- }{
- {
- name: "No Response in Template",
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- vars: PromptVars{
- System: "You are a Wizard.",
- Prompt: "What are the potion ingredients?",
- },
- want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
- },
- {
- name: "Response in Template",
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
- vars: PromptVars{
- System: "You are a Wizard.",
- Prompt: "What are the potion ingredients?",
- },
- want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] ",
- },
- {
- name: "Response in Template with Trailing Formatting",
- template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
- vars: PromptVars{
- Prompt: "What are the potion ingredients?",
- },
- want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
- },
- {
- name: "Response in Template with Alternative Formatting",
- template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
- vars: PromptVars{
- Prompt: "What are the potion ingredients?",
- },
- want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
- },
- }
- for _, tt := range tests {
- m := Model{Template: tt.template}
- t.Run(tt.name, func(t *testing.T) {
- got, err := m.PreResponsePrompt(tt.vars)
- if (err != nil) != tt.wantErr {
- t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if got != tt.want {
- t.Errorf("PreResponsePrompt() got = %v, want %v", got, tt.want)
- }
- })
- }
- }
- func TestModel_PostResponsePrompt(t *testing.T) {
- tests := []struct {
- name string
- template string
- vars PromptVars
- want string
- wantErr bool
- }{
- {
- name: "No Response in Template",
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- vars: PromptVars{
- Response: "I don't know.",
- },
- want: "I don't know.",
- },
- {
- name: "Response in Template",
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
- vars: PromptVars{
- Response: "I don't know.",
- },
- want: "I don't know.",
- },
- {
- name: "Response in Template with Trailing Formatting",
- template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
- vars: PromptVars{
- Response: "I don't know.",
- },
- want: "I don't know.<|im_end|>",
- },
- {
- name: "Response in Template with Alternative Formatting",
- template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
- vars: PromptVars{
- Response: "I don't know.",
- },
- want: "I don't know.<|im_end|>",
- },
- }
- for _, tt := range tests {
- m := Model{Template: tt.template}
- t.Run(tt.name, func(t *testing.T) {
- got, err := m.PostResponseTemplate(tt.vars)
- if (err != nil) != tt.wantErr {
- t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if got != tt.want {
- t.Errorf("PostResponseTemplate() got = %v, want %v", got, tt.want)
- }
- })
- }
- }
- func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) {
- tests := []struct {
- name string
- template string
- preVars PromptVars
- postVars PromptVars
- want string
- wantErr bool
- }{
- {
- name: "Response in Template",
- template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
- preVars: PromptVars{
- Prompt: "What are the potion ingredients?",
- },
- postVars: PromptVars{
- Prompt: "What are the potion ingredients?",
- Response: "Sugar.",
- },
- want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSugar.<|im_end|>",
- },
- {
- name: "No Response in Template",
- template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n",
- preVars: PromptVars{
- Prompt: "What are the potion ingredients?",
- },
- postVars: PromptVars{
- Prompt: "What are the potion ingredients?",
- Response: "Spice.",
- },
- want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSpice.",
- },
- }
- for _, tt := range tests {
- m := Model{Template: tt.template}
- t.Run(tt.name, func(t *testing.T) {
- pre, err := m.PreResponsePrompt(tt.preVars)
- if (err != nil) != tt.wantErr {
- t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- post, err := m.PostResponseTemplate(tt.postVars)
- if err != nil {
- t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- result := pre + post
- if result != tt.want {
- t.Errorf("Prompt() got = %v, want %v", result, tt.want)
- }
- })
- }
- }
- func chatHistoryEqual(a, b ChatHistory) bool {
- if len(a.Prompts) != len(b.Prompts) {
- return false
- }
- for i, v := range a.Prompts {
- if v.First != b.Prompts[i].First {
- return false
- }
- if v.Response != b.Prompts[i].Response {
- return false
- }
- if v.Prompt != b.Prompts[i].Prompt {
- return false
- }
- if v.System != b.Prompts[i].System {
- return false
- }
- if len(v.Images) != len(b.Prompts[i].Images) {
- return false
- }
- for j, img := range v.Images {
- if img.ID != b.Prompts[i].Images[j].ID {
- return false
- }
- if !bytes.Equal(img.Data, b.Prompts[i].Images[j].Data) {
- return false
- }
- }
- }
- return a.LastSystem == b.LastSystem
- }
- func TestChat(t *testing.T) {
- tests := []struct {
- name string
- model Model
- msgs []api.Message
- want ChatHistory
- wantErr string
- }{
- {
- name: "Single Message",
- model: Model{
- Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- },
- msgs: []api.Message{
- {
- Role: "system",
- Content: "You are a Wizard.",
- },
- {
- Role: "user",
- Content: "What are the potion ingredients?",
- },
- },
- want: ChatHistory{
- Prompts: []PromptVars{
- {
- System: "You are a Wizard.",
- Prompt: "What are the potion ingredients?",
- First: true,
- },
- },
- LastSystem: "You are a Wizard.",
- },
- },
- {
- name: "Message History",
- model: Model{
- Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- },
- msgs: []api.Message{
- {
- Role: "system",
- Content: "You are a Wizard.",
- },
- {
- Role: "user",
- Content: "What are the potion ingredients?",
- },
- {
- Role: "assistant",
- Content: "sugar",
- },
- {
- Role: "user",
- Content: "Anything else?",
- },
- },
- want: ChatHistory{
- Prompts: []PromptVars{
- {
- System: "You are a Wizard.",
- Prompt: "What are the potion ingredients?",
- Response: "sugar",
- First: true,
- },
- {
- Prompt: "Anything else?",
- },
- },
- LastSystem: "You are a Wizard.",
- },
- },
- {
- name: "Assistant Only",
- model: Model{
- Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- },
- msgs: []api.Message{
- {
- Role: "assistant",
- Content: "everything nice",
- },
- },
- want: ChatHistory{
- Prompts: []PromptVars{
- {
- Response: "everything nice",
- First: true,
- },
- },
- },
- },
- {
- name: "Last system message is preserved from modelfile",
- model: Model{
- Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- System: "You are Mojo Jojo.",
- },
- msgs: []api.Message{
- {
- Role: "user",
- Content: "hi",
- },
- },
- want: ChatHistory{
- Prompts: []PromptVars{
- {
- System: "You are Mojo Jojo.",
- Prompt: "hi",
- First: true,
- },
- },
- LastSystem: "You are Mojo Jojo.",
- },
- },
- {
- name: "Last system message is preserved from messages",
- model: Model{
- Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- System: "You are Mojo Jojo.",
- },
- msgs: []api.Message{
- {
- Role: "system",
- Content: "You are Professor Utonium.",
- },
- },
- want: ChatHistory{
- Prompts: []PromptVars{
- {
- System: "You are Professor Utonium.",
- First: true,
- },
- },
- LastSystem: "You are Professor Utonium.",
- },
- },
- {
- name: "Invalid Role",
- msgs: []api.Message{
- {
- Role: "not-a-role",
- Content: "howdy",
- },
- },
- wantErr: "invalid role: not-a-role",
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := tt.model.ChatPrompts(tt.msgs)
- if tt.wantErr != "" {
- if err == nil {
- t.Errorf("ChatPrompt() expected error, got nil")
- }
- if !strings.Contains(err.Error(), tt.wantErr) {
- t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
- }
- return
- }
- if !chatHistoryEqual(*got, tt.want) {
- t.Errorf("ChatPrompt() got = %#v, want %#v", got, tt.want)
- }
- })
- }
- }
|