Browse Source

trim chat prompt based on llm context size (#1963)

Bruce MacDonald 1 year ago
parent
commit
0632dff3f8
4 changed files with 440 additions and 57 deletions
  1. 24 27
      server/images.go
  2. 56 28
      server/images_test.go
  3. 105 2
      server/routes.go
  4. 255 0
      server/routes_test.go

+ 24 - 27
server/images.go

@@ -146,62 +146,59 @@ func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
 	return Prompt(post, p)
 }
 
-func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) {
+type ChatHistory struct {
+	Prompts       []PromptVars
+	CurrentImages []api.ImageData
+	LastSystem    string
+}
+
+// ChatPrompts returns a list of formatted chat prompts from a list of messages
+func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
 	// build the prompt from the list of messages
-	var prompt strings.Builder
 	var currentImages []api.ImageData
+	var lastSystem string
 	currentVars := PromptVars{
 		First:  true,
 		System: m.System,
 	}
 
-	writePrompt := func() error {
-		p, err := Prompt(m.Template, currentVars)
-		if err != nil {
-			return err
-		}
-		prompt.WriteString(p)
-		currentVars = PromptVars{}
-		return nil
-	}
+	prompts := []PromptVars{}
 
 	for _, msg := range msgs {
 		switch strings.ToLower(msg.Role) {
 		case "system":
 			if currentVars.System != "" {
-				if err := writePrompt(); err != nil {
-					return "", nil, err
-				}
+				prompts = append(prompts, currentVars)
+				currentVars = PromptVars{}
 			}
 			currentVars.System = msg.Content
+			lastSystem = msg.Content
 		case "user":
 			if currentVars.Prompt != "" {
-				if err := writePrompt(); err != nil {
-					return "", nil, err
-				}
+				prompts = append(prompts, currentVars)
+				currentVars = PromptVars{}
 			}
 			currentVars.Prompt = msg.Content
 			currentImages = msg.Images
 		case "assistant":
 			currentVars.Response = msg.Content
-			if err := writePrompt(); err != nil {
-				return "", nil, err
-			}
+			prompts = append(prompts, currentVars)
+			currentVars = PromptVars{}
 		default:
-			return "", nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
+			return nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
 		}
 	}
 
 	// Append the last set of vars if they are non-empty
 	if currentVars.Prompt != "" || currentVars.System != "" {
-		p, err := m.PreResponsePrompt(currentVars)
-		if err != nil {
-			return "", nil, fmt.Errorf("pre-response template: %w", err)
-		}
-		prompt.WriteString(p)
+		prompts = append(prompts, currentVars)
 	}
 
-	return prompt.String(), currentImages, nil
+	return &ChatHistory{
+		Prompts:       prompts,
+		CurrentImages: currentImages,
+		LastSystem:    lastSystem,
+	}, nil
 }
 
 type ManifestV2 struct {

+ 56 - 28
server/images_test.go

@@ -1,6 +1,7 @@
 package server
 
 import (
+	"bytes"
 	"strings"
 	"testing"
 
@@ -233,12 +234,32 @@ func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) {
 	}
 }
 
+func chatHistoryEqual(a, b ChatHistory) bool {
+	if len(a.Prompts) != len(b.Prompts) {
+		return false
+	}
+	if len(a.CurrentImages) != len(b.CurrentImages) {
+		return false
+	}
+	for i, v := range a.Prompts {
+		if v != b.Prompts[i] {
+			return false
+		}
+	}
+	for i, v := range a.CurrentImages {
+		if !bytes.Equal(v, b.CurrentImages[i]) {
+			return false
+		}
+	}
+	return a.LastSystem == b.LastSystem
+}
+
 func TestChat(t *testing.T) {
 	tests := []struct {
 		name     string
 		template string
 		msgs     []api.Message
-		want     string
+		want     ChatHistory
 		wantErr  string
 	}{
 		{
@@ -254,30 +275,16 @@ func TestChat(t *testing.T) {
 					Content: "What are the potion ingredients?",
 				},
 			},
-			want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
-		},
-		{
-			name:     "First Message",
-			template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]",
-			msgs: []api.Message{
-				{
-					Role:    "system",
-					Content: "You are a Wizard.",
-				},
-				{
-					Role:    "user",
-					Content: "What are the potion ingredients?",
-				},
-				{
-					Role:    "assistant",
-					Content: "eye of newt",
-				},
-				{
-					Role:    "user",
-					Content: "Anything else?",
+			want: ChatHistory{
+				Prompts: []PromptVars{
+					{
+						System: "You are a Wizard.",
+						Prompt: "What are the potion ingredients?",
+						First:  true,
+					},
 				},
+				LastSystem: "You are a Wizard.",
 			},
-			want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST]   Anything else? [/INST]",
 		},
 		{
 			name:     "Message History",
@@ -300,7 +307,20 @@ func TestChat(t *testing.T) {
 					Content: "Anything else?",
 				},
 			},
-			want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST]  Anything else? [/INST]",
+			want: ChatHistory{
+				Prompts: []PromptVars{
+					{
+						System:   "You are a Wizard.",
+						Prompt:   "What are the potion ingredients?",
+						Response: "sugar",
+						First:    true,
+					},
+					{
+						Prompt: "Anything else?",
+					},
+				},
+				LastSystem: "You are a Wizard.",
+			},
 		},
 		{
 			name:     "Assistant Only",
@@ -311,7 +331,14 @@ func TestChat(t *testing.T) {
 					Content: "everything nice",
 				},
 			},
