瀏覽代碼

update message processing

Michael Yang 10 月之前
父節點
當前提交
269ed6e6a2
共有 6 個文件被更改,包括 679 次插入714 次删除
  1. 13 4
      server/images.go
  2. 44 187
      server/prompt.go
  3. 165 166
      server/prompt_test.go
  4. 157 345
      server/routes.go
  5. 151 8
      template/template.go
  6. 149 4
      template/template_test.go

+ 13 - 4
server/images.go

@@ -34,6 +34,8 @@ import (
 	"github.com/ollama/ollama/version"
 )
 
+var errCapabilityCompletion = errors.New("completion")
+
 type Capability string
 
 const CapabilityCompletion = Capability("completion")
@@ -62,7 +64,10 @@ type Model struct {
 	Template *template.Template
 }
 
-func (m *Model) Has(caps ...Capability) bool {
+// CheckCapabilities checks if the model has the specified capabilities returning an error describing
+// any missing or unknown capabilities
+func (m *Model) CheckCapabilities(caps ...Capability) error {
+	var errs []error
 	for _, cap := range caps {
 		switch cap {
 		case CapabilityCompletion:
@@ -81,15 +86,19 @@ func (m *Model) Has(caps ...Capability) bool {
 			}
 
 			if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
-				return false
+				errs = append(errs, errCapabilityCompletion)
 			}
 		default:
 			slog.Error("unknown capability", "capability", cap)
-			return false
+			return fmt.Errorf("unknown capability: %s", cap)
 		}
 	}
 
-	return true
+	if err := errors.Join(errs...); err != nil {
+		return fmt.Errorf("missing capabilities: %w", errors.Join(errs...))
+	}
+
+	return nil
 }
 
 func (m *Model) String() string {

+ 44 - 187
server/prompt.go

@@ -1,217 +1,74 @@
 package server
 
 import (
-	"fmt"
+	"bytes"
+	"context"
 	"log/slog"
-	"strings"
-
-	"text/template/parse"
+	"slices"
 
 	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/template"
 )
 
-// isResponseNode checks if the node contains .Response
-func isResponseNode(node *parse.ActionNode) bool {
-	for _, cmd := range node.Pipe.Cmds {
-		for _, arg := range cmd.Args {
-			if fieldNode, ok := arg.(*parse.FieldNode); ok && len(fieldNode.Ident) > 0 {
-				if fieldNode.Ident[0] == "Response" {
-					return true
-				}
-			}
-		}
-	}
-	return false
-}
-
-// formatTemplateForResponse formats the template AST to:
-// 1. remove all nodes after the first .Response (if generate=true)
-// 2. add a .Response node to the end if it doesn't exist
-// TODO(jmorganca): this should recursively cut the template before the first .Response
-func formatTemplateForResponse(tmpl *template.Template, generate bool) {
-	var found bool
-	for i, node := range tmpl.Tree.Root.Nodes {
-		if actionNode, ok := node.(*parse.ActionNode); ok {
-			if isResponseNode(actionNode) {
-				found = true
-				if generate {
-					tmpl.Tree.Root.Nodes = tmpl.Tree.Root.Nodes[:i+1]
-					break
-				}
-			}
+func chatPrompt(ctx context.Context, r *runnerRef, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
+	// extract system messages which should always be included
+	var system []api.Message
+	msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
+		if m.Role == "system" {
+			system = append(system, m)
+			return true
 		}
-	}
-
-	if !found {
-		// add the response node if it doesn't exist
-		responseFieldNode := &parse.FieldNode{NodeType: parse.NodeField, Ident: []string{"Response"}}
-		responsePipeNode := &parse.PipeNode{NodeType: parse.NodePipe, Cmds: []*parse.CommandNode{{NodeType: parse.NodeCommand, Args: []parse.Node{responseFieldNode}}}}
-		responseActionNode := &parse.ActionNode{NodeType: parse.NodeAction, Pipe: responsePipeNode}
-		tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, responseActionNode)
-	}
-}
-
-// Prompt renders a prompt from a template. If generate is set to true,
-// the response and parts of the template following it are not rendered
-func Prompt(tmpl *template.Template, system, prompt, response string, generate bool) (string, error) {
-	formatTemplateForResponse(tmpl, generate)
-
-	vars := map[string]any{
-		"System":   system,
-		"Prompt":   prompt,
-		"Response": response,
-	}
-
-	var sb strings.Builder
-	if err := tmpl.Execute(&sb, vars); err != nil {
-		return "", err
-	}
 
-	return sb.String(), nil
-}
+		return false
+	})
 
