Ver Fonte

fix: chat system prompting overrides (#2542)

Bruce MacDonald há 1 ano atrás
pai
commit
88622847c6
4 ficheiros alterados com 24 adições e 41 exclusões
  1. 9 2
      cmd/interactive.go
  2. 1 6
      server/prompt.go
  3. 1 30
      server/prompt_test.go
  4. 13 3
      server/routes.go

+ 9 - 2
cmd/interactive.go

@@ -354,8 +354,15 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 					}
 
 					if args[1] == "system" {
-						opts.System = sb.String()
-						opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
+						opts.System = sb.String() // for display in modelfile
+						newMessage := api.Message{Role: "system", Content: sb.String()}
+						// Check if the slice is not empty and the last message is from 'system'
+						if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" {
+							// Replace the last message
+							opts.Messages[len(opts.Messages)-1] = newMessage
+						} else {
+							opts.Messages = append(opts.Messages, newMessage)
+						}
 						fmt.Println("Set system message.")
 						sb.Reset()
 					} else if args[1] == "template" {

+ 1 - 6
server/prompt.go

@@ -91,7 +91,7 @@ func countTokens(tmpl string, system string, prompt string, response string, enc
 }
 
 // ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
-func ChatPrompt(tmpl string, system string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
+func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
 	type prompt struct {
 		System   string
 		Prompt   string
@@ -103,11 +103,6 @@ func ChatPrompt(tmpl string, system string, messages []api.Message, window int,
 
 	var p prompt
 
-	// Set the first system prompt to the model's system prompt
-	if system != "" {
-		p.System = system
-	}
-
 	// iterate through messages to build up {system,user,response} prompts
 	var imgId int
 	var prompts []prompt

+ 1 - 30
server/prompt_test.go

@@ -77,7 +77,6 @@ func TestChatPrompt(t *testing.T) {
 	tests := []struct {
 		name     string
 		template string
-		system   string
 		messages []api.Message
 		window   int
 		want     string
@@ -91,16 +90,6 @@ func TestChatPrompt(t *testing.T) {
 			window: 1024,
 			want:   "[INST] Hello [/INST]",
 		},
-		{
-			name:     "with default system message",
-			system:   "You are a Wizard.",
-			template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
-			messages: []api.Message{
-				{Role: "user", Content: "Hello"},
-			},
-			window: 1024,
-			want:   "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
-		},
 		{
 			name:     "with system message",
 			template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
@@ -185,24 +174,6 @@ func TestChatPrompt(t *testing.T) {
 			window:   1024,
 			want:     "",
 		},
-		{
-			name:     "empty list default system",
-			system:   "You are a Wizard.",
-			template: "{{ .System }} {{ .Prompt }}",
-			messages: []api.Message{},
-			window:   1024,
-			want:     "You are a Wizard. ",
-		},
-		{
-			name:     "empty user message",
-			system:   "You are a Wizard.",
-			template: "{{ .System }} {{ .Prompt }}",
-			messages: []api.Message{
-				{Role: "user", Content: ""},
-			},
-			window: 1024,
-			want:   "You are a Wizard. ",
-		},
 		{
 			name:     "empty prompt",
 			template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
@@ -221,7 +192,7 @@ func TestChatPrompt(t *testing.T) {
 
 	for _, tc := range tests {
 		t.Run(tc.name, func(t *testing.T) {
-			got, err := ChatPrompt(tc.template, tc.system, tc.messages, tc.window, encode)
+			got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
 			if err != nil {
 				t.Errorf("error = %v", err)
 			}

+ 13 - 3
server/routes.go

@@ -1092,12 +1092,12 @@ func streamResponse(c *gin.Context, ch chan any) {
 }
 
 // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
-func chatPrompt(ctx context.Context, messages []api.Message) (string, error) {
+func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
 	encode := func(s string) ([]int, error) {
 		return loaded.runner.Encode(ctx, s)
 	}
 
-	prompt, err := ChatPrompt(loaded.Model.Template, loaded.Model.System, messages, loaded.Options.NumCtx, encode)
+	prompt, err := ChatPrompt(template, messages, numCtx, encode)
 	if err != nil {
 		return "", err
 	}
@@ -1167,7 +1167,17 @@ func ChatHandler(c *gin.Context) {
 
 	checkpointLoaded := time.Now()
 
-	prompt, err := chatPrompt(c.Request.Context(), req.Messages)
+	// if the first message is not a system message, then add the model's default system message
+	if len(req.Messages) > 0 && req.Messages[0].Role != "system" {
+		req.Messages = append([]api.Message{
+			{
+				Role:    "system",
+				Content: model.System,
+			},
+		}, req.Messages...)
+	}
+
+	prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
 	if err != nil {
 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return