Browse Source

Merge branch 'main' into jyan/reord-g

Josh 9 months ago
parent
commit
d069cf753b

+ 1 - 1
.github/workflows/test.yaml

@@ -126,7 +126,7 @@ jobs:
     strategy:
       matrix:
         rocm-version:
-          - '6.1.1'
+          - '6.1.2'
     runs-on: linux
     container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }}
     steps:

+ 1 - 1
Dockerfile

@@ -2,7 +2,7 @@ ARG GOLANG_VERSION=1.22.1
 ARG CMAKE_VERSION=3.22.1
 # this CUDA_VERSION corresponds with the one specified in docs/gpu.md
 ARG CUDA_VERSION=11.3.1
-ARG ROCM_VERSION=6.1.1
+ARG ROCM_VERSION=6.1.2
 
 # Copy the minimal context we need to run the generate scripts
 FROM scratch AS llm-code

+ 2 - 0
README.md

@@ -293,6 +293,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
 - [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
 - [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
 - [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
+- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
+- [AI Studio](https://github.com/MindWorkAI/AI-Studio)
 
 ### Terminal
 

+ 10 - 1
api/client.go

@@ -347,7 +347,16 @@ func (c *Client) Heartbeat(ctx context.Context) error {
 	return nil
 }
 
-// Embeddings generates embeddings from a model.
+// Embed generates embeddings from a model.
+func (c *Client) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) {
+	var resp EmbedResponse
+	if err := c.do(ctx, http.MethodPost, "/api/embed", req, &resp); err != nil {
+		return nil, err
+	}
+	return &resp, nil
+}
+
+// Embeddings generates an embedding from a model.
 func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
 	var resp EmbeddingResponse
 	if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {

+ 77 - 5
api/types.go

@@ -47,6 +47,9 @@ type GenerateRequest struct {
 	// Prompt is the textual prompt to send to the model.
 	Prompt string `json:"prompt"`
 
+	// Suffix is the text that comes after the inserted text.
+	Suffix string `json:"suffix"`
+
 	// System overrides the model's default system message/prompt.
 	System string `json:"system"`
 
@@ -97,6 +100,9 @@ type ChatRequest struct {
 	// followin the request.
 	KeepAlive *Duration `json:"keep_alive,omitempty"`
 
+	// Tools is an optional list of tools the model has access to.
+	Tools []Tool `json:"tools,omitempty"`
+
 	// Options lists model-specific options.
 	Options map[string]interface{} `json:"options"`
 }
@@ -105,9 +111,46 @@ type ChatRequest struct {
 // role ("system", "user", or "assistant"), the content and an optional list
 // of images.
 type Message struct {
-	Role    string      `json:"role"`
-	Content string      `json:"content"`
-	Images  []ImageData `json:"images,omitempty"`
+	Role      string      `json:"role"`
+	Content   string      `json:"content,omitempty"`
+	Images    []ImageData `json:"images,omitempty"`
+	ToolCalls []ToolCall  `json:"tool_calls,omitempty"`
+}
+
+type ToolCall struct {
+	Function struct {
+		Name      string         `json:"name"`
+		Arguments map[string]any `json:"arguments"`
+	} `json:"function"`
+}
+
+type Tool struct {
+	Type     string `json:"type"`
+	Function struct {
+		Name        string `json:"name"`
+		Description string `json:"description"`
+		Parameters  struct {
+			Type       string   `json:"type"`
+			Required   []string `json:"required"`
+			Properties map[string]struct {
+				Type        string   `json:"type"`
+				Description string   `json:"description"`
+				Enum        []string `json:"enum,omitempty"`
+			} `json:"properties"`
+		} `json:"parameters"`
+	} `json:"function"`
+}
+
+func (m *Message) UnmarshalJSON(b []byte) error {
+	type Alias Message
+	var a Alias
+	if err := json.Unmarshal(b, &a); err != nil {
+		return err
+	}
+
+	*m = Message(a)
+	m.Role = strings.ToLower(m.Role)
+	return nil
 }
 
 // ChatResponse is the response returned by [Client.Chat]. Its fields are
@@ -173,6 +216,30 @@ type Runner struct {
 	NumThread int   `json:"num_thread,omitempty"`
 }
 
+// EmbedRequest is the request passed to [Client.Embed].
+type EmbedRequest struct {
+	// Model is the model name.
+	Model string `json:"model"`
+
+	// Input is the input to embed.
+	Input any `json:"input"`
+
+	// KeepAlive controls how long the model will stay loaded in memory following
+	// this request.
+	KeepAlive *Duration `json:"keep_alive,omitempty"`
+
+	Truncate *bool `json:"truncate,omitempty"`
+
+	// Options lists model-specific options.
+	Options map[string]interface{} `json:"options"`
+}
+
+// EmbedResponse is the response from [Client.Embed].
+type EmbedResponse struct {
+	Model      string      `json:"model"`
+	Embeddings [][]float32 `json:"embeddings"`
+}
+
 // EmbeddingRequest is the request passed to [Client.Embeddings].
 type EmbeddingRequest struct {
 	// Model is the model name.
@@ -219,8 +286,10 @@ type DeleteRequest struct {
 
 // ShowRequest is the request passed to [Client.Show].
 type ShowRequest struct {
-	Model    string `json:"model"`
-	System   string `json:"system"`
+	Model  string `json:"model"`
+	System string `json:"system"`
+
+	// Template is deprecated
 	Template string `json:"template"`
 	Verbose  bool   `json:"verbose"`
 
@@ -336,6 +405,9 @@ type GenerateResponse struct {
 	// Response is the textual response itself.
 	Response string `json:"response"`
 
+	// ToolCalls is the list of tools the model wants to call
+	ToolCalls []ToolCall `json:"tool_calls,omitempty"`
+
 	// Done specifies if the response is complete.
 	Done bool `json:"done"`
 

+ 23 - 0
api/types_test.go

@@ -208,3 +208,26 @@ func TestUseMmapFormatParams(t *testing.T) {
 		})
 	}
 }
+
+func TestMessage_UnmarshalJSON(t *testing.T) {
+	tests := []struct {
+		input    string
+		expected string
+	}{
+		{`{"role": "USER", "content": "Hello!"}`, "user"},
+		{`{"role": "System", "content": "Initialization complete."}`, "system"},
+		{`{"role": "assistant", "content": "How can I help you?"}`, "assistant"},
+		{`{"role": "TOOl", "content": "Access granted."}`, "tool"},
+	}
+
+	for _, test := range tests {
+		var msg Message
+		if err := json.Unmarshal([]byte(test.input), &msg); err != nil {
+			t.Errorf("Unexpected error: %v", err)
+		}
+
+		if msg.Role != test.expected {
+			t.Errorf("role not lowercased: got %v, expected %v", msg.Role, test.expected)
+		}
+	}
+}

+ 4 - 0
app/ollama.iss

@@ -127,6 +127,10 @@ Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\models"
 Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\history"
 ; NOTE: if the user has a custom OLLAMA_MODELS it will be preserved
 
+[InstallDelete]
+Type: filesandordirs; Name: "{%TEMP}\ollama*"
+Type: filesandordirs; Name: "{%LOCALAPPDATA}\Programs\Ollama"
+
 [Messages]
 WizardReady=Ollama Windows Preview
 ReadyLabel1=%nLet's get you up and running with your own large language models.

+ 0 - 2
cmd/cmd.go

@@ -843,7 +843,6 @@ type runOptions struct {
 	WordWrap    bool
 	Format      string
 	System      string
-	Template    string
 	Images      []api.ImageData
 	Options     map[string]interface{}
 	MultiModal  bool
@@ -1037,7 +1036,6 @@ func generate(cmd *cobra.Command, opts runOptions) error {
 		Images:    opts.Images,
 		Format:    opts.Format,
 		System:    opts.System,
-		Template:  opts.Template,
 		Options:   opts.Options,
 		KeepAlive: opts.KeepAlive,
 	}

+ 14 - 38
cmd/interactive.go

@@ -27,7 +27,6 @@ const (
 	MultilineNone MultilineState = iota
 	MultilinePrompt
 	MultilineSystem
-	MultilineTemplate
 )
 
 func loadModel(cmd *cobra.Command, opts *runOptions) error {
@@ -94,7 +93,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 		fmt.Fprintln(os.Stderr, "Available Commands:")
 		fmt.Fprintln(os.Stderr, "  /set parameter ...     Set a parameter")
 		fmt.Fprintln(os.Stderr, "  /set system <string>   Set system message")
-		fmt.Fprintln(os.Stderr, "  /set template <string> Set prompt template")
 		fmt.Fprintln(os.Stderr, "  /set history           Enable history")
 		fmt.Fprintln(os.Stderr, "  /set nohistory         Disable history")
 		fmt.Fprintln(os.Stderr, "  /set wordwrap          Enable wordwrap")
@@ -204,10 +202,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 				opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
 				fmt.Println("Set system message.")
 				sb.Reset()
-			case MultilineTemplate:
-				opts.Template = sb.String()
-				fmt.Println("Set prompt template.")
-				sb.Reset()
 			}
 
 			multiline = MultilineNone
@@ -326,17 +320,13 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 					}
 					fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
 					opts.Options[args[2]] = fp[args[2]]
-				case "system", "template":
+				case "system":
 					if len(args) < 3 {
 						usageSet()
 						continue
 					}
 
-					if args[1] == "system" {
-						multiline = MultilineSystem
-					} else if args[1] == "template" {
-						multiline = MultilineTemplate
-					}
+					multiline = MultilineSystem
 
 					line := strings.Join(args[2:], " ")
 					line, ok := strings.CutPrefix(line, `"""`)
@@ -356,23 +346,17 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 						continue
 					}
 
-					if args[1] == "system" {
-						opts.System = sb.String() // for display in modelfile
-						newMessage := api.Message{Role: "system", Content: sb.String()}
-						// Check if the slice is not empty and the last message is from 'system'
-						if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" {
-							// Replace the last message
-							opts.Messages[len(opts.Messages)-1] = newMessage
-						} else {
-							opts.Messages = append(opts.Messages, newMessage)
-						}
-						fmt.Println("Set system message.")
-						sb.Reset()
-					} else if args[1] == "template" {
-						opts.Template = sb.String()
-						fmt.Println("Set prompt template.")
-						sb.Reset()
+					opts.System = sb.String() // for display in modelfile
+					newMessage := api.Message{Role: "system", Content: sb.String()}
+					// Check if the slice is not empty and the last message is from 'system'
+					if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" {
+						// Replace the last message
+						opts.Messages[len(opts.Messages)-1] = newMessage
+					} else {
+						opts.Messages = append(opts.Messages, newMessage)
 					}
+					fmt.Println("Set system message.")
+					sb.Reset()
 
 					sb.Reset()
 					continue
@@ -393,7 +377,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 				req := &api.ShowRequest{
 					Name:     opts.Model,
 					System:   opts.System,
-					Template: opts.Template,
 					Options:  opts.Options,
 				}
 				resp, err := client.Show(cmd.Context(), req)
@@ -437,12 +420,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 						fmt.Println("No system message was specified for this model.")
 					}
 				case "template":
-					switch {
-					case opts.Template != "":
-						fmt.Println(opts.Template + "\n")
-					case resp.Template != "":
+					if resp.Template != "" {
 						fmt.Println(resp.Template)
-					default:
+					} else {
 						fmt.Println("No prompt template was specified for this model.")
 					}
 				default:
@@ -536,10 +516,6 @@ func buildModelfile(opts runOptions) string {
 		fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System)
 	}
 
-	if opts.Template != "" {
-		fmt.Fprintf(&mf, "TEMPLATE \"\"\"%s\"\"\"\n", opts.Template)
-	}
-
 	keys := make([]string, 0)
 	for k := range opts.Options {
 		keys = append(keys, k)

+ 0 - 3
cmd/interactive_test.go

@@ -59,7 +59,6 @@ func TestModelfileBuilder(t *testing.T) {
 	opts := runOptions{
 		Model:    "hork",
 		System:   "You are part horse and part shark, but all hork. Do horklike things",
-		Template: "This is a template.",
 		Messages: []api.Message{
 			{Role: "user", Content: "Hey there hork!"},
 			{Role: "assistant", Content: "Yes it is true, I am half horse, half shark."},
@@ -75,7 +74,6 @@ func TestModelfileBuilder(t *testing.T) {
 	mf := buildModelfile(opts)
 	expectedModelfile := `FROM {{.Model}}
 SYSTEM """{{.System}}"""
-TEMPLATE """{{.Template}}"""
 PARAMETER penalize_newline false
 PARAMETER seed 42
 PARAMETER stop [hi there]
@@ -97,7 +95,6 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark."""
 	mf = buildModelfile(opts)
 	expectedModelfile = `FROM {{.ParentModel}}
 SYSTEM """{{.System}}"""
-TEMPLATE """{{.Template}}"""
 PARAMETER penalize_newline false
 PARAMETER seed 42
 PARAMETER stop [hi there]

+ 152 - 0
integration/embed_test.go

@@ -0,0 +1,152 @@
+//go:build integration
+
+package integration
+
+import (
+	"context"
+	"testing"
+	"time"
+
+	"github.com/ollama/ollama/api"
+)
+
+func TestAllMiniLMEmbed(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
+	defer cancel()
+
+	req := api.EmbedRequest{
+		Model: "all-minilm",
+		Input: "why is the sky blue?",
+	}
+
+	res, err := embedTestHelper(ctx, t, req)
+
+	if err != nil {
+		t.Fatalf("error: %v", err)
+	}
+
+	if len(res.Embeddings) != 1 {
+		t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
+	}
+
+	if len(res.Embeddings[0]) != 384 {
+		t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
+	}
+
+	if res.Embeddings[0][0] != 0.010071031 {
+		t.Fatalf("expected 0.010071031, got %f", res.Embeddings[0][0])
+	}
+}
+
+func TestAllMiniLMBatchEmbed(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
+	defer cancel()
+
+	req := api.EmbedRequest{
+		Model: "all-minilm",
+		Input: []string{"why is the sky blue?", "why is the grass green?"},
+	}
+
+	res, err := embedTestHelper(ctx, t, req)
+
+	if err != nil {
+		t.Fatalf("error: %v", err)
+	}
+
+	if len(res.Embeddings) != 2 {
+		t.Fatalf("expected 2 embeddings, got %d", len(res.Embeddings))
+	}
+
+	if len(res.Embeddings[0]) != 384 {
+		t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
+	}
+
+	if res.Embeddings[0][0] != 0.010071031 || res.Embeddings[1][0] != -0.009802706 {
+		t.Fatalf("expected 0.010071031 and -0.009802706, got %f and %f", res.Embeddings[0][0], res.Embeddings[1][0])
+	}
+}
+
+func TestAllMiniLmEmbedTruncate(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
+	defer cancel()
+
+	truncTrue, truncFalse := true, false
+
+	type testReq struct {
+		Name    string
+		Request api.EmbedRequest
+	}
+
+	reqs := []testReq{
+		{
+			Name: "Target Truncation",
+			Request: api.EmbedRequest{
+				Model: "all-minilm",
+				Input: "why",
+			},
+		},
+		{
+			Name: "Default Truncate",
+			Request: api.EmbedRequest{
+				Model:   "all-minilm",
+				Input:   "why is the sky blue?",
+				Options: map[string]any{"num_ctx": 1},
+			},
+		},
+		{
+			Name: "Explicit Truncate",
+			Request: api.EmbedRequest{
+				Model:    "all-minilm",
+				Input:    "why is the sky blue?",
+				Truncate: &truncTrue,
+				Options:  map[string]any{"num_ctx": 1},
+			},
+		},
+	}
+
+	res := make(map[string]*api.EmbedResponse)
+
+	for _, req := range reqs {
+		response, err := embedTestHelper(ctx, t, req.Request)
+		if err != nil {
+			t.Fatalf("error: %v", err)
+		}
+		res[req.Name] = response
+	}
+
+	if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
+		t.Fatal("expected default request to truncate correctly")
+	}
+
+	if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
+		t.Fatal("expected default request and truncate true request to be the same")
+	}
+
+	// check that truncate set to false returns an error if context length is exceeded
+	_, err := embedTestHelper(ctx, t, api.EmbedRequest{
+		Model:    "all-minilm",
+		Input:    "why is the sky blue?",
+		Truncate: &truncFalse,
+		Options:  map[string]any{"num_ctx": 1},
+	})
+
+	if err == nil {
+		t.Fatal("expected error, got nil")
+	}
+}
+
+func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+	if err := PullIfMissing(ctx, client, req.Model); err != nil {
+		t.Fatalf("failed to pull model %s: %v", req.Model, err)
+	}
+
+	response, err := client.Embed(ctx, &req)
+
+	if err != nil {
+		return nil, err
+	}
+
+	return response, nil
+}

+ 23 - 16
llm/ext_server/server.cpp

@@ -3188,26 +3188,33 @@ int main(int argc, char **argv) {
                     prompt = "";
                 }
 
-                json image_data;
-                if (body.count("image_data") != 0) {
-                    image_data = body["image_data"];
-                }
-                else
-                {
-                    image_data = "";
+                if (prompt.size() == 1) {
+                    prompt = prompt[0];
                 }
 
                 // create and queue the task
-                const int task_id = llama.queue_tasks.get_new_id();
-                llama.queue_results.add_waiting_task_id(task_id);
-                llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, true, -1);
-
-                // get the result
-                task_result result = llama.queue_results.recv(task_id);
-                llama.queue_results.remove_waiting_task_id(task_id);
+                json responses;
+                {
+                    const int id_task = llama.queue_tasks.get_new_id();
+                    llama.queue_results.add_waiting_task_id(id_task);
+                    llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
+
+                    // get the result
+                    task_result result = llama.queue_results.recv(id_task);
+                    llama.queue_results.remove_waiting_task_id(id_task);
+                    if (result.error) {
+                        return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
+                    }
 
-                // send the result
-                return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
+                    responses = result.result_json.value("results", std::vector<json>{result.result_json});
+                    json embeddings = json::array();
+                    for (auto & elem : responses) {
+                        embeddings.push_back(elem.at("embedding"));
+                    }
+                    // send the result
+                    json embedding_res = json{{"embedding", embeddings}};
+                    return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
+                }
             });
 
     // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?

+ 1 - 0
llm/gguf.go

@@ -557,6 +557,7 @@ var ggufKVOrder = map[string][]string{
 		"tokenizer.ggml.add_bos_token",
 		"tokenizer.ggml.add_eos_token",
 		"tokenizer.chat_template",
+		"bert.pooling_type",
 	},
 }
 

+ 9 - 9
llm/server.go

@@ -33,7 +33,7 @@ type LlamaServer interface {
 	Ping(ctx context.Context) error
 	WaitUntilRunning(ctx context.Context) error
 	Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
-	Embedding(ctx context.Context, prompt string) ([]float64, error)
+	Embed(ctx context.Context, input []string) ([][]float32, error)
 	Tokenize(ctx context.Context, content string) ([]int, error)
 	Detokenize(ctx context.Context, tokens []int) (string, error)
 	Close() error
@@ -127,7 +127,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 	// On linux, over-allocating CPU memory will almost always result in an error
 	if runtime.GOOS == "linux" {
 		systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize
-		available := min(systemTotalMemory, systemFreeMemory+systemSwapFreeMemory)
+		available := systemFreeMemory + systemSwapFreeMemory
 		if systemMemoryRequired > available {
 			slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", available, "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "swap", format.HumanBytes2(systemSwapFreeMemory))
 			return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available))
@@ -867,15 +867,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 	return nil
 }
 
-type EmbeddingRequest struct {
-	Content string `json:"content"`
+type EmbedRequest struct {
+	Content []string `json:"content"`
 }
 
-type EmbeddingResponse struct {
-	Embedding []float64 `json:"embedding"`
+type EmbedResponse struct {
+	Embedding [][]float32 `json:"embedding"`
 }
 
-func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
+func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) {
 	if err := s.sem.Acquire(ctx, 1); err != nil {
 		slog.Error("Failed to acquire semaphore", "error", err)
 		return nil, err
@@ -890,7 +890,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
 		return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 
-	data, err := json.Marshal(TokenizeRequest{Content: prompt})
+	data, err := json.Marshal(EmbedRequest{Content: input})
 	if err != nil {
 		return nil, fmt.Errorf("error marshaling embed data: %w", err)
 	}
@@ -917,7 +917,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
 		return nil, fmt.Errorf("%s", body)
 	}
 
-	var embedding EmbeddingResponse
+	var embedding EmbedResponse
 	if err := json.Unmarshal(body, &embedding); err != nil {
 		return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
 	}

+ 181 - 6
openai/openai.go

@@ -3,11 +3,13 @@ package openai
 
 import (
 	"bytes"
+	"encoding/base64"
 	"encoding/json"
 	"fmt"
 	"io"
 	"math/rand"
 	"net/http"
+	"strings"
 	"time"
 
 	"github.com/gin-gonic/gin"
@@ -28,7 +30,7 @@ type ErrorResponse struct {
 
 type Message struct {
 	Role    string `json:"role"`
-	Content string `json:"content"`
+	Content any    `json:"content"`
 }
 
 type Choice struct {
@@ -59,6 +61,11 @@ type ResponseFormat struct {
 	Type string `json:"type"`
 }
 
+type EmbedRequest struct {
+	Input any    `json:"input"`
+	Model string `json:"model"`
+}
+
 type ChatCompletionRequest struct {
 	Model            string          `json:"model"`
 	Messages         []Message       `json:"messages"`
@@ -132,11 +139,23 @@ type Model struct {
 	OwnedBy string `json:"owned_by"`
 }
 
+type Embedding struct {
+	Object    string    `json:"object"`
+	Embedding []float32 `json:"embedding"`
+	Index     int       `json:"index"`
+}
+
 type ListCompletion struct {
 	Object string  `json:"object"`
 	Data   []Model `json:"data"`
 }
 
+type EmbeddingList struct {
+	Object string      `json:"object"`
+	Data   []Embedding `json:"data"`
+	Model  string      `json:"model"`
+}
+
 func NewError(code int, message string) ErrorResponse {
 	var etype string
 	switch code {
@@ -260,6 +279,27 @@ func toListCompletion(r api.ListResponse) ListCompletion {
 	}
 }
 
+func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
+	if r.Embeddings != nil {
+		var data []Embedding
+		for i, e := range r.Embeddings {
+			data = append(data, Embedding{
+				Object:    "embedding",
+				Embedding: e,
+				Index:     i,
+			})
+		}
+
+		return EmbeddingList{
+			Object: "list",
+			Data:   data,
+			Model:  model,
+		}
+	}
+
+	return EmbeddingList{}
+}
+
 func toModel(r api.ShowResponse, m string) Model {
 	return Model{
 		Id:      m,
@@ -269,10 +309,66 @@ func toModel(r api.ShowResponse, m string) Model {
 	}
 }
 
-func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
+func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
 	var messages []api.Message
 	for _, msg := range r.Messages {
-		messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
+		switch content := msg.Content.(type) {
+		case string:
+			messages = append(messages, api.Message{Role: msg.Role, Content: content})
+		case []any:
+			message := api.Message{Role: msg.Role}
+			for _, c := range content {
+				data, ok := c.(map[string]any)
+				if !ok {
+					return nil, fmt.Errorf("invalid message format")
+				}
+				switch data["type"] {
+				case "text":
+					text, ok := data["text"].(string)
+					if !ok {
+						return nil, fmt.Errorf("invalid message format")
+					}
+					message.Content = text
+				case "image_url":
+					var url string
+					if urlMap, ok := data["image_url"].(map[string]any); ok {
+						if url, ok = urlMap["url"].(string); !ok {
+							return nil, fmt.Errorf("invalid message format")
+						}
+					} else {
+						if url, ok = data["image_url"].(string); !ok {
+							return nil, fmt.Errorf("invalid message format")
+						}
+					}
+
+					types := []string{"jpeg", "jpg", "png"}
+					valid := false
+					for _, t := range types {
+						prefix := "data:image/" + t + ";base64,"
+						if strings.HasPrefix(url, prefix) {
+							url = strings.TrimPrefix(url, prefix)
+							valid = true
+							break
+						}
+					}
+
+					if !valid {
+						return nil, fmt.Errorf("invalid image input")
+					}
+
+					img, err := base64.StdEncoding.DecodeString(url)
+					if err != nil {
+						return nil, fmt.Errorf("invalid message format")
+					}
+					message.Images = append(message.Images, img)
+				default:
+					return nil, fmt.Errorf("invalid message format")
+				}
+			}
+			messages = append(messages, message)
+		default:
+			return nil, fmt.Errorf("invalid message content type: %T", content)
+		}
 	}
 
 	options := make(map[string]interface{})
@@ -323,13 +419,13 @@ func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
 		format = "json"
 	}
 
-	return api.ChatRequest{
+	return &api.ChatRequest{
 		Model:    r.Model,
 		Messages: messages,
 		Format:   format,
 		Options:  options,
 		Stream:   &r.Stream,
-	}
+	}, nil
 }
 
 func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
@@ -407,6 +503,11 @@ type RetrieveWriter struct {
 	model string
 }
 
+type EmbedWriter struct {
+	BaseWriter
+	model string
+}
+
 func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
 	var serr api.StatusError
 	err := json.Unmarshal(data, &serr)
@@ -572,6 +673,33 @@ func (w *RetrieveWriter) Write(data []byte) (int, error) {
 	return w.writeResponse(data)
 }
 
+func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
+	var embedResponse api.EmbedResponse
+	err := json.Unmarshal(data, &embedResponse)
+
+	if err != nil {
+		return 0, err
+	}
+
+	w.ResponseWriter.Header().Set("Content-Type", "application/json")
+	err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
+
+	if err != nil {
+		return 0, err
+	}
+
+	return len(data), nil
+}
+
+func (w *EmbedWriter) Write(data []byte) (int, error) {
+	code := w.ResponseWriter.Status()
+	if code != http.StatusOK {
+		return w.writeError(code, data)
+	}
+
+	return w.writeResponse(data)
+}
+
 func ListMiddleware() gin.HandlerFunc {
 	return func(c *gin.Context) {
 		w := &ListWriter{
@@ -635,6 +763,47 @@ func CompletionsMiddleware() gin.HandlerFunc {
 			id:         fmt.Sprintf("cmpl-%d", rand.Intn(999)),
 		}
 
+		c.Writer = w
+		c.Next()
+	}
+}
+
+func EmbeddingsMiddleware() gin.HandlerFunc {
+	return func(c *gin.Context) {
+		var req EmbedRequest
+		err := c.ShouldBindJSON(&req)
+		if err != nil {
+			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
+			return
+		}
+
+		if req.Input == "" {
+			req.Input = []string{""}
+		}
+
+		if req.Input == nil {
+			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
+			return
+		}
+
+		if v, ok := req.Input.([]any); ok && len(v) == 0 {
+			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
+			return
+		}
+
+		var b bytes.Buffer
+		if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
+			c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
+			return
+		}
+
+		c.Request.Body = io.NopCloser(&b)
+
+		w := &EmbedWriter{
+			BaseWriter: BaseWriter{ResponseWriter: c.Writer},
+			model:      req.Model,
+		}
+
 		c.Writer = w
 
 		c.Next()
@@ -656,7 +825,13 @@ func ChatMiddleware() gin.HandlerFunc {
 		}
 
 		var b bytes.Buffer
-		if err := json.NewEncoder(&b).Encode(fromChatRequest(req)); err != nil {
+
+		chatReq, err := fromChatRequest(req)
+		if err != nil {
+			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
+		}
+
+		if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
 			c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
 			return
 		}

+ 121 - 0
openai/openai_test.go

@@ -2,6 +2,7 @@ package openai
 
 import (
 	"bytes"
+	"encoding/base64"
 	"encoding/json"
 	"io"
 	"net/http"
@@ -15,6 +16,10 @@ import (
 	"github.com/stretchr/testify/assert"
 )
 
+const prefix = `data:image/jpeg;base64,`
+const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
+const imageURL = prefix + image
+
 func TestMiddlewareRequests(t *testing.T) {
 	type testCase struct {
 		Name     string
@@ -112,6 +117,122 @@ func TestMiddlewareRequests(t *testing.T) {
 				}
 			},
 		},
+		{
+			Name:    "chat handler with image content",
+			Method:  http.MethodPost,
+			Path:    "/api/chat",
+			Handler: ChatMiddleware,
+			Setup: func(t *testing.T, req *http.Request) {
+				body := ChatCompletionRequest{
+					Model: "test-model",
+					Messages: []Message{
+						{
+							Role: "user", Content: []map[string]any{
+								{"type": "text", "text": "Hello"},
+								{"type": "image_url", "image_url": map[string]string{"url": imageURL}},
+							},
+						},
+					},
+				}
+
+				bodyBytes, _ := json.Marshal(body)
+
+				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+				req.Header.Set("Content-Type", "application/json")
+			},
+			Expected: func(t *testing.T, req *http.Request) {
+				var chatReq api.ChatRequest
+				if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
+					t.Fatal(err)
+				}
+
+				if chatReq.Messages[0].Role != "user" {
+					t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
+				}
+
+				if chatReq.Messages[0].Content != "Hello" {
+					t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
+				}
+
+				img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
+
+				if !bytes.Equal(chatReq.Messages[0].Images[0], img) {
+					t.Fatalf("expected image encoding, got %s", chatReq.Messages[0].Images[0])
+				}
+			},
+		},
+		{
+			Name:    "embed handler single input",
+			Method:  http.MethodPost,
+			Path:    "/api/embed",
+			Handler: EmbeddingsMiddleware,
+			Setup: func(t *testing.T, req *http.Request) {
+				body := EmbedRequest{
+					Input: "Hello",
+					Model: "test-model",
+				}
+
+				bodyBytes, _ := json.Marshal(body)
+
+				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+				req.Header.Set("Content-Type", "application/json")
+			},
+			Expected: func(t *testing.T, req *http.Request) {
+				var embedReq api.EmbedRequest
+				if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
+					t.Fatal(err)
+				}
+
+				if embedReq.Input != "Hello" {
+					t.Fatalf("expected 'Hello', got %s", embedReq.Input)
+				}
+
+				if embedReq.Model != "test-model" {
+					t.Fatalf("expected 'test-model', got %s", embedReq.Model)
+				}
+			},
+		},
+		{
+			Name:    "embed handler batch input",
+			Method:  http.MethodPost,
+			Path:    "/api/embed",
+			Handler: EmbeddingsMiddleware,
+			Setup: func(t *testing.T, req *http.Request) {
+				body := EmbedRequest{
+					Input: []string{"Hello", "World"},
+					Model: "test-model",
+				}
+
+				bodyBytes, _ := json.Marshal(body)
+
+				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+				req.Header.Set("Content-Type", "application/json")
+			},
+			Expected: func(t *testing.T, req *http.Request) {
+				var embedReq api.EmbedRequest
+				if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
+					t.Fatal(err)
+				}
+
+				input, ok := embedReq.Input.([]any)
+
+				if !ok {
+					t.Fatalf("expected input to be a list")
+				}
+
+				if input[0].(string) != "Hello" {
+					t.Fatalf("expected 'Hello', got %s", input[0])
+				}
+
+				if input[1].(string) != "World" {
+					t.Fatalf("expected 'World', got %s", input[1])
+				}
+
+				if embedReq.Model != "test-model" {
+					t.Fatalf("expected 'test-model', got %s", embedReq.Model)
+				}
+			},
+		},
 	}
 
 	gin.SetMode(gin.TestMode)

+ 21 - 3
server/images.go

@@ -34,11 +34,20 @@ import (
 	"github.com/ollama/ollama/version"
 )
 
-var errCapabilityCompletion = errors.New("completion")
+var (
+	errCapabilities         = errors.New("does not support")
+	errCapabilityCompletion = errors.New("completion")
+	errCapabilityTools      = errors.New("tools")
+	errCapabilityInsert     = errors.New("insert")
+)
 
 type Capability string
 
-const CapabilityCompletion = Capability("completion")
+const (
+	CapabilityCompletion = Capability("completion")
+	CapabilityTools      = Capability("tools")
+	CapabilityInsert     = Capability("insert")
+)
 
 type registryOptions struct {
 	Insecure bool
@@ -88,6 +97,15 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
 			if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
 				errs = append(errs, errCapabilityCompletion)
 			}
+		case CapabilityTools:
+			if !slices.Contains(m.Template.Vars(), "tools") {
+				errs = append(errs, errCapabilityTools)
+			}
+		case CapabilityInsert:
+			vars := m.Template.Vars()
+			if !slices.Contains(vars, "suffix") {
+				errs = append(errs, errCapabilityInsert)
+			}
 		default:
 			slog.Error("unknown capability", "capability", cap)
 			return fmt.Errorf("unknown capability: %s", cap)
@@ -95,7 +113,7 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
 	}
 
 	if err := errors.Join(errs...); err != nil {
-		return fmt.Errorf("missing capabilities: %w", errors.Join(errs...))
+		return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
 	}
 
 	return nil

+ 88 - 0
server/model.go

@@ -4,6 +4,7 @@ import (
 	"archive/zip"
 	"bytes"
 	"context"
+	"encoding/json"
 	"errors"
 	"fmt"
 	"io"
@@ -12,6 +13,9 @@ import (
 	"os"
 	"path/filepath"
 	"sort"
+	"slices"
+	"strings"
+	"text/template/parse"
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/convert"
@@ -322,3 +326,87 @@ func detectContentType(r io.Reader) (string, error) {
 
 	return "unknown", nil
 }
+
+// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
+// mxyng: this only really works if the input contains tool calls in some JSON format
+func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
+	// create a subtree from the node that ranges over .ToolCalls
+	tmpl := m.Template.Subtree(func(n parse.Node) bool {
+		if t, ok := n.(*parse.RangeNode); ok {
+			return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
+		}
+
+		return false
+	})
+
+	if tmpl == nil {
+		return nil, false
+	}
+
+	var b bytes.Buffer
+	if err := tmpl.Execute(&b, map[string][]map[string]any{
+		"ToolCalls": {
+			{
+				"Function": map[string]any{
+					"Name":      "@@name@@",
+					"Arguments": "@@arguments@@",
+				},
+			},
+		},
+	}); err != nil {
+		return nil, false
+	}
+
+	var kv map[string]string
+	// execute the subtree with placeholders to identify the keys
+	// trim any commands that might exist in the template
+	if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
+		return nil, false
+	}
+
+	// find the keys that correspond to the name and arguments fields
+	var name, arguments string
+	for k, v := range kv {
+		switch v {
+		case "@@name@@":
+			name = k
+		case "@@arguments@@":
+			arguments = k
+		}
+	}
+
+	var objs []map[string]any
+	for offset := 0; offset < len(s); {
+		if err := json.NewDecoder(strings.NewReader(s[offset:])).Decode(&objs); errors.Is(err, io.EOF) {
+			break
+		} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
+			// skip over any syntax errors
+			offset += int(syntax.Offset)
+		} else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) {
+			// skip over any unmarshalable types
+			offset += int(unmarshalType.Offset)
+		} else if err != nil {
+			return nil, false
+		} else {
+			// break when an object is decoded
+			break
+		}
+	}
+
+	var toolCalls []api.ToolCall
+	for _, kv := range objs {
+		var call api.ToolCall
+		for k, v := range kv {
+			switch k {
+			case name:
+				call.Function.Name = v.(string)
+			case arguments:
+				call.Function.Arguments = v.(map[string]any)
+			}
+		}
+
+		toolCalls = append(toolCalls, call)
+	}
+
+	return toolCalls, len(toolCalls) > 0
+}

+ 124 - 0
server/model_test.go

@@ -3,7 +3,9 @@ package server
 import (
 	"archive/zip"
 	"bytes"
+	"encoding/json"
 	"errors"
+	"fmt"
 	"io"
 	"os"
 	"path/filepath"
@@ -11,7 +13,9 @@ import (
 	"strings"
 	"testing"
 
+	"github.com/google/go-cmp/cmp"
 	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/template"
 )
 
 func createZipFile(t *testing.T, name string) *os.File {
@@ -110,3 +114,123 @@ func TestExtractFromZipFile(t *testing.T) {
 		})
 	}
 }
+
+type function struct {
+	Name      string         `json:"name"`
+	Arguments map[string]any `json:"arguments"`
+}
+
+func readFile(t *testing.T, base, name string) *bytes.Buffer {
+	t.Helper()
+
+	bts, err := os.ReadFile(filepath.Join(base, name))
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	return bytes.NewBuffer(bts)
+}
+
+func TestExecuteWithTools(t *testing.T) {
+	p := filepath.Join("testdata", "tools")
+	cases := []struct {
+		model  string
+		output string
+		ok     bool
+	}{
+		{"mistral", `[TOOL_CALLS]  [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
+		{"mistral", `[TOOL_CALLS]  [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
+
+The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
+		{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
+
+		[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
+		{"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
+		{"command-r-plus", "Action: ```json" + `
+[
+    {
+        "tool_name": "get_current_weather",
+        "parameters": {
+            "format": "fahrenheit",
+            "location": "San Francisco, CA"
+        }
+    },
+    {
+        "tool_name": "get_current_weather",
+        "parameters": {
+            "format": "celsius",
+            "location": "Toronto, Canada"
+        }
+    }
+]
+` + "```", true},
+		{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
+		{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
+		{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
+	}
+
+	var tools []api.Tool
+	if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil {
+		t.Fatal(err)
+	}
+
+	var messages []api.Message
+	if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil {
+		t.Fatal(err)
+	}
+
+	calls := []api.ToolCall{
+		{
+			Function: function{
+				Name: "get_current_weather",
+				Arguments: map[string]any{
+					"format":   "fahrenheit",
+					"location": "San Francisco, CA",
+				},
+			},
+		},
+		{
+			Function: function{
+				Name: "get_current_weather",
+				Arguments: map[string]any{
+					"format":   "celsius",
+					"location": "Toronto, Canada",
+				},
+			},
+		},
+	}
+
+	for _, tt := range cases {
+		t.Run(tt.model, func(t *testing.T) {
+			tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			t.Run("template", func(t *testing.T) {
+				var actual bytes.Buffer
+				if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil {
+					t.Fatal(err)
+				}
+
+				if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
+					t.Errorf("mismatch (-got +want):\n%s", diff)
+				}
+			})
+
+			t.Run("parse", func(t *testing.T) {
+				m := &Model{Template: tmpl}
+				actual, ok := m.parseToolCalls(tt.output)
+				if ok != tt.ok {
+					t.Fatalf("expected %t, got %t", tt.ok, ok)
+				}
+
+				if tt.ok {
+					if diff := cmp.Diff(actual, calls); diff != "" {
+						t.Errorf("mismatch (-got +want):\n%s", diff)
+					}
+				}
+			})
+		})
+	}
+}

+ 10 - 19
server/prompt.go

@@ -4,7 +4,6 @@ import (
 	"bytes"
 	"context"
 	"log/slog"
-	"slices"
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/llm"
@@ -16,29 +15,21 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
 // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
 // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
 // latest message and 2) system messages
-func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
-	// pull out any system messages which should always be included in the prompt
+func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
 	var system []api.Message
-	msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
-		if m.Role == "system" {
-			system = append(system, m)
-			return true
-		}
-
-		return false
-	})
-
-	if len(system) == 0 && m.System != "" {
-		// add model system prompt since it wasn't provided
-		system = append(system, api.Message{Role: "system", Content: m.System})
-	}
-
 	// always include the last message
 	n := len(msgs) - 1
 	// in reverse, find all messages that fit into context window
 	for i := n - 1; i >= 0; i-- {
+		system = make([]api.Message, 0)
+		for j := range i {
+			if msgs[j].Role == "system" {
+				system = append(system, msgs[j])
+			}
+		}
+
 		var b bytes.Buffer
-		if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
+		if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
 			return "", nil, err
 		}
 
@@ -66,7 +57,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
 
 	// truncate any messages that do not fit into the context window
 	var b bytes.Buffer
-	if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
+	if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil {
 		return "", nil, err
 	}
 

+ 17 - 12
server/prompt_test.go

@@ -3,21 +3,13 @@ package server
 import (
 	"bytes"
 	"context"
-	"strings"
 	"testing"
 
+	"github.com/google/go-cmp/cmp"
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/template"
 )
 
-func tokenize(_ context.Context, s string) (tokens []int, err error) {
-	for range strings.Fields(s) {
-		tokens = append(tokens, len(tokens))
-	}
-
-	return
-}
-
 func TestChatPrompt(t *testing.T) {
 	type expect struct {
 		prompt string
@@ -164,6 +156,19 @@ func TestChatPrompt(t *testing.T) {
 				prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
 			},
 		},
+		{
+			name:  "out of order system",
+			limit: 2048,
+			msgs: []api.Message{
+				{Role: "user", Content: "You're a test, Harry!"},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "system", Content: "You are the Test Who Lived."},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
+			},
+			expect: expect{
+				prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
+			},
+		},
 	}
 
 	tmpl, err := template.Parse(`
@@ -178,13 +183,13 @@ func TestChatPrompt(t *testing.T) {
 		t.Run(tt.name, func(t *testing.T) {
 			model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
 			opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
-			prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs)
+			prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
 			if err != nil {
 				t.Fatal(err)
 			}
 
-			if tt.prompt != prompt {
-				t.Errorf("expected %q, got %q", tt.prompt, prompt)
+			if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
+				t.Errorf("mismatch (-got +want):\n%s", diff)
 			}
 
 			if len(images) != len(tt.images) {

+ 224 - 39
server/routes.go

@@ -9,6 +9,7 @@ import (
 	"fmt"
 	"io"
 	"log/slog"
+	"math"
 	"net"
 	"net/http"
 	"net/netip"
@@ -102,6 +103,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
 }
 
 func (s *Server) GenerateHandler(c *gin.Context) {
+	checkpointStart := time.Now()
 	var req api.GenerateRequest
 	if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
@@ -120,6 +122,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 	}
 
 	caps := []Capability{CapabilityCompletion}
+	if req.Suffix != "" {
+		caps = append(caps, CapabilityInsert)
+	}
+
 	r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
 	if errors.Is(err, errCapabilityCompletion) {
 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
@@ -129,6 +135,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		return
 	}
 
+	checkpointLoaded := time.Now()
+
 	if req.Prompt == "" {
 		c.JSON(http.StatusOK, api.GenerateResponse{
 			Model:      req.Model,
@@ -146,19 +154,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 
 	prompt := req.Prompt
 	if !req.Raw {
-		var msgs []api.Message
-		if req.System != "" {
-			msgs = append(msgs, api.Message{Role: "system", Content: req.System})
-		} else if m.System != "" {
-			msgs = append(msgs, api.Message{Role: "system", Content: m.System})
-		}
-
-		for _, i := range images {
-			msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
-		}
-
-		msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
-
 		tmpl := m.Template
 		if req.Template != "" {
 			tmpl, err = template.Parse(req.Template)
@@ -179,7 +174,26 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 			b.WriteString(s)
 		}
 
-		if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil {
+		var values template.Values
+		if req.Suffix != "" {
+			values.Prompt = prompt
+			values.Suffix = req.Suffix
+		} else {
+			var msgs []api.Message
+			if req.System != "" {
+				msgs = append(msgs, api.Message{Role: "system", Content: req.System})
+			} else if m.System != "" {
+				msgs = append(msgs, api.Message{Role: "system", Content: m.System})
+			}
+
+			for _, i := range images {
+				msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
+			}
+
+			values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
+		}
+
+		if err := tmpl.Execute(&b, values); err != nil {
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 			return
 		}
@@ -191,26 +205,48 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 
 	ch := make(chan any)
 	go func() {
+		// TODO (jmorganca): avoid building the response twice both here and below
+		var sb strings.Builder
 		defer close(ch)
 		if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
 			Prompt:  prompt,
 			Images:  images,
 			Format:  req.Format,
 			Options: opts,
-		}, func(r llm.CompletionResponse) {
-			ch <- api.GenerateResponse{
+		}, func(cr llm.CompletionResponse) {
+			res := api.GenerateResponse{
 				Model:      req.Model,
 				CreatedAt:  time.Now().UTC(),
-				Response:   r.Content,
-				Done:       r.Done,
-				DoneReason: r.DoneReason,
+				Response:   cr.Content,
+				Done:       cr.Done,
+				DoneReason: cr.DoneReason,
 				Metrics: api.Metrics{
-					PromptEvalCount:    r.PromptEvalCount,
-					PromptEvalDuration: r.PromptEvalDuration,
-					EvalCount:          r.EvalCount,
-					EvalDuration:       r.EvalDuration,
+					PromptEvalCount:    cr.PromptEvalCount,
+					PromptEvalDuration: cr.PromptEvalDuration,
+					EvalCount:          cr.EvalCount,
+					EvalDuration:       cr.EvalDuration,
 				},
 			}
+
+			if _, err := sb.WriteString(cr.Content); err != nil {
+				ch <- gin.H{"error": err.Error()}
+			}
+
+			if cr.Done {
+				res.TotalDuration = time.Since(checkpointStart)
+				res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
+
+				if !req.Raw {
+					tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
+					if err != nil {
+						ch <- gin.H{"error": err.Error()}
+						return
+					}
+					res.Context = append(req.Context, tokens...)
+				}
+			}
+
+			ch <- res
 		}); err != nil {
 			ch <- gin.H{"error": err.Error()}
 		}
@@ -239,6 +275,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		}
 
 		r.Response = sb.String()
+		if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
+			r.ToolCalls = toolCalls
+			r.Response = ""
+		}
+
 		c.JSON(http.StatusOK, r)
 		return
 	}
@@ -246,6 +287,121 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 	streamResponse(c, ch)
 }
 
+func (s *Server) EmbedHandler(c *gin.Context) {
+	var req api.EmbedRequest
+	err := c.ShouldBindJSON(&req)
+	switch {
+	case errors.Is(err, io.EOF):
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
+		return
+	case err != nil:
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
+	truncate := true
+
+	if req.Truncate != nil && !*req.Truncate {
+		truncate = false
+	}
+
+	var input []string
+
+	switch i := req.Input.(type) {
+	case string:
+		if len(i) > 0 {
+			input = append(input, i)
+		}
+	case []any:
+		for _, v := range i {
+			if _, ok := v.(string); !ok {
+				c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
+				return
+			}
+			input = append(input, v.(string))
+		}
+	default:
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
+		return
+	}
+
+	if len(input) == 0 {
+		c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
+		return
+	}
+
+	r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
+	if err != nil {
+		handleScheduleError(c, req.Model, err)
+		return
+	}
+
+	kvData, err := getKVData(m.ModelPath, false)
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+
+	for i, s := range input {
+		tokens, err := r.Tokenize(c.Request.Context(), s)
+		if err != nil {
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
+		}
+
+		ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
+		if len(tokens) > ctxLen {
+			if !truncate {
+				c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
+				return
+			}
+
+			tokens = tokens[:ctxLen]
+			s, err = r.Detokenize(c.Request.Context(), tokens)
+			if err != nil {
+				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+				return
+			}
+		}
+
+		input[i] = s
+	}
+	embeddings, err := r.Embed(c.Request.Context(), input)
+
+	if err != nil {
+		slog.Error("embedding generation failed", "error", err)
+		c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
+		return
+	}
+
+	for i, e := range embeddings {
+		embeddings[i] = normalize(e)
+	}
+
+	resp := api.EmbedResponse{
+		Model:      req.Model,
+		Embeddings: embeddings,
+	}
+	c.JSON(http.StatusOK, resp)
+}
+
+func normalize(vec []float32) []float32 {
+	var sum float32
+	for _, v := range vec {
+		sum += v * v
+	}
+
+	norm := float32(0.0)
+	if sum > 0 {
+		norm = float32(1.0 / math.Sqrt(float64(sum)))
+	}
+
+	for i := range vec {
+		vec[i] *= norm
+	}
+	return vec
+}
+
 func (s *Server) EmbeddingsHandler(c *gin.Context) {
 	var req api.EmbeddingRequest
 	if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
@@ -268,14 +424,24 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
 		return
 	}
 
-	embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
+	embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
+
 	if err != nil {
 		slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
 		c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
 		return
 	}
 
-	c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: embedding})
+	embedding := make([]float64, len(embeddings[0]))
+
+	for i, v := range embeddings[0] {
+		embedding[i] = float64(v)
+	}
+
+	resp := api.EmbeddingResponse{
+		Embedding: embedding,
+	}
+	c.JSON(http.StatusOK, resp)
 }
 
 func (s *Server) PullModelHandler(c *gin.Context) {
@@ -549,13 +715,6 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
 		m.System = req.System
 	}
 
-	if req.Template != "" {
-		m.Template, err = template.Parse(req.Template)
-		if err != nil {
-			return nil, err
-		}
-	}
-
 	msgs := make([]api.Message, len(m.Messages))
 	for i, msg := range m.Messages {
 		msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
@@ -901,6 +1060,7 @@ func (s *Server) GenerateRoutes() http.Handler {
 	r.POST("/api/pull", s.PullModelHandler)
 	r.POST("/api/generate", s.GenerateHandler)
 	r.POST("/api/chat", s.ChatHandler)
+	r.POST("/api/embed", s.EmbedHandler)
 	r.POST("/api/embeddings", s.EmbeddingsHandler)
 	r.POST("/api/create", s.CreateModelHandler)
 	r.POST("/api/push", s.PushModelHandler)
@@ -914,6 +1074,7 @@ func (s *Server) GenerateRoutes() http.Handler {
 	// Compatibility endpoints
 	r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
 	r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
+	r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
 	r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
 	r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
 
@@ -1122,6 +1283,8 @@ func (s *Server) ProcessHandler(c *gin.Context) {
 }
 
 func (s *Server) ChatHandler(c *gin.Context) {
+	checkpointStart := time.Now()
+
 	var req api.ChatRequest
 	if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
@@ -1132,6 +1295,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
 	}
 
 	caps := []Capability{CapabilityCompletion}
+	if req.Tools != nil {
+		caps = append(caps, CapabilityTools)
+	}
+
 	r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
 	if errors.Is(err, errCapabilityCompletion) {
 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
@@ -1141,6 +1308,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
 		return
 	}
 
+	checkpointLoaded := time.Now()
+
 	if len(req.Messages) == 0 {
 		c.JSON(http.StatusOK, api.ChatResponse{
 			Model:      req.Model,
@@ -1152,7 +1321,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
 		return
 	}
 
-	prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
+	if req.Messages[0].Role != "system" && m.System != "" {
+		req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...)
+	}
+
+	prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages, req.Tools)
 	if err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
@@ -1169,7 +1342,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
 			Format:  req.Format,
 			Options: opts,
 		}, func(r llm.CompletionResponse) {
-			ch <- api.ChatResponse{
+			res := api.ChatResponse{
 				Model:      req.Model,
 				CreatedAt:  time.Now().UTC(),
 				Message:    api.Message{Role: "assistant", Content: r.Content},
@@ -1182,19 +1355,26 @@ func (s *Server) ChatHandler(c *gin.Context) {
 					EvalDuration:       r.EvalDuration,
 				},
 			}
+
+			if r.Done {
+				res.TotalDuration = time.Since(checkpointStart)
+				res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
+			}
+
+			ch <- res
 		}); err != nil {
 			ch <- gin.H{"error": err.Error()}
 		}
 	}()
 
 	if req.Stream != nil && !*req.Stream {
-		var r api.ChatResponse
+		var resp api.ChatResponse
 		var sb strings.Builder
 		for rr := range ch {
 			switch t := rr.(type) {
 			case api.ChatResponse:
 				sb.WriteString(t.Message.Content)
-				r = t
+				resp = t
 			case gin.H:
 				msg, ok := t["error"].(string)
 				if !ok {
@@ -1209,8 +1389,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
 			}
 		}
 
-		r.Message.Content = sb.String()
-		c.JSON(http.StatusOK, r)
+		resp.Message.Content = sb.String()
+		if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
+			resp.Message.ToolCalls = toolCalls
+			resp.Message.Content = ""
+		}
+
+		c.JSON(http.StatusOK, resp)
 		return
 	}
 
@@ -1219,7 +1404,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
 
 func handleScheduleError(c *gin.Context, name string, err error) {
 	switch {
-	case errors.Is(err, errRequired):
+	case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 	case errors.Is(err, context.Canceled):
 		c.JSON(499, gin.H{"error": "request canceled"})

+ 18 - 0
server/routes_create_test.go

@@ -85,6 +85,8 @@ func checkFileExists(t *testing.T, p string, expect []string) {
 }
 
 func TestCreateFromBin(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+
 	p := t.TempDir()
 	t.Setenv("OLLAMA_MODELS", p)
 	envconfig.LoadConfig()
@@ -111,6 +113,8 @@ func TestCreateFromBin(t *testing.T) {
 }
 
 func TestCreateFromModel(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+
 	p := t.TempDir()
 	t.Setenv("OLLAMA_MODELS", p)
 	envconfig.LoadConfig()
@@ -152,6 +156,8 @@ func TestCreateFromModel(t *testing.T) {
 }
 
 func TestCreateRemovesLayers(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+
 	p := t.TempDir()
 	t.Setenv("OLLAMA_MODELS", p)
 	envconfig.LoadConfig()
@@ -199,6 +205,8 @@ func TestCreateRemovesLayers(t *testing.T) {
 }
 
 func TestCreateUnsetsSystem(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+
 	p := t.TempDir()
 	t.Setenv("OLLAMA_MODELS", p)
 	envconfig.LoadConfig()
@@ -255,6 +263,8 @@ func TestCreateUnsetsSystem(t *testing.T) {
 }
 
 func TestCreateMergeParameters(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+
 	p := t.TempDir()
 	t.Setenv("OLLAMA_MODELS", p)
 	envconfig.LoadConfig()
@@ -358,6 +368,8 @@ func TestCreateMergeParameters(t *testing.T) {
 }
 
 func TestCreateReplacesMessages(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+
 	p := t.TempDir()
 	t.Setenv("OLLAMA_MODELS", p)
 	envconfig.LoadConfig()
@@ -434,6 +446,8 @@ func TestCreateReplacesMessages(t *testing.T) {
 }
 
 func TestCreateTemplateSystem(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+
 	p := t.TempDir()
 	t.Setenv("OLLAMA_MODELS", p)
 	envconfig.LoadConfig()
@@ -480,6 +494,8 @@ func TestCreateTemplateSystem(t *testing.T) {
 }
 
 func TestCreateLicenses(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+
 	p := t.TempDir()
 	t.Setenv("OLLAMA_MODELS", p)
 	envconfig.LoadConfig()
@@ -526,6 +542,8 @@ func TestCreateLicenses(t *testing.T) {
 }
 
 func TestCreateDetectTemplate(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+
 	p := t.TempDir()
 	t.Setenv("OLLAMA_MODELS", p)
 	envconfig.LoadConfig()

+ 5 - 0
server/routes_delete_test.go

@@ -8,12 +8,15 @@ import (
 	"path/filepath"
 	"testing"
 
+	"github.com/gin-gonic/gin"
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/envconfig"
 	"github.com/ollama/ollama/types/model"
 )
 
 func TestDelete(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+
 	p := t.TempDir()
 	t.Setenv("OLLAMA_MODELS", p)
 	envconfig.LoadConfig()
@@ -77,6 +80,8 @@ func TestDelete(t *testing.T) {
 }
 
 func TestDeleteDuplicateLayers(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+
 	p := t.TempDir()
 	t.Setenv("OLLAMA_MODELS", p)
 	var s Server

+ 712 - 0
server/routes_generate_test.go

@@ -0,0 +1,712 @@
+package server
+
+import (
+	"bytes"
+	"context"
+	"encoding/json"
+	"fmt"
+	"io"
+	"net/http"
+	"strings"
+	"testing"
+	"time"
+
+	"github.com/gin-gonic/gin"
+	"github.com/google/go-cmp/cmp"
+
+	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/gpu"
+	"github.com/ollama/ollama/llm"
+)
+
+type mockRunner struct {
+	llm.LlamaServer
+
+	// CompletionRequest is only valid until the next call to Completion
+	llm.CompletionRequest
+	llm.CompletionResponse
+}
+
+func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
+	m.CompletionRequest = r
+	fn(m.CompletionResponse)
+	return nil
+}
+
+func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
+	for range strings.Fields(s) {
+		tokens = append(tokens, len(tokens))
+	}
+
+	return
+}
+
+func newMockServer(mock *mockRunner) func(gpu.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
+	return func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
+		return mock, nil
+	}
+}
+
+func TestGenerateChat(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+
+	mock := mockRunner{
+		CompletionResponse: llm.CompletionResponse{
+			Done:               true,
+			DoneReason:         "stop",
+			PromptEvalCount:    1,
+			PromptEvalDuration: 1,
+			EvalCount:          1,
+			EvalDuration:       1,
+		},
+	}
+
+	s := Server{
+		sched: &Scheduler{
+			pendingReqCh:  make(chan *LlmRequest, 1),
+			finishedReqCh: make(chan *LlmRequest, 1),
+			expiredCh:     make(chan *runnerRef, 1),
+			unloadedCh:    make(chan any, 1),
+			loaded:        make(map[string]*runnerRef),
+			newServerFn:   newMockServer(&mock),
+			getGpuFn:      gpu.GetGPUInfo,
+			getCpuFn:      gpu.GetCPUInfo,
+			reschedDelay:  250 * time.Millisecond,
+			loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
+				// add 10ms delay to simulate loading
+				time.Sleep(10 * time.Millisecond)
+				req.successCh <- &runnerRef{
+					llama: &mock,
+				}
+			},
+		},
+	}
+
+	go s.sched.Run(context.TODO())
+
+	w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
+		Model: "test",
+		Modelfile: fmt.Sprintf(`FROM %s
+		TEMPLATE """
+{{- if .System }}System: {{ .System }} {{ end }}
+{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
+{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
+`, createBinFile(t, llm.KV{
+			"general.architecture":          "llama",
+			"llama.block_count":             uint32(1),
+			"llama.context_length":          uint32(8192),
+			"llama.embedding_length":        uint32(4096),
+			"llama.attention.head_count":    uint32(32),
+			"llama.attention.head_count_kv": uint32(8),
+			"tokenizer.ggml.tokens":         []string{""},
+			"tokenizer.ggml.scores":         []float32{0},
+			"tokenizer.ggml.token_type":     []int32{0},
+		}, []llm.Tensor{
+			{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+			{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+			{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+			{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+			{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+			{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+			{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+			{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+			{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+			{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+			{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+		})),
+		Stream: &stream,
+	})
+
+	if w.Code != http.StatusOK {
+		t.Fatalf("expected status 200, got %d", w.Code)
+	}
+
+	t.Run("missing body", func(t *testing.T) {
+		w := createRequest(t, s.ChatHandler, nil)
+		if w.Code != http.StatusBadRequest {
+			t.Errorf("expected status 400, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+
+	t.Run("missing model", func(t *testing.T) {
+		w := createRequest(t, s.ChatHandler, api.ChatRequest{})
+		if w.Code != http.StatusBadRequest {
+			t.Errorf("expected status 400, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+
+	t.Run("missing capabilities chat", func(t *testing.T) {
+		w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
+			Model: "bert",
+			Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
+				"general.architecture": "bert",
+				"bert.pooling_type":    uint32(0),
+			}, []llm.Tensor{})),
+			Stream: &stream,
+		})
+
+		if w.Code != http.StatusOK {
+			t.Fatalf("expected status 200, got %d", w.Code)
+		}
+
+		w = createRequest(t, s.ChatHandler, api.ChatRequest{
+			Model: "bert",
+		})
+
+		if w.Code != http.StatusBadRequest {
+			t.Errorf("expected status 400, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support chat"}`); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+
+	t.Run("load model", func(t *testing.T) {
+		w := createRequest(t, s.ChatHandler, api.ChatRequest{
+			Model: "test",
+		})
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+
+		var actual api.ChatResponse
+		if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
+			t.Fatal(err)
+		}
+
+		if actual.Model != "test" {
+			t.Errorf("expected model test, got %s", actual.Model)
+		}
+
+		if !actual.Done {
+			t.Errorf("expected done true, got false")
+		}
+
+		if actual.DoneReason != "load" {
+			t.Errorf("expected done reason load, got %s", actual.DoneReason)
+		}
+	})
+
+	checkChatResponse := func(t *testing.T, body io.Reader, model, content string) {
+		t.Helper()
+
+		var actual api.ChatResponse
+		if err := json.NewDecoder(body).Decode(&actual); err != nil {
+			t.Fatal(err)
+		}
+
+		if actual.Model != model {
+			t.Errorf("expected model test, got %s", actual.Model)
+		}
+
+		if !actual.Done {
+			t.Errorf("expected done false, got true")
+		}
+
+		if actual.DoneReason != "stop" {
+			t.Errorf("expected done reason stop, got %s", actual.DoneReason)
+		}
+
+		if diff := cmp.Diff(actual.Message, api.Message{
+			Role:    "assistant",
+			Content: content,
+		}); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+
+		if actual.PromptEvalCount == 0 {
+			t.Errorf("expected prompt eval count > 0, got 0")
+		}
+
+		if actual.PromptEvalDuration == 0 {
+			t.Errorf("expected prompt eval duration > 0, got 0")
+		}
+
+		if actual.EvalCount == 0 {
+			t.Errorf("expected eval count > 0, got 0")
+		}
+
+		if actual.EvalDuration == 0 {
+			t.Errorf("expected eval duration > 0, got 0")
+		}
+
+		if actual.LoadDuration == 0 {
+			t.Errorf("expected load duration > 0, got 0")
+		}
+
+		if actual.TotalDuration == 0 {
+			t.Errorf("expected total duration > 0, got 0")
+		}
+	}
+
+	mock.CompletionResponse.Content = "Hi!"
+	t.Run("messages", func(t *testing.T) {
+		w := createRequest(t, s.ChatHandler, api.ChatRequest{
+			Model: "test",
+			Messages: []api.Message{
+				{Role: "user", Content: "Hello!"},
+			},
+			Stream: &stream,
+		})
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+
+		checkChatResponse(t, w.Body, "test", "Hi!")
+	})
+
+	w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
+		Model:     "test-system",
+		Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
+	})
+
+	if w.Code != http.StatusOK {
+		t.Fatalf("expected status 200, got %d", w.Code)
+	}
+
+	t.Run("messages with model system", func(t *testing.T) {
+		w := createRequest(t, s.ChatHandler, api.ChatRequest{
+			Model: "test-system",
+			Messages: []api.Message{
+				{Role: "user", Content: "Hello!"},
+			},
+			Stream: &stream,
+		})
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+
+		checkChatResponse(t, w.Body, "test-system", "Hi!")
+	})
+
+	mock.CompletionResponse.Content = "Abra kadabra!"
+	t.Run("messages with system", func(t *testing.T) {
+		w := createRequest(t, s.ChatHandler, api.ChatRequest{
+			Model: "test-system",
+			Messages: []api.Message{
+				{Role: "system", Content: "You can perform magic tricks."},
+				{Role: "user", Content: "Hello!"},
+			},
+			Stream: &stream,
+		})
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+
+		checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
+	})
+
+	t.Run("messages with interleaved system", func(t *testing.T) {
+		w := createRequest(t, s.ChatHandler, api.ChatRequest{
+			Model: "test-system",
+			Messages: []api.Message{
+				{Role: "user", Content: "Hello!"},
+				{Role: "assistant", Content: "I can help you with that."},
+				{Role: "system", Content: "You can perform magic tricks."},
+				{Role: "user", Content: "Help me write tests."},
+			},
+			Stream: &stream,
+		})
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+
+		checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
+	})
+}
+
+func TestGenerate(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+
+	mock := mockRunner{
+		CompletionResponse: llm.CompletionResponse{
+			Done:               true,
+			DoneReason:         "stop",
+			PromptEvalCount:    1,
+			PromptEvalDuration: 1,
+			EvalCount:          1,
+			EvalDuration:       1,
+		},
+	}
+
+	s := Server{
+		sched: &Scheduler{
+			pendingReqCh:  make(chan *LlmRequest, 1),
+			finishedReqCh: make(chan *LlmRequest, 1),
+			expiredCh:     make(chan *runnerRef, 1),
+			unloadedCh:    make(chan any, 1),
+			loaded:        make(map[string]*runnerRef),
+			newServerFn:   newMockServer(&mock),
+			getGpuFn:      gpu.GetGPUInfo,
+			getCpuFn:      gpu.GetCPUInfo,
+			reschedDelay:  250 * time.Millisecond,
+			loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
+				req.successCh <- &runnerRef{
+					llama: &mock,
+				}
+			},
+		},
+	}
+
+	go s.sched.Run(context.TODO())
+
+	w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
+		Model: "test",
+		Modelfile: fmt.Sprintf(`FROM %s
+		TEMPLATE """
+{{- if .System }}System: {{ .System }} {{ end }}
+{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
+{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
+`, createBinFile(t, llm.KV{
+			"general.architecture":          "llama",
+			"llama.block_count":             uint32(1),
+			"llama.context_length":          uint32(8192),
+			"llama.embedding_length":        uint32(4096),
+			"llama.attention.head_count":    uint32(32),
+			"llama.attention.head_count_kv": uint32(8),
+			"tokenizer.ggml.tokens":         []string{""},
+			"tokenizer.ggml.scores":         []float32{0},
+			"tokenizer.ggml.token_type":     []int32{0},
+		}, []llm.Tensor{
+			{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+			{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+			{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+			{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+			{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+			{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+			{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+			{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+			{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+			{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+			{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+		})),
+		Stream: &stream,
+	})
+
+	if w.Code != http.StatusOK {
+		t.Fatalf("expected status 200, got %d", w.Code)
+	}
+
+	t.Run("missing body", func(t *testing.T) {
+		w := createRequest(t, s.GenerateHandler, nil)
+		if w.Code != http.StatusBadRequest {
+			t.Errorf("expected status 400, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+
+	t.Run("missing model", func(t *testing.T) {
+		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
+		if w.Code != http.StatusBadRequest {
+			t.Errorf("expected status 400, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+
+	t.Run("missing capabilities generate", func(t *testing.T) {
+		w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
+			Model: "bert",
+			Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
+				"general.architecture": "bert",
+				"bert.pooling_type":    uint32(0),
+			}, []llm.Tensor{})),
+			Stream: &stream,
+		})
+
+		if w.Code != http.StatusOK {
+			t.Fatalf("expected status 200, got %d", w.Code)
+		}
+
+		w = createRequest(t, s.GenerateHandler, api.GenerateRequest{
+			Model: "bert",
+		})
+
+		if w.Code != http.StatusBadRequest {
+			t.Errorf("expected status 400, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support generate"}`); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+
+	t.Run("missing capabilities suffix", func(t *testing.T) {
+		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+			Model:  "test",
+			Prompt: "def add(",
+			Suffix: "    return c",
+		})
+
+		if w.Code != http.StatusBadRequest {
+			t.Errorf("expected status 400, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+
+	t.Run("load model", func(t *testing.T) {
+		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+			Model: "test",
+		})
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+
+		var actual api.GenerateResponse
+		if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
+			t.Fatal(err)
+		}
+
+		if actual.Model != "test" {
+			t.Errorf("expected model test, got %s", actual.Model)
+		}
+
+		if !actual.Done {
+			t.Errorf("expected done true, got false")
+		}
+
+		if actual.DoneReason != "load" {
+			t.Errorf("expected done reason load, got %s", actual.DoneReason)
+		}
+	})
+
+	checkGenerateResponse := func(t *testing.T, body io.Reader, model, content string) {
+		t.Helper()
+
+		var actual api.GenerateResponse
+		if err := json.NewDecoder(body).Decode(&actual); err != nil {
+			t.Fatal(err)
+		}
+
+		if actual.Model != model {
+			t.Errorf("expected model test, got %s", actual.Model)
+		}
+
+		if !actual.Done {
+			t.Errorf("expected done false, got true")
+		}
+
+		if actual.DoneReason != "stop" {
+			t.Errorf("expected done reason stop, got %s", actual.DoneReason)
+		}
+
+		if actual.Response != content {
+			t.Errorf("expected response %s, got %s", content, actual.Response)
+		}
+
+		if actual.Context == nil {
+			t.Errorf("expected context not nil")
+		}
+
+		if actual.PromptEvalCount == 0 {
+			t.Errorf("expected prompt eval count > 0, got 0")
+		}
+
+		if actual.PromptEvalDuration == 0 {
+			t.Errorf("expected prompt eval duration > 0, got 0")
+		}
+
+		if actual.EvalCount == 0 {
+			t.Errorf("expected eval count > 0, got 0")
+		}
+
+		if actual.EvalDuration == 0 {
+			t.Errorf("expected eval duration > 0, got 0")
+		}
+
+		if actual.LoadDuration == 0 {
+			t.Errorf("expected load duration > 0, got 0")
+		}
+
+		if actual.TotalDuration == 0 {
+			t.Errorf("expected total duration > 0, got 0")
+		}
+	}
+
+	mock.CompletionResponse.Content = "Hi!"
+	t.Run("prompt", func(t *testing.T) {
+		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+			Model:  "test",
+			Prompt: "Hello!",
+			Stream: &stream,
+		})
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+
+		checkGenerateResponse(t, w.Body, "test", "Hi!")
+	})
+
+	w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
+		Model:     "test-system",
+		Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
+	})
+
+	if w.Code != http.StatusOK {
+		t.Fatalf("expected status 200, got %d", w.Code)
+	}
+
+	t.Run("prompt with model system", func(t *testing.T) {
+		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+			Model:  "test-system",
+			Prompt: "Hello!",
+			Stream: &stream,
+		})
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+
+		checkGenerateResponse(t, w.Body, "test-system", "Hi!")
+	})
+
+	mock.CompletionResponse.Content = "Abra kadabra!"
+	t.Run("prompt with system", func(t *testing.T) {
+		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+			Model:  "test-system",
+			Prompt: "Hello!",
+			System: "You can perform magic tricks.",
+			Stream: &stream,
+		})
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+
+		checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
+	})
+
+	t.Run("prompt with template", func(t *testing.T) {
+		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+			Model:  "test-system",
+			Prompt: "Help me write tests.",
+			System: "You can perform magic tricks.",
+			Template: `{{- if .System }}{{ .System }} {{ end }}
+{{- if .Prompt }}### USER {{ .Prompt }} {{ end }}
+{{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`,
+			Stream: &stream,
+		})
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+
+		checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
+	})
+
+	w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
+		Model: "test-suffix",
+		Modelfile: `FROM test
+TEMPLATE """{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
+{{- else }}{{ .Prompt }}
+{{- end }}"""`,
+	})
+
+	if w.Code != http.StatusOK {
+		t.Fatalf("expected status 200, got %d", w.Code)
+	}
+
+	t.Run("prompt with suffix", func(t *testing.T) {
+		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+			Model:  "test-suffix",
+			Prompt: "def add(",
+			Suffix: "    return c",
+		})
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF>    return c <MID>"); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+
+	t.Run("prompt without suffix", func(t *testing.T) {
+		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+			Model:  "test-suffix",
+			Prompt: "def add(",
+		})
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+
+	t.Run("raw", func(t *testing.T) {
+		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+			Model:  "test-system",
+			Prompt: "Help me write tests.",
+			Raw:    true,
+			Stream: &stream,
+		})
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+}

