Browse Source

post-response templating (#1427)

Bruce MacDonald 1 năm trước cách đây
mục cha
commit
db356c8519
3 tập tin đã thay đổi với 333 bổ sung15 xóa
  1. 72 12
      server/images.go
  2. 249 0
      server/images_test.go
  3. 12 3
      server/routes.go

+ 72 - 12
server/images.go

@@ -18,6 +18,7 @@ import (
 	"strconv"
 	"strings"
 	"text/template"
+	"text/template/parse"
 
 	"golang.org/x/exp/slices"
 
@@ -57,19 +58,37 @@ type PromptVars struct {
 	First    bool
 }
 
-func (m *Model) Prompt(p PromptVars) (string, error) {
+// extractParts extracts the parts of the template before and after the {{.Response}} node.
+func extractParts(tmplStr string) (pre string, post string, err error) {
+	tmpl, err := template.New("").Parse(tmplStr)
+	if err != nil {
+		return "", "", err
+	}
+
+	var foundResponse bool
+
+	for _, node := range tmpl.Tree.Root.Nodes {
+		if node.Type() == parse.NodeAction && node.String() == "{{.Response}}" {
+			foundResponse = true
+		}
+		if !foundResponse {
+			pre += node.String()
+		} else {
+			post += node.String()
+		}
+	}
+
+	return pre, post, nil
+}
+
+func Prompt(promptTemplate string, p PromptVars) (string, error) {
 	var prompt strings.Builder
 	// Use the "missingkey=zero" option to handle missing variables without panicking
-	tmpl, err := template.New("").Option("missingkey=zero").Parse(m.Template)
+	tmpl, err := template.New("").Option("missingkey=zero").Parse(promptTemplate)
 	if err != nil {
 		return "", err
 	}
 
-	if p.System == "" {
-		// use the default system message for this model if one is not specified
-		p.System = m.System
-	}
-
 	vars := map[string]any{
 		"System":   p.System,
 		"Prompt":   p.Prompt,
@@ -82,20 +101,59 @@ func (m *Model) Prompt(p PromptVars) (string, error) {
 		return "", err
 	}
 	prompt.WriteString(sb.String())
-	prompt.WriteString(p.Response)
+
+	if !strings.Contains(prompt.String(), p.Response) {
+		// if the response is not in the prompt template, append it to the end
+		prompt.WriteString(p.Response)
+	}
+
 	return prompt.String(), nil
 }
 
+// PreResponsePrompt returns the prompt before the response tag
+func (m *Model) PreResponsePrompt(p PromptVars) (string, error) {
+	if p.System == "" {
+		// use the default system prompt for this model if one is not specified
+		p.System = m.System
+	}
+	pre, _, err := extractParts(m.Template)
+	if err != nil {
+		return "", err
+	}
+
+	return Prompt(pre, p)
+}
+
+// PostResponseTemplate returns the template after the response tag
+func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
+	if p.System == "" {
+		// use the default system prompt for this model if one is not specified
+		p.System = m.System
+	}
+	_, post, err := extractParts(m.Template)
+	if err != nil {
+		return "", err
+	}
+
+	if post == "" {
+		// if there is no post-response template, return the provided response
+		return p.Response, nil
+	}
+
+	return Prompt(post, p)
+}
+
 func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) {
 	// build the prompt from the list of messages
 	var prompt strings.Builder
 	var currentImages []api.ImageData
 	currentVars := PromptVars{
-		First: true,
+		First:  true,
+		System: m.System,
 	}
 
 	writePrompt := func() error {
-		p, err := m.Prompt(currentVars)
+		p, err := Prompt(m.Template, currentVars)
 		if err != nil {
 			return err
 		}
@@ -133,9 +191,11 @@ func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error)
 
 	// Append the last set of vars if they are non-empty
 	if currentVars.Prompt != "" || currentVars.System != "" {
-		if err := writePrompt(); err != nil {
-			return "", nil, err
+		p, err := m.PreResponsePrompt(currentVars)
+		if err != nil {
+			return "", nil, fmt.Errorf("pre-response template: %w", err)
 		}
+		prompt.WriteString(p)
 	}
 
 	return prompt.String(), currentImages, nil

+ 249 - 0
server/images_test.go

@@ -7,6 +7,232 @@ import (
 	"github.com/jmorganca/ollama/api"
 )
 
+func TestPrompt(t *testing.T) {
+	tests := []struct {
+		name     string
+		template string
+		vars     PromptVars
+		want     string
+		wantErr  bool
+	}{
+		{
+			name:     "System Prompt",
+			template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
+			vars: PromptVars{
+				System: "You are a Wizard.",
+				Prompt: "What are the potion ingredients?",
+			},
+			want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
+		},
+		{
+			name:     "System Prompt with Response",
+			template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
+			vars: PromptVars{
+				System:   "You are a Wizard.",
+				Prompt:   "What are the potion ingredients?",
+				Response: "I don't know.",
+			},
+			want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
+		},
+		{
+			name:     "Conditional Logic Nodes",
+			template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
+			vars: PromptVars{
+				First:    true,
+				System:   "You are a Wizard.",
+				Prompt:   "What are the potion ingredients?",
+				Response: "I don't know.",
+			},
+			want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			got, err := Prompt(tt.template, tt.vars)
+			if (err != nil) != tt.wantErr {
+				t.Errorf("Prompt() error = %v, wantErr %v", err, tt.wantErr)
+				return
+			}
+			if got != tt.want {
+				t.Errorf("Prompt() got = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
+
+func TestModel_PreResponsePrompt(t *testing.T) {
+	tests := []struct {
+		name     string
+		template string
+		vars     PromptVars
+		want     string
+		wantErr  bool
+	}{
+		{
+			name:     "No Response in Template",
+			template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
+			vars: PromptVars{
+				System: "You are a Wizard.",
+				Prompt: "What are the potion ingredients?",
+			},
+			want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
+		},
+		{
+			name:     "Response in Template",
+			template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
+			vars: PromptVars{
+				System: "You are a Wizard.",
+				Prompt: "What are the potion ingredients?",
+			},
+			want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] ",
+		},
+		{
+			name:     "Response in Template with Trailing Formatting",
+			template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
+			vars: PromptVars{
+				Prompt: "What are the potion ingredients?",
+			},
+			want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
+		},
+		{
+			name:     "Response in Template with Alternative Formatting",
+			template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
+			vars: PromptVars{
+				Prompt: "What are the potion ingredients?",
+			},
+			want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
+		},
+	}
+
+	for _, tt := range tests {
+		m := Model{Template: tt.template}
+		t.Run(tt.name, func(t *testing.T) {
+			got, err := m.PreResponsePrompt(tt.vars)
+			if (err != nil) != tt.wantErr {
+				t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
+				return
+			}
+			if got != tt.want {
+				t.Errorf("PreResponsePrompt() got = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
+
+func TestModel_PostResponsePrompt(t *testing.T) {
+	tests := []struct {
+		name     string
+		template string
+		vars     PromptVars
+		want     string
+		wantErr  bool
+	}{
+		{
+			name:     "No Response in Template",
+			template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
+			vars: PromptVars{
+				Response: "I don't know.",
+			},
+			want: "I don't know.",
+		},
+		{
+			name:     "Response in Template",
+			template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
+			vars: PromptVars{
+				Response: "I don't know.",
+			},
+			want: "I don't know.",
+		},
+		{
+			name:     "Response in Template with Trailing Formatting",
+			template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
+			vars: PromptVars{
+				Response: "I don't know.",
+			},
+			want: "I don't know.<|im_end|>",
+		},
+		{
+			name:     "Response in Template with Alternative Formatting",
+			template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
+			vars: PromptVars{
+				Response: "I don't know.",
+			},
+			want: "I don't know.<|im_end|>",
+		},
+	}
+
+	for _, tt := range tests {
+		m := Model{Template: tt.template}
+		t.Run(tt.name, func(t *testing.T) {
+			got, err := m.PostResponseTemplate(tt.vars)
+			if (err != nil) != tt.wantErr {
+				t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
+				return
+			}
+			if got != tt.want {
+				t.Errorf("PostResponseTemplate() got = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
+
+func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) {
+	tests := []struct {
+		name     string
+		template string
+		preVars  PromptVars
+		postVars PromptVars
+		want     string
+		wantErr  bool
+	}{
+		{
+			name:     "Response in Template",
+			template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
+			preVars: PromptVars{
+				Prompt: "What are the potion ingredients?",
+			},
+			postVars: PromptVars{
+				Prompt:   "What are the potion ingredients?",
+				Response: "Sugar.",
+			},
+			want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSugar.<|im_end|>",
+		},
+		{
+			name:     "No Response in Template",
+			template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n",
+			preVars: PromptVars{
+				Prompt: "What are the potion ingredients?",
+			},
+			postVars: PromptVars{
+				Prompt:   "What are the potion ingredients?",
+				Response: "Spice.",
+			},
+			want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSpice.",
+		},
+	}
+
+	for _, tt := range tests {
+		m := Model{Template: tt.template}
+		t.Run(tt.name, func(t *testing.T) {
+			pre, err := m.PreResponsePrompt(tt.preVars)
+			if (err != nil) != tt.wantErr {
+				t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
+				return
+			}
+			post, err := m.PostResponseTemplate(tt.postVars)
+			if err != nil {
+				t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
+				return
+			}
+			result := pre + post
+			if result != tt.want {
+				t.Errorf("Prompt() got = %v, want %v", result, tt.want)
+			}
+		})
+	}
+}
+
 func TestChat(t *testing.T) {
 	tests := []struct {
 		name     string
@@ -30,6 +256,29 @@ func TestChat(t *testing.T) {
 			},
 			want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
 		},
+		{
+			name:     "First Message",
+			template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]",
+			msgs: []api.Message{
+				{
+					Role:    "system",
+					Content: "You are a Wizard.",
+				},
+				{
+					Role:    "user",
+					Content: "What are the potion ingredients?",
+				},
+				{
+					Role:    "assistant",
+					Content: "eye of newt",
+				},
+				{
+					Role:    "user",
+					Content: "Anything else?",
+				},
+			},
+			want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST]   Anything else? [/INST]",
+		},
 		{
 			name:     "Message History",
 			template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",

+ 12 - 3
server/routes.go

@@ -195,6 +195,7 @@ func GenerateHandler(c *gin.Context) {
 	checkpointLoaded := time.Now()
 
 	var prompt string
+	var promptVars PromptVars
 	switch {
 	case req.Raw:
 		prompt = req.Prompt
@@ -217,11 +218,12 @@ func GenerateHandler(c *gin.Context) {
 			prevCtx = strings.TrimPrefix(prevCtx, " ")
 			rebuild.WriteString(prevCtx)
 		}
-		p, err := model.Prompt(PromptVars{
+		promptVars = PromptVars{
 			System: req.System,
 			Prompt: req.Prompt,
 			First:  len(req.Context) == 0,
-		})
+		}
+		p, err := model.PreResponsePrompt(promptVars)
 		if err != nil {
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 			return
@@ -264,7 +266,14 @@ func GenerateHandler(c *gin.Context) {
 				resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
 
 				if !req.Raw {
-					embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String())
+					// append the generated text to the history and template it if needed
+					promptVars.Response = generated.String()
+					result, err := model.PostResponseTemplate(promptVars)
+					if err != nil {
+						ch <- gin.H{"error": err.Error()}
+						return
+					}
+					embd, err := loaded.runner.Encode(c.Request.Context(), prompt+result)
 					if err != nil {
 						ch <- gin.H{"error": err.Error()}
 						return