-func countTokens(tmpl *template.Template, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
-	rendered, err := Prompt(tmpl, system, prompt, response, false)
-	if err != nil {
-		return 0, err
+	if len(system) == 0 && r.model.System != "" {
+		// add model system prompt since it wasn't provided
+		system = append(system, api.Message{Role: "system", Content: r.model.System})
 	}
 
-	tokens, err := encode(rendered)
-	if err != nil {
-		slog.Error("failed to encode prompt", "err", err)
-		return 0, err
-	}
-
-	return len(tokens), err
-}
-
-// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
-func ChatPrompt(tmpl *template.Template, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
-	type prompt struct {
-		System   string
-		Prompt   string
-		Response string
-
-		images []int
-		tokens int
-	}
-
-	var p prompt
-
-	// iterate through messages to build up {system,user,response} prompts
-	var imgId int
-	var prompts []prompt
-	for _, msg := range messages {
-		switch strings.ToLower(msg.Role) {
-		case "system":
-			if p.System != "" || p.Prompt != "" || p.Response != "" {
-				prompts = append(prompts, p)
-				p = prompt{}
-			}
-
-			p.System = msg.Content
-		case "user":
-			if p.Prompt != "" || p.Response != "" {
-				prompts = append(prompts, p)
-				p = prompt{}
-			}
-
-			var sb strings.Builder
-			for range msg.Images {
-				fmt.Fprintf(&sb, "[img-%d] ", imgId)
-				p.images = append(p.images, imgId)
-				imgId += 1
-			}
-
-			sb.WriteString(msg.Content)
-			p.Prompt = sb.String()
-		case "assistant":
-			if p.Response != "" {
-				prompts = append(prompts, p)
-				p = prompt{}
-			}
-
-			p.Response = msg.Content
-		default:
-			return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
+	n := len(msgs) - 1
+	for i := n - 1; i >= 0; i-- {
+		var b bytes.Buffer
+		if err := r.model.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
+			return "", nil, err
 		}
-	}
-
-	// add final prompt
-	if p.System != "" || p.Prompt != "" || p.Response != "" {
-		prompts = append(prompts, p)
-	}
 
-	// calculate token lengths for each prompt, estimating 768 tokens per images
-	for i, p := range prompts {
-		tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode)
+		s, err := r.llama.Tokenize(ctx, b.String())
 		if err != nil {
-			return "", err
+			return "", nil, err
 		}
 
-		prompts[i].tokens = tokens + len(prompts[i].images)*768
-	}
-
-	// truncate images and prompts starting from the beginning of the list
-	// until either one prompt remains or the total tokens fits the context window
-	// TODO (jmorganca): this doesn't account for the context window room required for the response
-	for {
-		var required int
-		for _, p := range prompts {
-			required += p.tokens
+		c := len(s)
+		if r.model.ProjectorPaths != nil {
+			for _, m := range msgs[i:] {
+				// TODO: get image embedding length from project metadata
+				c += 768 * len(m.Images)
+			}
 		}
 
-		required += 1 // for bos token
-
-		if required <= window {
-			slog.Debug("prompt now fits in context window", "required", required, "window", window)
+		if c > r.NumCtx {
+			slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
 			break
+		} else {
+			n = i
 		}
+	}
 
-		prompt := &prompts[0]
-
-		if len(prompt.images) > 1 {
-			img := prompt.images[0]
-			slog.Debug("prompt longer than context window, removing image", "id", img, "required", required, "window", window)
-			prompt.images = prompt.images[1:]
-			prompt.Prompt = strings.Replace(prompt.Prompt, fmt.Sprintf(" [img-%d]", img), "", 1)
-			prompt.tokens -= 768
-			continue
-		}
-
-		if len(prompts) > 1 {
-			slog.Debug("required tokens longer than context window, removing first prompt", "prompt", prompts[0].tokens, "required", required, "window", window)
-			system := prompt.System
-			prompts = prompts[1:]
-
-			if system != "" && prompts[0].System == "" {
-				prompts[0].System = system
-
-				tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode)
-				if err != nil {
-					return "", err
-				}
-
-				prompts[0].tokens = tokens + len(prompts[0].images)*768
-			}
-
-			continue
-		}
-
-		// stop truncating if there's only one prompt left
-		break
+	var b bytes.Buffer
+	if err := r.model.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
+		return "", nil, err
 	}
 
-	var sb strings.Builder
-	for i, p := range prompts {
-		// last prompt should leave the response unrendered (for completion)
-		rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1)
-		if err != nil {
-			return "", err
+	for _, m := range msgs[n:] {
+		for _, i := range m.Images {
+			images = append(images, llm.ImageData{
+				ID:   len(images),
+				Data: i,
+			})
 		}
-		sb.WriteString(rendered)
 	}
 
-	return sb.String(), nil
+	return b.String(), images, nil
 }

+ 165 - 166
server/prompt_test.go

@@ -1,215 +1,214 @@
 package server
 
 import (
+	"bytes"
+	"context"
 	"strings"
 	"testing"
 
 	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/template"
 )
 
-func TestPrompt(t *testing.T) {
-	tests := []struct {
-		name     string
-		template string
-		system   string
-		prompt   string
-		response string
-		generate bool
-		want     string
-	}{
-		{
-			name:     "simple prompt",
-			template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
-			system:   "You are a Wizard.",
-			prompt:   "What are the potion ingredients?",
-			want:     "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
-		},
-		{
-			name:     "implicit response",
-			template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
-			system:   "You are a Wizard.",
-			prompt:   "What are the potion ingredients?",
-			response: "I don't know.",
-			want:     "[INST] You are a Wizard. What are the potion ingredients? [/INST]I don't know.",
-		},
-		{
-			name:     "response",
-			template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
-			system:   "You are a Wizard.",
-			prompt:   "What are the potion ingredients?",
-			response: "I don't know.",
-			want:     "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
-		},
-		{
-			name:     "cut",
-			template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
-			system:   "You are a Wizard.",
-			prompt:   "What are the potion ingredients?",
-			response: "I don't know.",
-			generate: true,
-			want:     "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.",
-		},
-		{
-			name:     "nocut",
-			template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
-			system:   "You are a Wizard.",
-			prompt:   "What are the potion ingredients?",
-			response: "I don't know.",
-			want:     "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.</assistant>",
-		},
-	}
-
-	for _, tc := range tests {
-		t.Run(tc.name, func(t *testing.T) {
-			tmpl, err := template.Parse(tc.template)
-			if err != nil {
-				t.Fatal(err)
-			}
-
-			got, err := Prompt(tmpl, tc.system, tc.prompt, tc.response, tc.generate)
-			if err != nil {
-				t.Errorf("error = %v", err)
-			}
+type mock struct {
+	llm.LlamaServer
+}
 
-			if got != tc.want {
-				t.Errorf("got = %v, want %v", got, tc.want)
-			}
-		})
+func (m mock) Tokenize(_ context.Context, s string) (tokens []int, err error) {
+	for range strings.Fields(s) {
+		tokens = append(tokens, len(tokens))
 	}
+
+	return
 }
 
 func TestChatPrompt(t *testing.T) {
-	tests := []struct {
-		name     string
-		template string
-		messages []api.Message
-		window   int
-		want     string
+	type expect struct {
+		prompt string
+		images [][]byte
+	}
+
+	cases := []struct {
+		name  string
+		limit int
+		msgs  []api.Message
+		expect
 	}{
 		{
-			name:     "simple prompt",
-			template: "[INST] {{ .Prompt }} [/INST]",
-			messages: []api.Message{
-				{Role: "user", Content: "Hello"},
+			name:  "messages",
+			limit: 64,
+			msgs: []api.Message{
+				{Role: "user", Content: "You're a test, Harry!"},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
+			},
+			expect: expect{
+				prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
 			},
-			window: 1024,
-			want:   "[INST] Hello [/INST]",
-		},
-		{
-			name:     "with system message",
-			template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
-			messages: []api.Message{
-				{Role: "system", Content: "You are a Wizard."},
-				{Role: "user", Content: "Hello"},
-			},
-			window: 1024,
-			want:   "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
 		},
 		{
-			name:     "with response",
-			template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }}",
-			messages: []api.Message{
-				{Role: "system", Content: "You are a Wizard."},
-				{Role: "user", Content: "Hello"},
-				{Role: "assistant", Content: "I am?"},
-			},
-			window: 1024,
-			want:   "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST] I am?",
+			name: "truncate messages",
+			limit: 1,
+			msgs: []api.Message{
+				{Role: "user", Content: "You're a test, Harry!"},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
+			},
+			expect: expect{
+				prompt: "A test. And a thumping good one at that, I'd wager. ",
+			},
 		},
 		{
-			name:     "with implicit response",
-			template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
-			messages: []api.Message{
-				{Role: "system", Content: "You are a Wizard."},
-				{Role: "user", Content: "Hello"},
-				{Role: "assistant", Content: "I am?"},
-			},
-			window: 1024,
-			want:   "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]I am?",
+			name: "truncate messages with image",
+			limit: 64,
+			msgs: []api.Message{
+				{Role: "user", Content: "You're a test, Harry!"},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("something")}},
+			},
+			expect: expect{
+				prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
+				images: [][]byte{
+					[]byte("something"),
+				},
+			},
 		},
 		{
-			name:     "with conversation",
-			template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
-			messages: []api.Message{
-				{Role: "system", Content: "You are a Wizard."},
-				{Role: "user", Content: "What are the potion ingredients?"},
-				{Role: "assistant", Content: "sugar"},
-				{Role: "user", Content: "Anything else?"},
-			},
-			window: 1024,
-			want:   "[INST] <<SYS>>You are a Wizard.<</SYS>> What are the potion ingredients? [/INST] sugar [INST] Anything else? [/INST] ",
+			name: "truncate messages with images",
+			limit: 64,
+			msgs: []api.Message{
+				{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
+			},
+			expect: expect{
+				prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
+				images: [][]byte{
+					[]byte("somethingelse"),
+				},
+			},
 		},
 		{
-			name:     "with truncation",
-			template: "{{ .System }} {{ .Prompt }} {{ .Response }} ",
-			messages: []api.Message{
-				{Role: "system", Content: "You are a Wizard."},
-				{Role: "user", Content: "Hello"},
-				{Role: "assistant", Content: "I am?"},
-				{Role: "user", Content: "Why is the sky blue?"},
-				{Role: "assistant", Content: "The sky is blue from rayleigh scattering"},
-			},
-			window: 10,
-			want:   "You are a Wizard. Why is the sky blue? The sky is blue from rayleigh scattering",
+			name: "messages with images",
+			limit: 2048,
+			msgs: []api.Message{
+				{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
+			},
+			expect: expect{
+				prompt: "[img-0] You're a test, Harry! I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
+				images: [][]byte{
+					[]byte("something"),
+					[]byte("somethingelse"),
+				},
+			},
 		},
 		{
-			name:     "images",
-			template: "{{ .System }} {{ .Prompt }}",
-			messages: []api.Message{
-				{Role: "system", Content: "You are a Wizard."},
-				{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}},
-			},
-			window: 1024,
-			want:   "You are a Wizard. [img-0] Hello",
+			name: "message with image tag",
+			limit: 2048,
+			msgs: []api.Message{
+				{Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
+			},
+			expect: expect{
+				prompt: "You're a test, Harry! [img-0] I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
+				images: [][]byte{
+					[]byte("something"),
+					[]byte("somethingelse"),
+				},
+			},
 		},
 		{
-			name:     "images truncated",
-			template: "{{ .System }} {{ .Prompt }}",
-			messages: []api.Message{
-				{Role: "system", Content: "You are a Wizard."},
-				{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}},
-			},
-			window: 1024,
-			want:   "You are a Wizard. [img-0] [img-1] Hello",
+			name: "messages with interleaved images",
+			limit: 2048,
+			msgs: []api.Message{
+				{Role: "user", Content: "You're a test, Harry!"},
+				{Role: "user", Images: []api.ImageData{[]byte("something")}},
+				{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
+			},
+			expect: expect{
+				prompt: "You're a test, Harry!\n\n[img-0]\n\n[img-1] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
+				images: [][]byte{
+					[]byte("something"),
+					[]byte("somethingelse"),
+				},
+			},
 		},
 		{
-			name:     "empty list",
-			template: "{{ .System }} {{ .Prompt }}",
-			messages: []api.Message{},
-			window:   1024,
-			want:     "",
+			name: "truncate message with interleaved images",
+			limit: 1024,
+			msgs: []api.Message{
+				{Role: "user", Content: "You're a test, Harry!"},
+				{Role: "user", Images: []api.ImageData{[]byte("something")}},
+				{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
+			},
+			expect: expect{
+				prompt: "[img-0] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
+				images: [][]byte{
+					[]byte("somethingelse"),
+				},
+			},
 		},
 		{
-			name:     "empty prompt",
-			template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
-			messages: []api.Message{
-				{Role: "user", Content: ""},
+			name: "message with system prompt",
+			limit: 2048,
+			msgs: []api.Message{
+				{Role: "system", Content: "You are the Test Who Lived."},
+				{Role: "user", Content: "You're a test, Harry!"},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
+			},
+			expect: expect{
+				prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
 			},
-			window: 1024,
-			want:   "",
 		},
 	}
 
-	encode := func(s string) ([]int, error) {
-		words := strings.Fields(s)
-		return make([]int, len(words)), nil
+	tmpl, err := template.Parse(`
+{{- if .System }}{{ .System }} {{ end }}
+{{- if .Prompt }}{{ .Prompt }} {{ end }}
+{{- if .Response }}{{ .Response }} {{ end }}`)
+	if err != nil {
+		t.Fatal(err)
 	}
 
-	for _, tc := range tests {
-		t.Run(tc.name, func(t *testing.T) {
-			tmpl, err := template.Parse(tc.template)
+	for _, tt := range cases {
+		t.Run(tt.name, func(t *testing.T) {
+			r := runnerRef{
+				llama:   mock{},
+				model:   &Model{Template: tmpl, ProjectorPaths: []string{"vision"}},
+				Options: &api.Options{},
+			}
+
+			r.NumCtx = tt.limit
+			prompt, images, err := chatPrompt(context.TODO(), &r, tt.msgs)
 			if err != nil {
 				t.Fatal(err)
 			}
 
-			got, err := ChatPrompt(tmpl, tc.messages, tc.window, encode)
-			if err != nil {
-				t.Errorf("error = %v", err)
+			if tt.prompt != prompt {
+				t.Errorf("expected %q, got %q", tt.prompt, prompt)
 			}
 
-			if got != tc.want {
-				t.Errorf("got: %q, want: %q", got, tc.want)
+			if len(images) != len(tt.images) {
+				t.Fatalf("expected %d images, got %d", len(tt.images), len(images))
+			}
+
+			for i := range images {
+				if images[i].ID != i {
+					t.Errorf("expected ID %d, got %d", i, images[i].ID)
+				}
+
+				if !bytes.Equal(images[i].Data, tt.images[i]) {
+					t.Errorf("expected %q, got %q", tt.images[i], images[i])
+				}
 			}
 		})
 	}

+ 157 - 345
server/routes.go

@@ -1,13 +1,13 @@
 package server
 
 import (
+	"bytes"
 	"cmp"
 	"context"
 	"encoding/json"
 	"errors"
 	"fmt"
 	"io"
-	"io/fs"
 	"log/slog"
 	"net"
 	"net/http"
@@ -67,163 +67,140 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
 	return opts, nil
 }
 
-func isSupportedImageType(image []byte) bool {
-	contentType := http.DetectContentType(image)
-	allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
-	return slices.Contains(allowedTypes, contentType)
-}
-
-func (s *Server) GenerateHandler(c *gin.Context) {
-	checkpointStart := time.Now()
-	var req api.GenerateRequest
-	err := c.ShouldBindJSON(&req)
-
-	switch {
-	case errors.Is(err, io.EOF):
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
-		return
-	case err != nil:
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
-		return
+func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (*runnerRef, error) {
+	if name == "" {
+		return nil, errors.New("model is required")
 	}
 
-	// validate the request
-	switch {
-	case req.Model == "":
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
-		return
-	case len(req.Format) > 0 && req.Format != "json":
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
-		return
-	case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
-		return
-	}
-
-	for _, img := range req.Images {
-		if !isSupportedImageType(img) {
-			c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
-			return
-		}
-	}
-
-	model, err := GetModel(req.Model)
+	model, err := GetModel(name)
 	if err != nil {
-		var pErr *fs.PathError
-		if errors.As(err, &pErr) {
-			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
-			return
-		}
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
+		return nil, err
 	}
 
-	if !model.Has(CapabilityCompletion) {
-		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support generate", req.Model)})
-		return
+	if err := model.CheckCapabilities(caps...); err != nil {
+		return nil, fmt.Errorf("%s %w", name, err)
 	}
 
-	opts, err := modelOptions(model, req.Options)
+	opts, err := modelOptions(model, requestOpts)
 	if err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
+		return nil, err
 	}
 
-	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
+	runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
 	var runner *runnerRef
 	select {
-	case runner = <-rCh:
-	case err = <-eCh:
-		handleErrorResponse(c, err)
-		return
+	case runner = <-runnerCh:
+	case err = <-errCh:
+		return nil, err
 	}
 
-	// an empty request loads the model
-	// note: for a short while template was used in lieu
-	// of `raw` mode so we need to check for it too
-	if req.Prompt == "" && req.Template == "" && req.System == "" {
-		c.JSON(http.StatusOK, api.GenerateResponse{
-			CreatedAt:  time.Now().UTC(),
-			Model:      req.Model,
-			Done:       true,
-			DoneReason: "load",
-		})
+	return runner, nil
+}
+
+func (s *Server) GenerateHandler(c *gin.Context) {
+	var req api.GenerateRequest
+	if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
+		return
+	} else if err != nil {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 	}
 
-	tmpl, err := template.Parse(req.Template)
-	if err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+	if req.Format != "" && req.Format != "json" {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
+		return
+	} else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
 		return
 	}
 
-	checkpointLoaded := time.Now()
+	caps := []Capability{CapabilityCompletion}
+	r, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
+	if errors.Is(err, errCapabilityCompletion) {
+		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
+		return
+	} else if err != nil {
+		handleScheduleError(c, err)
+		return
+	}
 
-	var prompt string
-	switch {
-	case req.Raw:
-		prompt = req.Prompt
-	case req.Prompt != "":
-		if req.Template == "" {
-			tmpl = model.Template
-		}
+	images := make([]llm.ImageData, len(req.Images))
+	for i := range req.Images {
+		images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
+	}
 
-		if req.System == "" {
-			req.System = model.System
+	prompt := req.Prompt
+	if !req.Raw {
+		var msgs []api.Message
+		if req.System != "" {
+			msgs = append(msgs, api.Message{Role: "system", Content: req.System})
+		} else if r.model.System != "" {
+			msgs = append(msgs, api.Message{Role: "system", Content: r.model.System})
 		}
 
-		slog.Debug("generate handler", "prompt", req.Prompt)
-		slog.Debug("generate handler", "template", req.Template)
-		slog.Debug("generate handler", "system", req.System)
+		if req.Prompt != "" {
+			for _, i := range images {
+				msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
+			}
 
-		var sb strings.Builder
-		for i := range req.Images {
-			fmt.Fprintf(&sb, "[img-%d] ", i)
+			msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
 		}
 
-		sb.WriteString(req.Prompt)
-
-		p, err := Prompt(tmpl, req.System, sb.String(), "", true)
-		if err != nil {
-			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		if len(msgs) == 0 {
+			c.JSON(http.StatusOK, api.GenerateResponse{
+				Model:      req.Model,
+				CreatedAt:  time.Now().UTC(),
+				Done:       true,
+				DoneReason: "load",
+			})
 			return
 		}
 
-		sb.Reset()
+		tmpl := r.model.Template
+		if req.Template != "" {
+			tmpl, err = template.Parse(req.Template)
+			if err != nil {
+				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+				return
+			}
+		}
+
+		var b bytes.Buffer
 		if req.Context != nil {
-			prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context)
+			s, err := r.llama.Detokenize(c.Request.Context(), req.Context)
 			if err != nil {
 				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 				return
 			}
 
-			sb.WriteString(prev)
+			b.WriteString(s)
 		}
 
-		sb.WriteString(p)
+		if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil {
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
+		}
 
-		prompt = sb.String()
+		prompt = b.String()
 	}
 
-	slog.Debug("generate handler", "prompt", prompt)
+	slog.Debug("generate request", "prompt", prompt, "images", images)
 
 	ch := make(chan any)
-	var generated strings.Builder
 	go func() {
 		defer close(ch)
-
-		fn := func(r llm.CompletionResponse) {
-			// Build up the full response
-			if _, err := generated.WriteString(r.Content); err != nil {
-				ch <- gin.H{"error": err.Error()}
-				return
-			}
-
-			resp := api.GenerateResponse{
+		if err := r.llama.Completion(c.Request.Context(), llm.CompletionRequest{
+			Prompt:  prompt,
+			Images:  images,
+			Format:  req.Format,
+			Options: *r.Options,
+		}, func(r llm.CompletionResponse) {
+			ch <- api.GenerateResponse{
 				Model:      req.Model,
 				CreatedAt:  time.Now().UTC(),
-				Done:       r.Done,
 				Response:   r.Content,
+				Done:       r.Done,
 				DoneReason: r.DoneReason,
 				Metrics: api.Metrics{
 					PromptEvalCount:    r.PromptEvalCount,
@@ -232,77 +209,35 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 					EvalDuration:       r.EvalDuration,
 				},
 			}
-
-			if r.Done {
-				resp.TotalDuration = time.Since(checkpointStart)
-				resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
-
-				if !req.Raw {
-					p, err := Prompt(tmpl, req.System, req.Prompt, generated.String(), false)
-					if err != nil {
-						c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-						return
-					}
-
-					// TODO (jmorganca): encode() should not strip special tokens
-					tokens, err := runner.llama.Tokenize(c.Request.Context(), p)
-					if err != nil {
-						ch <- gin.H{"error": err.Error()}
-						return
-					}
-
-					resp.Context = append(req.Context, tokens...)
-				}
-			}
-
-			ch <- resp
-		}
-
-		var images []llm.ImageData
-		for i := range req.Images {
-			images = append(images, llm.ImageData{
-				ID:   i,
-				Data: req.Images[i],
-			})
-		}
-
-		// Start prediction
-		req := llm.CompletionRequest{
-			Prompt:  prompt,
-			Format:  req.Format,
-			Images:  images,
-			Options: opts,
-		}
-		if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil {
+		}); err != nil {
 			ch <- gin.H{"error": err.Error()}
 		}
 	}()
 
 	if req.Stream != nil && !*req.Stream {
-		// Accumulate responses into the final response
-		var final api.GenerateResponse
+		var r api.GenerateResponse
 		var sb strings.Builder
-		for resp := range ch {
-			switch r := resp.(type) {
+		for rr := range ch {
+			switch t := rr.(type) {
 			case api.GenerateResponse:
-				sb.WriteString(r.Response)
-				final = r
+				sb.WriteString(t.Response)
+				r = t
 			case gin.H:
-				if errorMsg, ok := r["error"].(string); ok {
-					c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
-					return
-				} else {
-					c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
-					return
+				msg, ok := t["error"].(string)
+				if !ok {
+					msg = "unexpected error format in response"
 				}
+
+				c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
+				return
 			default:
-				c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
+				c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
 				return
 			}
 		}
 
-		final.Response = sb.String()
-		c.JSON(http.StatusOK, final)
+		r.Response = sb.String()
+		c.JSON(http.StatusOK, r)
 		return
 	}
 
@@ -311,44 +246,17 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 
 func (s *Server) EmbeddingsHandler(c *gin.Context) {
 	var req api.EmbeddingRequest
-	err := c.ShouldBindJSON(&req)
-	switch {
-	case errors.Is(err, io.EOF):
+	if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
 		return
-	case err != nil:
+	} else if err != nil {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 	}
 
-	if req.Model == "" {
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
-		return
-	}
-
-	model, err := GetModel(req.Model)
+	r, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
 	if err != nil {
-		var pErr *fs.PathError
-		if errors.As(err, &pErr) {
-			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
-			return
-		}
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
-	}
-
-	opts, err := modelOptions(model, req.Options)
-	if err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
-	}
-
-	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
-	var runner *runnerRef
-	select {
-	case runner = <-rCh:
-	case err = <-eCh:
-		handleErrorResponse(c, err)
+		handleScheduleError(c, err)
 		return
 	}
 
@@ -358,17 +266,14 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
 		return
 	}
 
-	embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
+	embedding, err := r.llama.Embedding(c.Request.Context(), req.Prompt)
 	if err != nil {
 		slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
 		c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
 		return
 	}
 
-	resp := api.EmbeddingResponse{
-		Embedding: embedding,
-	}
-	c.JSON(http.StatusOK, resp)
+	c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: embedding})
 }
 
 func (s *Server) PullModelHandler(c *gin.Context) {
@@ -649,9 +554,9 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
 		}
 	}
 
-	msgs := make([]api.Message, 0)
-	for _, msg := range m.Messages {
-		msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
+	msgs := make([]api.Message, len(m.Messages))
+	for i, msg := range m.Messages {
+		msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
 	}
 
 	n := model.ParseName(req.Model)
@@ -1214,132 +1119,55 @@ func (s *Server) ProcessHandler(c *gin.Context) {
 	c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
 }
 
-// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
-func chatPrompt(ctx context.Context, runner *runnerRef, template *template.Template, messages []api.Message, numCtx int) (string, error) {
-	encode := func(s string) ([]int, error) {
-		return runner.llama.Tokenize(ctx, s)
-	}
-
-	prompt, err := ChatPrompt(template, messages, numCtx, encode)
-	if err != nil {
-		return "", err
-	}
-
-	return prompt, nil
-}
-
 func (s *Server) ChatHandler(c *gin.Context) {
-	checkpointStart := time.Now()
-
 	var req api.ChatRequest
-	err := c.ShouldBindJSON(&req)
-	switch {
-	case errors.Is(err, io.EOF):
+	if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
 		return
-	case err != nil:
+	} else if err != nil {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 	}
 
-	// validate the request
-	switch {
-	case req.Model == "":
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
-		return
-	case len(req.Format) > 0 && req.Format != "json":
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
-		return
-	}
-
-	model, err := GetModel(req.Model)
-	if err != nil {
-		var pErr *fs.PathError
-		if errors.As(err, &pErr) {
-			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
-			return
-		}
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
-	}
-
-	if !model.Has(CapabilityCompletion) {
-		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support chat", req.Model)})
-		return
-	}
-
-	opts, err := modelOptions(model, req.Options)
-	if err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
-	}
-
-	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
-	var runner *runnerRef
-	select {
-	case runner = <-rCh:
-	case err = <-eCh:
-		handleErrorResponse(c, err)
+	caps := []Capability{CapabilityCompletion}
+	r, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
+	if errors.Is(err, errCapabilityCompletion) {
+		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
 		return
-	}
-
-	checkpointLoaded := time.Now()
-
-	// if the first message is not a system message, then add the model's default system message
-	if len(req.Messages) > 0 && req.Messages[0].Role != "system" {
-		req.Messages = append([]api.Message{
-			{
-				Role:    "system",
-				Content: model.System,
-			},
-		}, req.Messages...)
-	}
-
-	prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx)
-	if err != nil {
-		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+	} else if err != nil {
+		handleScheduleError(c, err)
 		return
 	}
 
-	// an empty request loads the model
-	if len(req.Messages) == 0 || prompt == "" {
-		resp := api.ChatResponse{
-			CreatedAt:  time.Now().UTC(),
+	if len(req.Messages) == 0 {
+		c.JSON(http.StatusOK, api.ChatResponse{
 			Model:      req.Model,
+			CreatedAt:  time.Now().UTC(),
+			Message:    api.Message{Role: "assistant"},
 			Done:       true,
 			DoneReason: "load",
-			Message:    api.Message{Role: "assistant"},
-		}
-		c.JSON(http.StatusOK, resp)
+		})
 		return
 	}
 
-	// only send images that are in the prompt
-	var i int
-	var images []llm.ImageData
-	for _, m := range req.Messages {
-		for _, img := range m.Images {
-			if !isSupportedImageType(img) {
-				c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
-				return
-			}
-
-			if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) {
-				images = append(images, llm.ImageData{Data: img, ID: i})
-			}
-			i += 1
-		}
+	prompt, images, err := chatPrompt(c.Request.Context(), r, req.Messages)
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
 	}
 
-	slog.Debug("chat handler", "prompt", prompt, "images", len(images))
+	slog.Debug("chat request", "images", len(images), "prompt", prompt)
 
 	ch := make(chan any)
-
 	go func() {
 		defer close(ch)
-
-		fn := func(r llm.CompletionResponse) {
-			resp := api.ChatResponse{
+		if err := r.llama.Completion(c.Request.Context(), llm.CompletionRequest{
+			Prompt:  prompt,
+			Images:  images,
+			Format:  req.Format,
+			Options: *r.Options,
+		}, func(r llm.CompletionResponse) {
+			ch <- api.ChatResponse{
 				Model:      req.Model,
 				CreatedAt:  time.Now().UTC(),
 				Message:    api.Message{Role: "assistant", Content: r.Content},
@@ -1352,64 +1180,48 @@ func (s *Server) ChatHandler(c *gin.Context) {
 					EvalDuration:       r.EvalDuration,
 				},
 			}
-
-			if r.Done {
-				resp.TotalDuration = time.Since(checkpointStart)
-				resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
-			}
-
-			ch <- resp
-		}
-
-		if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
-			Prompt:  prompt,
-			Format:  req.Format,
-			Images:  images,
-			Options: opts,
-		}, fn); err != nil {
+		}); err != nil {
 			ch <- gin.H{"error": err.Error()}
 		}
 	}()
 
 	if req.Stream != nil && !*req.Stream {
-		// Accumulate responses into the final response
-		var final api.ChatResponse
+		var r api.ChatResponse
 		var sb strings.Builder
-		for resp := range ch {
-			switch r := resp.(type) {
+		for rr := range ch {
+			switch t := rr.(type) {
 			case api.ChatResponse:
-				sb.WriteString(r.Message.Content)
-				final = r
+				sb.WriteString(t.Message.Content)
+				r = t
 			case gin.H:
-				if errorMsg, ok := r["error"].(string); ok {
-					c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
-					return
-				} else {
-					c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
-					return
+				msg, ok := t["error"].(string)
+				if !ok {
+					msg = "unexpected error format in response"
 				}
+
+				c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
+				return
 			default:
-				c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
+				c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
 				return
 			}
 		}
 
-		final.Message = api.Message{Role: "assistant", Content: sb.String()}
-		c.JSON(http.StatusOK, final)
+		r.Message.Content = sb.String()
+		c.JSON(http.StatusOK, r)
 		return
 	}
 
 	streamResponse(c, ch)
 }
 
-func handleErrorResponse(c *gin.Context, err error) {
-	if errors.Is(err, context.Canceled) {
+func handleScheduleError(c *gin.Context, err error) {
+	switch {
+	case errors.Is(err, context.Canceled):
 		c.JSON(499, gin.H{"error": "request canceled"})
-		return
-	}
-	if errors.Is(err, ErrMaxQueue) {
+	case errors.Is(err, ErrMaxQueue):
 		c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
-		return
+	default:
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 	}
-	c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 }

+ 151 - 8
template/template.go

@@ -5,6 +5,7 @@ import (
 	"embed"
 	"encoding/json"
 	"errors"
+	"fmt"
 	"io"
 	"math"
 	"slices"
@@ -14,6 +15,7 @@ import (
 	"text/template/parse"
 
 	"github.com/agnivade/levenshtein"
+	"github.com/ollama/ollama/api"
 	"golang.org/x/exp/maps"
 )
 
@@ -74,30 +76,78 @@ func Named(s string) (*named, error) {
 	return nil, errors.New("no matching template found")
 }
 
+var DefaultTemplate, _ = Parse("{{ .Prompt }}")
+
 type Template struct {
 	*template.Template
 	raw string
 }
 
-func (t *Template) String() string {
-	return t.raw
+var response = parse.ActionNode{
+	NodeType: parse.NodeAction,
+	Pipe: &parse.PipeNode{
+		NodeType: parse.NodePipe,
+		Cmds: []*parse.CommandNode{
+			{
+				NodeType: parse.NodeCommand,
+				Args: []parse.Node{
+					&parse.FieldNode{
+						NodeType: parse.NodeField,
+						Ident:    []string{"Response"},
+					},
+				},
+			},
+		},
+	},
 }
 
-var DefaultTemplate, _ = Parse("{{ .Prompt }}")
-
 func Parse(s string) (*Template, error) {
-	t, err := template.New("").Option("missingkey=zero").Parse(s)
+	tmpl := template.New("").Option("missingkey=zero").Funcs(template.FuncMap{
+		"toJson": func(v any) string {
+			b, err := json.Marshal(v)
+			if err != nil {
+				return ""
+			}
+
+			return string(b)
+		},
+		"isLastMessage": func(s []*api.Message, m *api.Message) bool {
+			for i := len(s) - 1; i >= 0; i-- {
+				if m.Role != s[i].Role {
+					continue
+				}
+
+				return m == s[i]
+			}
+
+			return false
+		},
+	})
+
+	tmpl, err := tmpl.Parse(s)
 	if err != nil {
 		return nil, err
 	}
 
-	return &Template{Template: t, raw: s}, nil
+	t := Template{Template: tmpl, raw: s}
+	if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
+		// touch up the template and append {{ .Response }}
+		tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
+	}
+
+	return &t, nil
+}
+
+func (t *Template) String() string {
+	return t.raw
 }
 
 func (t *Template) Vars() []string {
 	var vars []string
-	for _, n := range t.Tree.Root.Nodes {
-		vars = append(vars, parseNode(n)...)
+	for _, tt := range t.Templates() {
+		for _, n := range tt.Root.Nodes {
+			vars = append(vars, parseNode(n)...)
+		}
 	}
 
 	set := make(map[string]struct{})
@@ -110,6 +160,97 @@ func (t *Template) Vars() []string {
 	return vars
 }
 
+type Values struct {
+	Messages []api.Message
+}
+
+func (t *Template) Execute(w io.Writer, v Values) error {
+	system, collated := collate(v.Messages)
+	if slices.Contains(t.Vars(), "messages") {
+		return t.Template.Execute(w, map[string]any{
+			"System":   system,
+			"Messages": collated,
+		})
+	}
+
+	var b bytes.Buffer
+	var prompt, response string
+	for i, m := range collated {
+		if m.Role == "user" {
+			prompt = m.Content
+		} else {
+			response = m.Content
+		}
+
+		if i != len(collated)-1 && prompt != "" && response != "" {
+			if err := t.Template.Execute(&b, map[string]any{
+				"System":   "",
+				"Prompt":   prompt,
+				"Response": response,
+			}); err != nil {
+				return err
+			}
+
+			prompt = ""
+			response = ""
+		}
+	}
+
+	var cut bool
+	tree := t.Template.Copy()
+	// for the last message, cut everything after "{{ .Response }}"
+	tree.Root.Nodes = slices.DeleteFunc(tree.Root.Nodes, func(n parse.Node) bool {
+		if slices.Contains(parseNode(n), "Response") {
+			cut = true
+		}
+
+		return cut
+	})
+
+	if err := template.Must(template.New("").AddParseTree("", tree)).Execute(&b, map[string]any{
+		"System": system,
+		"Prompt": prompt,
+	}); err != nil {
+		return err
+	}
+
+	_, err := io.Copy(w, &b)
+	return err
+}
+
+func collate(msgs []api.Message) (system string, collated []*api.Message) {
+	var n int
+	for i := range msgs {
+		msg := msgs[i]
+		if msg.Role == "system" {
+			if system != "" {
+				system += "\n\n"
+			}
+
+			system += msg.Content
+			continue
+		}
+
+		for range msg.Images {
+			imageTag := fmt.Sprintf("[img-%d]", n)
+			if !strings.Contains(msg.Content, "[img]") {
+				msg.Content = strings.TrimSpace("[img] " + msg.Content)
+			}
+
+			msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1)
+			n++
+		}
+
+		if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
+			collated[len(collated)-1].Content += "\n\n" + msg.Content
+		} else {
+			collated = append(collated, &msg)
+		}
+	}
+
+	return
+}
+
 func parseNode(n parse.Node) []string {
 	switch n := n.(type) {
 	case *parse.ActionNode:
@@ -152,6 +293,8 @@ func parseNode(n parse.Node) []string {
 		return names
 	case *parse.FieldNode:
 		return n.Ident
+	case *parse.TemplateNode:
+		return parseNode(n.Pipe)
 	}
 
 	return nil

+ 149 - 4
template/template_test.go

@@ -11,6 +11,7 @@ import (
 	"testing"
 	"text/template"
 
+	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/llm"
 )
 
@@ -64,13 +65,12 @@ func TestParse(t *testing.T) {
 		template string
 		vars     []string
 	}{
-		{"{{ .Prompt }}", []string{"prompt"}},
-		{"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}},
+		{"{{ .Prompt }}", []string{"prompt", "response"}},
+		{"{{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system"}},
 		{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
-		{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "system", "tools"}},
+		{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
 		{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
 		{"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}},
-		{"{{ .Prompt }} {{ .Suffix }}", []string{"prompt", "suffix"}},
 	}
 
 	for _, tt := range cases {
@@ -87,3 +87,148 @@ func TestParse(t *testing.T) {
 		})
 	}
 }
+
+func TestExecuteWithMessages(t *testing.T) {
+	cases := []struct {
+		templates []string
+		values    Values
+		expected  string
+	}{
+		{
+			[]string{
+				`[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `,
+				`[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`,
+				`{{- range .Messages }}
+{{- if eq .Role "user" }}[INST] {{ if and (isLastMessage $.Messages .) $.System }}{{ $.System }}{{ print "\n\n" }}
+{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
+{{- end }}
+{{- end }}`,
+			},
+			Values{
+				Messages: []api.Message{
+					{Role: "user", Content: "Hello friend!"},
+					{Role: "assistant", Content: "Hello human!"},
+					{Role: "user", Content: "Yay!"},
+				},
+			},
+			`[INST] Hello friend![/INST] Hello human![INST] Yay![/INST] `,
+		},
+		{
+			[]string{
+				`[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `,
+				`[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`,
+				`
+{{- range .Messages }}
+{{- if eq .Role "user" }}[INST] {{ if and (isLastMessage $.Messages .) $.System }}{{ $.System }}{{ print "\n\n" }}
+{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
+{{- end }}
+{{- end }}`,
+			},
+			Values{
+				Messages: []api.Message{
+					{Role: "system", Content: "You are a helpful assistant!"},
+					{Role: "user", Content: "Hello friend!"},
+					{Role: "assistant", Content: "Hello human!"},
+					{Role: "user", Content: "Yay!"},
+				},
+			},
+			`[INST] Hello friend![/INST] Hello human![INST] You are a helpful assistant!
+
+Yay![/INST] `,
+		},
+		{
+			[]string{
+				`{{ if .System }}<|im_start|>system
+{{ .System }}<|im_end|>
+{{ end }}{{ if .Prompt }}<|im_start|>user
+{{ .Prompt }}<|im_end|>
+{{ end }}<|im_start|>assistant
+{{ .Response }}<|im_end|>
+`,
+				`
+{{- range .Messages }}
+{{- if and (eq .Role "user") (isLastMessage $.Messages .) $.System }}<|im_start|>system
+{{ $.System }}<|im_end|>{{ print "\n" }}
+{{- end }}<|im_start|>{{ .Role }}
+{{ .Content }}<|im_end|>{{ print "\n" }}
+{{- end }}<|im_start|>assistant
+`,
+			},
+			Values{
+				Messages: []api.Message{
+					{Role: "system", Content: "You are a helpful assistant!"},
+					{Role: "user", Content: "Hello friend!"},
+					{Role: "assistant", Content: "Hello human!"},
+					{Role: "user", Content: "Yay!"},
+				},
+			},
+			`<|im_start|>user
+Hello friend!<|im_end|>
+<|im_start|>assistant
+Hello human!<|im_end|>
+<|im_start|>system
+You are a helpful assistant!<|im_end|>
+<|im_start|>user
+Yay!<|im_end|>
+<|im_start|>assistant
+`,
+		},
+		{
+			[]string{
+				`{{ if .Prompt }}Question: {{ .Prompt }}
+
+{{ end }}Answer: {{ .Response }}
+
+`,
+				`
+{{- range .Messages }}
+{{- if eq .Role "user" }}Question: {{ .Content }}{{ print "\n\n" }}
+{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ print "\n\n" }}
+{{- end }}
+{{- end }}Answer: `,
+			},
+			Values{
+				Messages: []api.Message{
+					{Role: "user", Content: "What's in this image?", Images: []api.ImageData{[]byte("")}},
+					{Role: "assistant", Content: "It's a hot dog."},
+					{Role: "user", Content: "What's in _this_ image?"},
+					{Role: "user", Images: []api.ImageData{[]byte("")}},
+					{Role: "user", Content: "Is it a hot dog?"},
+				},
+			},
+			`Question: [img-0] What's in this image?
+
+Answer: It's a hot dog.
+
+Question: What's in _this_ image?
+
+[img-1]
+
+Is it a hot dog?
+
+Answer: `,
+		},
+	}
+
+	for _, tt := range cases {
+		t.Run("", func(t *testing.T) {
+			for _, tmpl := range tt.templates {
+				t.Run("", func(t *testing.T) {
+					tmpl, err := Parse(tmpl)
+					if err != nil {
+						t.Fatal(err)
+					}
+
+					var b bytes.Buffer
+					if err := tmpl.Execute(&b, tt.values); err != nil {
+						t.Fatal(err)
+					}
+
+					if b.String() != tt.expected {
+						t.Errorf("expected\n%s,\ngot\n%s", tt.expected, b.String())
+					}
+				})
+			}
+		})
+	}
+}