+ 3 - 0
server/routes_list_test.go

@@ -7,11 +7,14 @@ import (
 	"slices"
 	"testing"
 
+	"github.com/gin-gonic/gin"
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/envconfig"
 )
 
 func TestList(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+
 	t.Setenv("OLLAMA_MODELS", t.TempDir())
 	envconfig.LoadConfig()
 

+ 107 - 0
server/routes_test.go

@@ -7,6 +7,7 @@ import (
 	"encoding/json"
 	"fmt"
 	"io"
+	"math"
 	"net/http"
 	"net/http/httptest"
 	"os"
@@ -272,6 +273,77 @@ func Test_Routes(t *testing.T) {
 				assert.Equal(t, "library", retrieveResp.OwnedBy)
 			},
 		},
+		{
+			Name:   "Embed Handler Empty Input",
+			Method: http.MethodPost,
+			Path:   "/api/embed",
+			Setup: func(t *testing.T, req *http.Request) {
+				embedReq := api.EmbedRequest{
+					Model: "t-bone",
+					Input: "",
+				}
+				jsonData, err := json.Marshal(embedReq)
+				require.NoError(t, err)
+				req.Body = io.NopCloser(bytes.NewReader(jsonData))
+			},
+			Expected: func(t *testing.T, resp *http.Response) {
+				contentType := resp.Header.Get("Content-Type")
+				if contentType != "application/json; charset=utf-8" {
+					t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
+				}
+				body, err := io.ReadAll(resp.Body)
+				if err != nil {
+					t.Fatal(err)
+				}
+
+				var embedResp api.EmbedResponse
+				err = json.Unmarshal(body, &embedResp)
+				if err != nil {
+					t.Fatal(err)
+				}
+
+				if embedResp.Model != "t-bone" {
+					t.Fatalf("expected model t-bone, got %s", embedResp.Model)
+				}
+
+				if embedResp.Embeddings == nil {
+					t.Fatalf("expected embeddings to not be nil, got %v", embedResp.Embeddings)
+				}
+
+				if len(embedResp.Embeddings) != 0 {
+					t.Fatalf("expected embeddings to be empty, got %v", embedResp.Embeddings)
+				}
+			},
+		},
+		{
+			Name:   "Embed Handler Invalid Input",
+			Method: http.MethodPost,
+			Path:   "/api/embed",
+			Setup: func(t *testing.T, req *http.Request) {
+				embedReq := api.EmbedRequest{
+					Model: "t-bone",
+					Input: 2,
+				}
+				jsonData, err := json.Marshal(embedReq)
+				require.NoError(t, err)
+				req.Body = io.NopCloser(bytes.NewReader(jsonData))
+			},
+			Expected: func(t *testing.T, resp *http.Response) {
+				contentType := resp.Header.Get("Content-Type")
+				if contentType != "application/json; charset=utf-8" {
+					t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
+				}
+				_, err := io.ReadAll(resp.Body)
+
+				if err != nil {
+					t.Fatal(err)
+				}
+
+				if resp.StatusCode != http.StatusBadRequest {
+					t.Fatalf("expected status code 400, got %d", resp.StatusCode)
+				}
+			},
+		},
 	}
 
 	t.Setenv("OLLAMA_MODELS", t.TempDir())