-			want: "[INST]   [/INST]everything nice",
+			want: ChatHistory{
+				Prompts: []PromptVars{
+					{
+						Response: "everything nice",
+						First:    true,
+					},
+				},
+			},
 		},
 		{
 			name: "Invalid Role",
@@ -330,7 +357,7 @@ func TestChat(t *testing.T) {
 			Template: tt.template,
 		}
 		t.Run(tt.name, func(t *testing.T) {
-			got, _, err := m.ChatPrompt(tt.msgs)
+			got, err := m.ChatPrompts(tt.msgs)
 			if tt.wantErr != "" {
 				if err == nil {
 					t.Errorf("ChatPrompt() expected error, got nil")
@@ -338,9 +365,10 @@ func TestChat(t *testing.T) {
 				if !strings.Contains(err.Error(), tt.wantErr) {
 					t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
 				}
+				return
 			}
-			if got != tt.want {
-				t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
+			if !chatHistoryEqual(*got, tt.want) {
+				t.Errorf("ChatPrompt() got = %#v, want %#v", got, tt.want)
 			}
 		})
 	}

+ 105 - 2
server/routes.go

@@ -1121,11 +1121,16 @@ func ChatHandler(c *gin.Context) {
 
 	checkpointLoaded := time.Now()
 
-	prompt, images, err := model.ChatPrompt(req.Messages)
+	chat, err := model.ChatPrompts(req.Messages)
 	if err != nil {
 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 	}
+	prompt, err := trimmedPrompt(c.Request.Context(), chat, model)
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
 
 	slog.Debug(fmt.Sprintf("prompt: %s", prompt))
 
@@ -1164,7 +1169,7 @@ func ChatHandler(c *gin.Context) {
 		predictReq := llm.PredictOpts{
 			Prompt:  prompt,
 			Format:  req.Format,
-			Images:  images,
+			Images:  chat.CurrentImages,
 			Options: opts,
 		}
 		if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
@@ -1202,3 +1207,101 @@ func ChatHandler(c *gin.Context) {
 
 	streamResponse(c, ch)
 }
+
+// promptInfo stores the variables used to template a prompt, and the token length of the resulting template for some model
+type promptInfo struct {
+	vars     PromptVars
+	tokenLen int
+}
+
+// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
+// while preserving the most recent system message.
+func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, error) {
+	if len(chat.Prompts) == 0 {
+		return "", nil
+	}
+
+	var promptsToAdd []promptInfo
+	var totalTokenLength int
+	var systemPromptIncluded bool
+
+	// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
+	for i := len(chat.Prompts) - 1; i >= 0; i-- {
+		promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1)
+		if err != nil {
+			return "", err
+		}
+
+		encodedTokens, err := loaded.runner.Encode(ctx, promptText)
+		if err != nil {
+			return "", err
+		}
+
+		if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 {
+			break // reached max context length, stop adding more prompts
+		}
+
+		totalTokenLength += len(encodedTokens)
+		systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != ""
+		promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)})
+	}
+
+	// ensure the system prompt is included, if not already
+	if chat.LastSystem != "" && !systemPromptIncluded {
+		var err error
+		promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd)
+		if err != nil {
+			return "", err
+		}
+	}
+
+	promptsToAdd[len(promptsToAdd)-1].vars.First = true
+
+	// construct the final prompt string from the prompts which fit within the context window
+	var result string
+	for i, prompt := range promptsToAdd {
+		promptText, err := promptString(model, prompt.vars, i == 0)
+		if err != nil {
+			return "", err
+		}
+		result = promptText + result
+	}
+	return result, nil
+}
+
+// promptString applies the model template to the prompt
+func promptString(model *Model, vars PromptVars, isMostRecent bool) (string, error) {
+	if isMostRecent {
+		p, err := model.PreResponsePrompt(vars)
+		if err != nil {
+			return "", fmt.Errorf("pre-response template: %w", err)
+		}
+		return p, nil
+	}
+	p, err := Prompt(model.Template, vars)
+	if err != nil {
+		return "", err
+	}
+	return p, nil
+}
+
+// includeSystemPrompt adjusts the prompts to include the system prompt.
+func includeSystemPrompt(ctx context.Context, systemPrompt string, totalTokenLength int, promptsToAdd []promptInfo) ([]promptInfo, error) {
+	systemTokens, err := loaded.runner.Encode(ctx, systemPrompt)
+	if err != nil {
+		return nil, err
+	}
+
+	for i := len(promptsToAdd) - 1; i >= 0; i-- {
+		if totalTokenLength+len(systemTokens) <= loaded.NumCtx {
+			promptsToAdd[i].vars.System = systemPrompt
+			return promptsToAdd[:i+1], nil
+		}
+		totalTokenLength -= promptsToAdd[i].tokenLen
+	}
+
+	// if got here, system did not fit anywhere, so return the most recent prompt with the system message set
+	recent := promptsToAdd[len(promptsToAdd)-1]
+	recent.vars.System = systemPrompt
+	return []promptInfo{recent}, nil
+}

