瀏覽代碼

preserve last system message from modelfile (#2289)

Bruce MacDonald 1 年之前
父節點
當前提交
a896079705
共有 2 個文件被更改,包括 66 次插入17 次删除
  1. 3 2
      server/images.go
  2. 63 15
      server/images_test.go

+ 3 - 2
server/images.go

@@ -156,7 +156,7 @@ type ChatHistory struct {
 func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
 func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
 	// build the prompt from the list of messages
 	// build the prompt from the list of messages
 	var currentImages []api.ImageData
 	var currentImages []api.ImageData
-	var lastSystem string
+	lastSystem := m.System
 	currentVars := PromptVars{
 	currentVars := PromptVars{
 		First:  true,
 		First:  true,
 		System: m.System,
 		System: m.System,
@@ -167,7 +167,8 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
 	for _, msg := range msgs {
 	for _, msg := range msgs {
 		switch strings.ToLower(msg.Role) {
 		switch strings.ToLower(msg.Role) {
 		case "system":
 		case "system":
-			if currentVars.System != "" {
+			// if this is the first message it overrides the system prompt in the modelfile
+			if !currentVars.First && currentVars.System != "" {
 				prompts = append(prompts, currentVars)
 				prompts = append(prompts, currentVars)
 				currentVars = PromptVars{}
 				currentVars = PromptVars{}
 			}
 			}

+ 63 - 15
server/images_test.go

@@ -256,15 +256,17 @@ func chatHistoryEqual(a, b ChatHistory) bool {
 
 
 func TestChat(t *testing.T) {
 func TestChat(t *testing.T) {
 	tests := []struct {
 	tests := []struct {
-		name     string
-		template string
-		msgs     []api.Message
-		want     ChatHistory
-		wantErr  string
+		name    string
+		model   Model
+		msgs    []api.Message
+		want    ChatHistory
+		wantErr string
 	}{
 	}{
 		{
 		{
-			name:     "Single Message",
-			template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
+			name: "Single Message",
+			model: Model{
+				Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
+			},
 			msgs: []api.Message{
 			msgs: []api.Message{
 				{
 				{
 					Role:    "system",
 					Role:    "system",
@@ -287,8 +289,10 @@ func TestChat(t *testing.T) {
 			},
 			},
 		},
 		},
 		{
 		{
-			name:     "Message History",
-			template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
+			name: "Message History",
+			model: Model{
+				Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
+			},
 			msgs: []api.Message{
 			msgs: []api.Message{
 				{
 				{
 					Role:    "system",
 					Role:    "system",
@@ -323,8 +327,10 @@ func TestChat(t *testing.T) {
 			},
 			},
 		},
 		},
 		{
 		{
-			name:     "Assistant Only",
-			template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
+			name: "Assistant Only",
+			model: Model{
+				Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
+			},
 			msgs: []api.Message{
 			msgs: []api.Message{
 				{
 				{
 					Role:    "assistant",
 					Role:    "assistant",
@@ -340,6 +346,51 @@ func TestChat(t *testing.T) {
 				},
 				},
 			},
 			},
 		},
 		},
+		{
+			name: "Last system message is preserved from modelfile",
+			model: Model{
+				Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
+				System:   "You are Mojo Jojo.",
+			},
+			msgs: []api.Message{
+				{
+					Role:    "user",
+					Content: "hi",
+				},
+			},
+			want: ChatHistory{
+				Prompts: []PromptVars{
+					{
+						System: "You are Mojo Jojo.",
+						Prompt: "hi",
+						First:  true,
+					},
+				},
+				LastSystem: "You are Mojo Jojo.",
+			},
+		},
+		{
+			name: "Last system message is preserved from messages",
+			model: Model{
+				Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
+				System:   "You are Mojo Jojo.",
+			},
+			msgs: []api.Message{
+				{
+					Role:    "system",
+					Content: "You are Professor Utonium.",
+				},
+			},
+			want: ChatHistory{
+				Prompts: []PromptVars{
+					{
+						System: "You are Professor Utonium.",
+						First:  true,
+					},
+				},
+				LastSystem: "You are Professor Utonium.",
+			},
+		},
 		{
 		{
 			name: "Invalid Role",
 			name: "Invalid Role",
 			msgs: []api.Message{
 			msgs: []api.Message{
@@ -353,11 +404,8 @@ func TestChat(t *testing.T) {
 	}
 	}
 
 
 	for _, tt := range tests {
 	for _, tt := range tests {
-		m := Model{
-			Template: tt.template,
-		}
 		t.Run(tt.name, func(t *testing.T) {
 		t.Run(tt.name, func(t *testing.T) {
-			got, err := m.ChatPrompts(tt.msgs)
+			got, err := tt.model.ChatPrompts(tt.msgs)
 			if tt.wantErr != "" {
 			if tt.wantErr != "" {
 				if err == nil {
 				if err == nil {
 					t.Errorf("ChatPrompt() expected error, got nil")
 					t.Errorf("ChatPrompt() expected error, got nil")