@@ -420,3 +492,38 @@ func TestShow(t *testing.T) {
 		t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"])
 	}
 }
+
+func TestNormalize(t *testing.T) {
+	type testCase struct {
+		input []float32
+	}
+
+	testCases := []testCase{
+		{input: []float32{1}},
+		{input: []float32{0, 1, 2, 3}},
+		{input: []float32{0.1, 0.2, 0.3}},
+		{input: []float32{-0.1, 0.2, 0.3, -0.4}},
+		{input: []float32{0, 0, 0}},
+	}
+
+	isNormalized := func(vec []float32) (res bool) {
+		sum := 0.0
+		for _, v := range vec {
+			sum += float64(v * v)
+		}
+		if math.Abs(sum-1) > 1e-6 {
+			return sum == 0
+		} else {
+			return true
+		}
+	}
+
+	for _, tc := range testCases {
+		t.Run("", func(t *testing.T) {
+			normalized := normalize(tc.input)
+			if !isNormalized(normalized) {
+				t.Errorf("Vector %v is not normalized", tc.input)
+			}
+		})
+	}
+}

+ 4 - 4
server/sched_test.go

@@ -642,8 +642,8 @@ type mockLlm struct {
 	pingResp           error
 	waitResp           error
 	completionResp     error
-	embeddingResp      []float64
-	embeddingRespErr   error
+	embedResp          [][]float32
+	embedRespErr       error
 	tokenizeResp       []int
 	tokenizeRespErr    error
 	detokenizeResp     string
@@ -660,8 +660,8 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes
 func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
 	return s.completionResp
 }
