Explorar o código

fix: only flush template in chat when current role encountered (#1426)

Bruce MacDonald hai 1 ano
pai
achega
3b0b8930d4
Modificáronse 2 ficheiros con 92 adicións e 15 borrados
  1. 3 3
      server/images.go
  2. 89 12
      server/images_test.go

+ 3 - 3
server/images.go

@@ -103,16 +103,16 @@ func (m *Model) ChatPrompt(msgs []api.Message) (string, error) {
 	}
 	}
 
 
 	for _, msg := range msgs {
 	for _, msg := range msgs {
-		switch msg.Role {
+		switch strings.ToLower(msg.Role) {
 		case "system":
 		case "system":
-			if currentVars.Prompt != "" || currentVars.System != "" {
+			if currentVars.System != "" {
 				if err := writePrompt(); err != nil {
 				if err := writePrompt(); err != nil {
 					return "", err
 					return "", err
 				}
 				}
 			}
 			}
 			currentVars.System = msg.Content
 			currentVars.System = msg.Content
 		case "user":
 		case "user":
-			if currentVars.Prompt != "" || currentVars.System != "" {
+			if currentVars.Prompt != "" {
 				if err := writePrompt(); err != nil {
 				if err := writePrompt(); err != nil {
 					return "", err
 					return "", err
 				}
 				}

+ 89 - 12
server/images_test.go

@@ -1,21 +1,98 @@
 package server
 package server
 
 
 import (
 import (
+	"strings"
 	"testing"
 	"testing"
+
+	"github.com/jmorganca/ollama/api"
 )
 )
 
 
-func TestModelPrompt(t *testing.T) {
-	m := Model{
-		Template: "a{{ .Prompt }}b",
-	}
-	s, err := m.Prompt(PromptVars{
-		Prompt: "<h1>",
-	})
-	if err != nil {
-		t.Fatal(err)
+func TestChat(t *testing.T) {
+	tests := []struct {
+		name     string
+		template string
+		msgs     []api.Message
+		want     string
+		wantErr  string
+	}{
+		{
+			name:     "Single Message",
+			template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
+			msgs: []api.Message{
+				{
+					Role:    "system",
+					Content: "You are a Wizard.",
+				},
+				{
+					Role:    "user",
+					Content: "What are the potion ingredients?",
+				},
+			},
+			want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
+		},
+		{
+			name:     "Message History",
+			template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
+			msgs: []api.Message{
+				{
+					Role:    "system",
+					Content: "You are a Wizard.",
+				},
+				{
+					Role:    "user",
+					Content: "What are the potion ingredients?",
+				},
+				{
+					Role:    "assistant",
+					Content: "sugar",
+				},
+				{
+					Role:    "user",
+					Content: "Anything else?",
+				},
+			},
+			want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST]  Anything else? [/INST]",
+		},
+		{
+			name:     "Assistant Only",
+			template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
+			msgs: []api.Message{
+				{
+					Role:    "assistant",
+					Content: "everything nice",
+				},
+			},
+			want: "[INST]   [/INST]everything nice",
+		},
+		{
+			name: "Invalid Role",
+			msgs: []api.Message{
+				{
+					Role:    "not-a-role",
+					Content: "howdy",
+				},
+			},
+			wantErr: "invalid role: not-a-role",
+		},
 	}
 	}
-	want := "a<h1>b"
-	if s != want {
-		t.Errorf("got %q, want %q", s, want)
+
+	for _, tt := range tests {
+		m := Model{
+			Template: tt.template,
+		}
+		t.Run(tt.name, func(t *testing.T) {
+			got, err := m.ChatPrompt(tt.msgs)
+			if tt.wantErr != "" {
+				if err == nil {
+					t.Errorf("ChatPrompt() expected error, got nil")
+				}
+				if !strings.Contains(err.Error(), tt.wantErr) {
+					t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
+				}
+			}
+			if got != tt.want {
+				t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
+			}
+		})
 	}
 	}
 }
 }