+ 255 - 0
server/routes_test.go

@@ -16,6 +16,7 @@ import (
 	"github.com/stretchr/testify/assert"
 
 	"github.com/jmorganca/ollama/api"
+	"github.com/jmorganca/ollama/llm"
 	"github.com/jmorganca/ollama/parser"
 	"github.com/jmorganca/ollama/version"
 )
@@ -239,3 +240,257 @@ func Test_Routes(t *testing.T) {
 
 	}
 }
+
+func Test_ChatPrompt(t *testing.T) {
+	tests := []struct {
+		name     string
+		template string
+		chat     *ChatHistory
+		numCtx   int
+		runner   MockLLM
+		want     string
+		wantErr  string
+	}{
+		{
+			name:     "Single Message",
+			template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
+			chat: &ChatHistory{
+				Prompts: []PromptVars{
+					{
+						System: "You are a Wizard.",
+						Prompt: "What are the potion ingredients?",
+						First:  true,
+					},
+				},
+				LastSystem: "You are a Wizard.",
+			},
+			numCtx: 1,
+			runner: MockLLM{
+				encoding: []int{1}, // fit the ctxLen
+			},
+			want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
+		},
+		{
+			name:     "First Message",
+			template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]",
+			chat: &ChatHistory{
+				Prompts: []PromptVars{
+					{
+						System:   "You are a Wizard.",
+						Prompt:   "What are the potion ingredients?",
+						Response: "eye of newt",
+						First:    true,
+					},
+					{
+						Prompt: "Anything else?",
+					},
+				},
+				LastSystem: "You are a Wizard.",
+			},
+			numCtx: 2,
+			runner: MockLLM{
+				encoding: []int{1}, // fit the ctxLen
+			},
+			want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST]   Anything else? [/INST]",
+		},
+		{
+			name:     "Message History",
+			template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
+			chat: &ChatHistory{
+				Prompts: []PromptVars{
+					{
+						System:   "You are a Wizard.",
+						Prompt:   "What are the potion ingredients?",
+						Response: "sugar",
+						First:    true,
+					},
+					{
+						Prompt: "Anything else?",
+					},
+				},
+				LastSystem: "You are a Wizard.",
+			},
+			numCtx: 4,
+			runner: MockLLM{
+				encoding: []int{1}, // fit the ctxLen, 1 for each message
+			},
+			want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST]  Anything else? [/INST]",
+		},
+		{
+			name:     "Assistant Only",
+			template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
+			chat: &ChatHistory{
+				Prompts: []PromptVars{
+					{
+						Response: "everything nice",
+						First:    true,
+					},
+				},
+			},
+			numCtx: 1,
+			runner: MockLLM{
+				encoding: []int{1},
+			},
+			want: "[INST]   [/INST]everything nice",
+		},
+		{
+			name:     "Message History Truncated, No System",
+			template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
+			chat: &ChatHistory{
+				Prompts: []PromptVars{
+					{
+						Prompt:   "What are the potion ingredients?",
+						Response: "sugar",
+						First:    true,
+					},
+					{
+						Prompt:   "Anything else?",
+						Response: "spice",
+					},
+					{
+						Prompt: "... and?",
+					},
+				},
+			},
+			numCtx: 2, // only 1 message from history and most recent message
+			runner: MockLLM{
+				encoding: []int{1},
+			},
+			want: "[INST]  Anything else? [/INST]spice[INST]  ... and? [/INST]",
+		},
+		{
+			name:     "System is Preserved when Truncated",
+			template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
+			chat: &ChatHistory{
+				Prompts: []PromptVars{
+					{
+						Prompt:   "What are the magic words?",
+						Response: "abracadabra",
+					},
+					{
+						Prompt: "What is the spell for invisibility?",
+					},
+				},
+				LastSystem: "You are a wizard.",
+			},
+			numCtx: 2,
+			runner: MockLLM{
+				encoding: []int{1},
+			},
+			want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]",
+		},
+		{
+			name:     "System is Preserved when Length Exceeded",
+			template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
+			chat: &ChatHistory{
+				Prompts: []PromptVars{
+					{
+						Prompt:   "What are the magic words?",
+						Response: "abracadabra",
+					},
+					{
+						Prompt: "What is the spell for invisibility?",
+					},
+				},
+				LastSystem: "You are a wizard.",
+			},
+			numCtx: 1,
+			runner: MockLLM{
+				encoding: []int{1},
+			},
+			want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]",
+		},
+		{
+			name:     "First is Preserved when Truncated",
+			template: "[INST] {{ if .First }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]",
+
+			chat: &ChatHistory{
+				Prompts: []PromptVars{
+					// first message omitted for test
+					{
+						Prompt:   "Do you have a magic hat?",
+						Response: "Of course.",
+					},
+					{
+						Prompt: "What is the spell for invisibility?",
+					},
+				},
+				LastSystem: "You are a wizard.",
+			},
+			numCtx: 3, // two most recent messages and room for system message
+			runner: MockLLM{
+				encoding: []int{1},
+			},
+			want: "[INST] You are a wizard. Do you have a magic hat? [/INST]Of course.[INST] What is the spell for invisibility? [/INST]",
+		},
+		{
+			name:     "Most recent message is returned when longer than ctxLen",
+			template: "[INST] {{ .Prompt }} [/INST]",
+
+			chat: &ChatHistory{
+				Prompts: []PromptVars{
+					{
+						Prompt: "What is the spell for invisibility?",
+						First:  true,
+					},
+				},
+			},
+			numCtx: 1, // two most recent messages
+			runner: MockLLM{
+				encoding: []int{1, 2},
+			},
+			want: "[INST] What is the spell for invisibility? [/INST]",
+		},
+	}
+
+	for _, testCase := range tests {
+		tt := testCase
+		m := &Model{
+			Template: tt.template,
+		}
+		t.Run(tt.name, func(t *testing.T) {
+			loaded.runner = &tt.runner
+			loaded.Options = &api.Options{
+				Runner: api.Runner{
+					NumCtx: tt.numCtx,
+				},
+			}
+			got, err := trimmedPrompt(context.Background(), tt.chat, m)
+			if tt.wantErr != "" {
+				if err == nil {
+					t.Errorf("ChatPrompt() expected error, got nil")
+				}
+				if !strings.Contains(err.Error(), tt.wantErr) {
+					t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
+				}
+			}
+			if got != tt.want {
+				t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
+
+type MockLLM struct {
+	encoding []int
+}
+
+func (llm *MockLLM) Predict(ctx context.Context, pred llm.PredictOpts, fn func(llm.PredictResult)) error {
+	return nil
+}
+
+func (llm *MockLLM) Encode(ctx context.Context, prompt string) ([]int, error) {
+	return llm.encoding, nil
+}
+
+func (llm *MockLLM) Decode(ctx context.Context, tokens []int) (string, error) {
+	return "", nil
+}
+
+func (llm *MockLLM) Embedding(ctx context.Context, input string) ([]float64, error) {
+	return []float64{}, nil
+}
+
+func (llm *MockLLM) Close() {
+	// do nothing
+}