1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- package server
- import (
- "strings"
- "testing"
- "github.com/jmorganca/ollama/api"
- )
- func TestChat(t *testing.T) {
- tests := []struct {
- name string
- template string
- msgs []api.Message
- want string
- wantErr string
- }{
- {
- name: "Single Message",
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- msgs: []api.Message{
- {
- Role: "system",
- Content: "You are a Wizard.",
- },
- {
- Role: "user",
- Content: "What are the potion ingredients?",
- },
- },
- want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
- },
- {
- name: "Message History",
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- msgs: []api.Message{
- {
- Role: "system",
- Content: "You are a Wizard.",
- },
- {
- Role: "user",
- Content: "What are the potion ingredients?",
- },
- {
- Role: "assistant",
- Content: "sugar",
- },
- {
- Role: "user",
- Content: "Anything else?",
- },
- },
- want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
- },
- {
- name: "Assistant Only",
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- msgs: []api.Message{
- {
- Role: "assistant",
- Content: "everything nice",
- },
- },
- want: "[INST] [/INST]everything nice",
- },
- {
- name: "Invalid Role",
- msgs: []api.Message{
- {
- Role: "not-a-role",
- Content: "howdy",
- },
- },
- wantErr: "invalid role: not-a-role",
- },
- }
- for _, tt := range tests {
- m := Model{
- Template: tt.template,
- }
- t.Run(tt.name, func(t *testing.T) {
- got, err := m.ChatPrompt(tt.msgs)
- if tt.wantErr != "" {
- if err == nil {
- t.Errorf("ChatPrompt() expected error, got nil")
- }
- if !strings.Contains(err.Error(), tt.wantErr) {
- t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
- }
- }
- if got != tt.want {
- t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
- }
- })
- }
- }
|