-func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
-	return s.embeddingResp, s.embeddingRespErr
+func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) {
+	return s.embedResp, s.embedRespErr
 }
 func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
 	return s.tokenizeResp, s.tokenizeRespErr

+ 67 - 0
server/testdata/tools/command-r-plus.gotmpl

@@ -0,0 +1,67 @@
+{{- if or .Tools .System }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>
+{{- if .Tools }}# Safety Preamble
+The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
+
+# System Preamble
+## Basic Rules
+You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
+
+{{ if .System }}# User Preamble
+{{ .System }}
+{{- end }}
+
+## Available Tools
+Here is a list of tools that you have available to you:
+{{- range .Tools }}
+
+```python
+def {{ .Function.Name }}(
+{{- range $name, $property := .Function.Parameters.Properties }}{{ $name }}: {{ $property.Type }}, {{ end }}) -> List[Dict]:
+    """{{ .Function.Description }}
+
+{{- if .Function.Parameters.Properties }}
+
+    Args:
+{{- range $name, $property := .Function.Parameters.Properties }}
+        {{ $name }} ({{ $property.Type }}): {{ $property.Description }}
+{{- end }}
+{{- end }}
+    """
+    pass
+```
+{{- end }}
+{{- else if .System }}{{ .System }}
+{{- end }}<|END_OF_TURN_TOKEN|>
+{{- end }}
+{{- range .Messages }}
+{{- if eq .Role "system" }}
+{{- continue }}
+{{- end }}<|START_OF_TURN_TOKEN|>
+{{- if eq .Role "user" }}<|USER_TOKEN|>{{ .Content }}
+{{- else if eq .Role "assistant" }}<|CHATBOT_TOKEN|>
+{{- if .Content }}{{ .Content }}
+{{- else if .ToolCalls }}
+Action: ```json
+[
+{{- range .ToolCalls }}
+    {
+        "tool_name": "{{ .Function.Name }}",
+        "parameters": {{ json .Function.Arguments }}
+    }
+{{- end }}
+]```
+{{ continue }}
+{{ end }}
+{{- else if eq .Role "tool" }}<|SYSTEM_TOKEN|><results>
+{{ .Content }}</results>
+{{- end }}<|END_OF_TURN_TOKEN|>
+{{- end }}
+{{- if .Tools }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
+```json
+[
+    {
+        "tool_name": title of the tool in the specification,
+        "parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
+    }
+]```
+{{- end }}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>

+ 39 - 0
server/testdata/tools/command-r-plus.out

@@ -0,0 +1,39 @@
+<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># Safety Preamble
+The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
+
+# System Preamble
+## Basic Rules
+You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
+
+# User Preamble
+You are a knowledgable assistant. You can answer questions and perform tasks.
+
+## Available Tools
+Here is a list of tools that you have available to you:
+
+```python
+def get_current_weather(format: string, location: string, ) -> List[Dict]:
+    """Get the current weather
+
+    Args:
+        format (string): The temperature unit to use. Infer this from the users location.
+        location (string): The city and state, e.g. San Francisco, CA
+    """
+    pass
+```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in Paris?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
+Action: ```json
+[
+    {
+        "tool_name": "get_current_weather",
+        "parameters": {"format":"celsius","location":"Paris, France"}
+    }
+]```
+<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><results>
+22</results><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>The current temperature in Paris, France is 22 degrees Celsius.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in San Francisco and Toronto?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
+```json
+[
+    {
+        "tool_name": title of the tool in the specification,
+        "parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
+    }
+]```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>

+ 31 - 0
server/testdata/tools/firefunction.gotmpl

@@ -0,0 +1,31 @@
+{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
+{{- if .System }}
+{{ .System }}
+{{- end }}
+In addition to plain text responses, you can chose to call one or more of the provided functions.
+
+Use the following rule to decide when to call a function:
+  * if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
+  * if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
+
+If you decide to call functions:
+  * prefix function calls with functools marker (no closing marker required)
+  * all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
+  * follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
+  * respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
+  * make sure you pick the right functions that match the user intent
+
+Available functions as JSON spec:
+{{- if .Tools }}
+{{ json .Tools }}
+{{- end }}<|eot_id|>
+{{- end }}
+{{- range .Messages }}<|start_header_id|>
+{{- if or (eq .Role "user") (eq .Role "assistant") (eq .Role "tool") }}{{ .Role }}
+{{- end }}<|end_header_id|>
+{{- if .Content }}{{ .Content }}
+{{- else if .ToolCalls }} functools[
+{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}{{ "}" }}
+{{- end }}]
+{{- end }}<|eot_id|>
+{{- end }}<|start_header_id|>assistant<|end_header_id|>

+ 17 - 0
server/testdata/tools/firefunction.out

@@ -0,0 +1,17 @@
+<|start_header_id|>system<|end_header_id|>
+You are a knowledgable assistant. You can answer questions and perform tasks.
+In addition to plain text responses, you can chose to call one or more of the provided functions.
+
+Use the following rule to decide when to call a function:
+  * if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
+  * if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
+
+If you decide to call functions:
+  * prefix function calls with functools marker (no closing marker required)
+  * all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
+  * follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
+  * respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
+  * make sure you pick the right functions that match the user intent
+
+Available functions as JSON spec:
+[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]<|eot_id|><|start_header_id|><|end_header_id|>You are a knowledgable assistant. You can answer questions and perform tasks.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|> functools[{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]<|eot_id|><|start_header_id|>tool<|end_header_id|>22<|eot_id|><|start_header_id|>assistant<|end_header_id|>The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

+ 39 - 0
server/testdata/tools/messages.json

@@ -0,0 +1,39 @@
+[
+  {
+    "role": "system",
+    "content": "You are a knowledgable assistant. You can answer questions and perform tasks."
+  },
+  {
+    "role": "user",
+    "content": "What's the weather like today in Paris?"
+  },
+  {
+    "role": "assistant",
+    "tool_calls": [
+      {
+        "id": "89a1e453-0bce-4de3-a456-c54bed09c520",
+        "type": "function",
+        "function": {
+          "name": "get_current_weather",
+          "arguments": {
+            "location": "Paris, France",
+            "format": "celsius"
+          }
+        }
+      }
+    ]
+  },
+  {
+    "role": "tool",
+    "tool_call_id": "89a1e453-0bce-4de3-a456-c54bed09c520",
+    "content": "22"
+  },
+  {
+    "role": "assistant",
+    "content": "The current temperature in Paris, France is 22 degrees Celsius."
+  },
+  {
+    "role": "user",
+    "content": "What's the weather like today in San Francisco and Toronto?"
+  }
+]

+ 15 - 0
server/testdata/tools/mistral.gotmpl

@@ -0,0 +1,15 @@
+{{- range $index, $_ := .Messages }}
+{{- if eq .Role "user" }}
+{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS]
+{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
+
+{{ end }}{{ .Content }}[/INST]
+{{- else if eq .Role "assistant" }}
+{{- if .Content }} {{ .Content }}</s>
+{{- else if .ToolCalls }}[TOOL_CALLS] [
+{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
+{{- end }}]</s>
+{{- end }}
+{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
+{{- end }}
+{{- end }}

+ 3 - 0
server/testdata/tools/mistral.out

@@ -0,0 +1,3 @@
+[INST] What's the weather like today in Paris?[/INST][TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]</s>[TOOL_RESULTS] {"content": 22}[/TOOL_RESULTS] The current temperature in Paris, France is 22 degrees Celsius.</s>[AVAILABLE_TOOLS] [{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}][/AVAILABLE_TOOLS][INST] You are a knowledgable assistant. You can answer questions and perform tasks.
+
+What's the weather like today in San Francisco and Toronto?[/INST]

+ 30 - 0
server/testdata/tools/tools.json

@@ -0,0 +1,30 @@
+[
+  {
+    "type": "function",
+    "function": {
+      "name": "get_current_weather",
+      "description": "Get the current weather",
+      "parameters": {
+        "type": "object",
+        "properties": {
+          "location": {
+            "type": "string",
+            "description": "The city and state, e.g. San Francisco, CA"
+          },
+          "format": {
+            "type": "string",
+            "enum": [
+              "celsius",
+              "fahrenheit"
+            ],
+            "description": "The temperature unit to use. Infer this from the users location."
+          }
+        },
+        "required": [
+          "location",
+          "format"
+        ]
+      }
+    }
+  }
+]

+ 127 - 67
template/template.go

@@ -103,16 +103,9 @@ var response = parse.ActionNode{
 }
 
 var funcs = template.FuncMap{
-	// contents returns the contents of messages with an optional role filter
-	"contents": func(v []*api.Message, role ...string) string {
-		var parts []string
-		for _, m := range v {
-			if len(role) == 0 || role[0] == "" || m.Role == role[0] {
-				parts = append(parts, m.Content)
-			}
-		}
-
-		return strings.Join(parts, "\n\n")
+	"json": func(v any) string {
+		b, _ := json.Marshal(v)
+		return string(b)
 	},
 }
 
@@ -141,7 +134,7 @@ func (t *Template) Vars() []string {
 	var vars []string
 	for _, tt := range t.Templates() {
 		for _, n := range tt.Root.Nodes {
-			vars = append(vars, parseNode(n)...)
+			vars = append(vars, Identifiers(n)...)
 		}
 	}
 
@@ -157,32 +150,81 @@ func (t *Template) Vars() []string {
 
 type Values struct {
 	Messages []api.Message
+	Tools    []api.Tool
+	Prompt   string
+	Suffix   string
 
 	// forceLegacy is a flag used to test compatibility with legacy templates
 	forceLegacy bool
 }
 
+func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
+	var walk func(parse.Node) parse.Node
+	walk = func(n parse.Node) parse.Node {
+		if fn(n) {
+			return n
+		}
+
+		switch t := n.(type) {
+		case *parse.ListNode:
+			for _, c := range t.Nodes {
+				if n := walk(c); n != nil {
+					return n
+				}
+			}
+		case *parse.BranchNode:
+			for _, n := range []*parse.ListNode{t.List, t.ElseList} {
+				if n != nil {
+					if n := walk(n); n != nil {
+						return n
+					}
+				}
+			}
+		case *parse.IfNode:
+			return walk(&t.BranchNode)
+		case *parse.WithNode:
+			return walk(&t.BranchNode)
+		case *parse.RangeNode:
+			return walk(&t.BranchNode)
+		}
+
+		return nil
+	}
+
+	if n := walk(t.Tree.Root); n != nil {
+		return (&template.Template{
+			Tree: &parse.Tree{
+				Root: &parse.ListNode{
+					Nodes: []parse.Node{n},
+				},
+			},
+		}).Funcs(funcs)
+	}
+
+	return nil
+}
+
 func (t *Template) Execute(w io.Writer, v Values) error {
-	collated := collate(v.Messages)
-	if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
+	system, messages := collate(v.Messages)
+	if v.Prompt != "" && v.Suffix != "" {
 		return t.Template.Execute(w, map[string]any{
-			"Messages": collated,
+			"Prompt":   v.Prompt,
+			"Suffix":   v.Suffix,
+			"Response": "",
+		})
+	} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
+		return t.Template.Execute(w, map[string]any{
+			"System":   system,
+			"Messages": messages,
+			"Tools":    v.Tools,
 		})
 	}
 
+	system = ""
 	var b bytes.Buffer
-	var system, prompt, response string
-	for i, m := range collated {
-		switch m.Role {
-		case "system":
-			system = m.Content
-		case "user":
-			prompt = m.Content
-		case "assistant":
-			response = m.Content
-		}
-
-		if i != len(collated)-1 && prompt != "" && response != "" {
+	var prompt, response string
+	for _, m := range messages {
+		execute := func() error {
 			if err := t.Template.Execute(&b, map[string]any{
 				"System":   system,
 				"Prompt":   prompt,
@@ -194,17 +236,33 @@ func (t *Template) Execute(w io.Writer, v Values) error {
 			system = ""
 			prompt = ""
 			response = ""
+			return nil
+		}
+
+		switch m.Role {
+		case "system":
+			if prompt != "" || response != "" {
+				if err := execute(); err != nil {
+					return err
+				}
+			}
+			system = m.Content
+		case "user":
+			if response != "" {
+				if err := execute(); err != nil {
+					return err
+				}
+			}
+			prompt = m.Content
+		case "assistant":
+			response = m.Content
 		}
 	}
 
 	var cut bool
 	nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
-		switch t := n.(type) {
-		case *parse.ActionNode:
-		case *parse.FieldNode:
-			if slices.Contains(t.Ident, "Response") {
-				cut = true
-			}
+		if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
+			cut = true
 		}
 
 		return cut
@@ -212,7 +270,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
 
 	tree := parse.Tree{Root: nodes.(*parse.ListNode)}
 	if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
-		"System": "",
+		"System": system,
 		"Prompt": prompt,
 	}); err != nil {
 		return err
@@ -223,11 +281,13 @@ func (t *Template) Execute(w io.Writer, v Values) error {
 }
 
 // collate messages based on role. consecutive messages of the same role are merged
-// into a single message. collate also pulls out and merges messages with Role == "system"
-// which are templated separately. As a side effect, it mangles message content adding image
-// tags ([img-%d]) as needed
-func collate(msgs []api.Message) (collated []*api.Message) {
+// into a single message. collate also collects and returns all system messages.
+// collate mutates message content adding image tags ([img-%d]) as needed
+func collate(msgs []api.Message) (string, []*api.Message) {
 	var n int
+
+	var system []string
+	var collated []*api.Message
 	for i := range msgs {
 		msg := msgs[i]
 		for range msg.Images {
@@ -240,6 +300,10 @@ func collate(msgs []api.Message) (collated []*api.Message) {
 			n++
 		}
 
+		if msg.Role == "system" {
+			system = append(system, msg.Content)
+		}
+
 		if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
 			collated[len(collated)-1].Content += "\n\n" + msg.Content
 		} else {
@@ -247,53 +311,49 @@ func collate(msgs []api.Message) (collated []*api.Message) {
 		}
 	}
 
-	return
+	return strings.Join(system, "\n\n"), collated
 }
 
-func parseNode(n parse.Node) []string {
+// Identifiers walks the node tree returning any identifiers it finds along the way
+func Identifiers(n parse.Node) []string {
 	switch n := n.(type) {
-	case *parse.ActionNode:
-		return parseNode(n.Pipe)
-	case *parse.IfNode:
-		names := parseNode(n.Pipe)
-		names = append(names, parseNode(n.List)...)
-		if n.ElseList != nil {
-			names = append(names, parseNode(n.ElseList)...)
+	case *parse.ListNode:
+		var names []string
+		for _, n := range n.Nodes {
+			names = append(names, Identifiers(n)...)
 		}
+
 		return names
-	case *parse.RangeNode:
-		names := parseNode(n.Pipe)
-		names = append(names, parseNode(n.List)...)
-		if n.ElseList != nil {
-			names = append(names, parseNode(n.ElseList)...)
+	case *parse.TemplateNode:
+		return Identifiers(n.Pipe)
+	case *parse.ActionNode:
+		return Identifiers(n.Pipe)
+	case *parse.BranchNode:
+		names := Identifiers(n.Pipe)
+		for _, n := range []*parse.ListNode{n.List, n.ElseList} {
+			if n != nil {
+				names = append(names, Identifiers(n)...)
+			}
 		}
 		return names
+	case *parse.IfNode:
+		return Identifiers(&n.BranchNode)
+	case *parse.RangeNode:
+		return Identifiers(&n.BranchNode)
 	case *parse.WithNode:
-		names := parseNode(n.Pipe)
-		names = append(names, parseNode(n.List)...)
-		if n.ElseList != nil {
-			names = append(names, parseNode(n.ElseList)...)
-		}
-		return names
+		return Identifiers(&n.BranchNode)
 	case *parse.PipeNode:
 		var names []string
 		for _, c := range n.Cmds {
 			for _, a := range c.Args {
-				names = append(names, parseNode(a)...)
+				names = append(names, Identifiers(a)...)
 			}
 		}
-		return names
-	case *parse.ListNode:
-		var names []string
-		for _, n := range n.Nodes {
-			names = append(names, parseNode(n)...)
-		}
-
 		return names
 	case *parse.FieldNode:
 		return n.Ident
-	case *parse.TemplateNode:
-		return parseNode(n.Pipe)
+	case *parse.VariableNode:
+		return n.Ident
 	}
 
 	return nil

+ 38 - 40
template/template_test.go

@@ -216,13 +216,11 @@ func TestExecuteWithMessages(t *testing.T) {
 				{"response", `[INST] {{ if .System }}{{ .System }}
 
 {{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
-				{"messages", `{{- $system := contents .Messages "system" -}}
-{{- range $index, $_ := .Messages }}
-{{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }}
-{{- $system = "" }}
+				{"messages", `[INST] {{ if .System }}{{ .System }}
 
-{{ end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
-{{- end }}
+{{ end }}
+{{- range .Messages }}
+{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
 {{- end }}`},
 			},
 			Values{
@@ -243,13 +241,11 @@ func TestExecuteWithMessages(t *testing.T) {
 				{"response", `[INST] {{ if .System }}{{ .System }}
 
 {{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
-				{"messages", `{{- $system := contents .Messages "system" -}}
-{{- range $index, $_ := .Messages }}
-{{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }}
-{{- $system = "" }}
+				{"messages", `[INST] {{ if .System }}{{ .System }}
 
-{{ end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
-{{- end }}
+{{ end }}
+{{- range .Messages }}
+{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
 {{- end }}`},
 			},
 			Values{
@@ -364,35 +360,37 @@ Answer: `,
 	}
 }
 
-func TestFuncs(t *testing.T) {
-	t.Run("contents", func(t *testing.T) {
-		cases := map[string]string{
-			"":          "A\n\nB\n\nC\n\nD\n\nE\n\nF",
-			"system":    "A\n\nF",
-			"user":      "B\n\nE",
-			"assistant": "C\n\nD",
-		}
+func TestExecuteWithSuffix(t *testing.T) {
+	tmpl, err := Parse(`{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
+{{- else }}{{ .Prompt }}
+{{- end }}`)
+	if err != nil {
+		t.Fatal(err)
+	}
 
-		s := []*api.Message{
-			{Role: "system", Content: "A"},
-			{Role: "user", Content: "B"},
-			{Role: "assistant", Content: "C"},
-			{Role: "assistant", Content: "D"},
-			{Role: "user", Content: "E"},
-			{Role: "system", Content: "F"},
-		}
+	cases := []struct {
+		name   string
+		values Values
+		expect string
+	}{
+		{
+			"message", Values{Messages: []api.Message{{Role: "user", Content: "hello"}}}, "hello",
+		},
+		{
+			"prompt suffix", Values{Prompt: "def add(", Suffix: "return x"}, "<PRE> def add( <SUF>return x <MID>",
+		},
+	}
 
-		fn, ok := funcs["contents"].(func([]*api.Message, ...string) string)
-		if !ok {
-			t.Fatal("contents is not a function")
-		}
+	for _, tt := range cases {
+		t.Run(tt.name, func(t *testing.T) {
+			var b bytes.Buffer
+			if err := tmpl.Execute(&b, tt.values); err != nil {
+				t.Fatal(err)
+			}
 
-		for k, v := range cases {
-			t.Run(k, func(t *testing.T) {
-				if diff := cmp.Diff(fn(s, k), v); diff != "" {
-					t.Errorf("mismatch (-got +want):\n%s", diff)
-				}
-			})
-		}
-	})
+			if diff := cmp.Diff(b.String(), tt.expect); diff != "" {
+				t.Errorf("mismatch (-got +want):\n%s", diff)
+			}
+		})
+	}
 }