浏览代码

Merge pull request #5284 from ollama/mxyng/tools

tools
Michael Yang 9 月之前
父节点
当前提交
64039df6d7

+ 36 - 3
api/types.go

@@ -97,6 +97,9 @@ type ChatRequest struct {
 	// followin the request.
 	KeepAlive *Duration `json:"keep_alive,omitempty"`
 
+	// Tools is an optional list of tools the model has access to.
+	Tools []Tool `json:"tools,omitempty"`
+
 	// Options lists model-specific options.
 	Options map[string]interface{} `json:"options"`
 }
@@ -105,9 +108,36 @@ type ChatRequest struct {
 // role ("system", "user", or "assistant"), the content and an optional list
 // of images.
 type Message struct {
-	Role    string      `json:"role"`
-	Content string      `json:"content"`
-	Images  []ImageData `json:"images,omitempty"`
+	Role      string      `json:"role"`
+	Content   string      `json:"content,omitempty"`
+	Images    []ImageData `json:"images,omitempty"`
+	ToolCalls []ToolCall  `json:"tool_calls,omitempty"`
+}
+
+type ToolCall struct {
+	ID       string `json:"id"`
+	Type     string `json:"type"`
+	Function struct {
+		Name      string         `json:"name"`
+		Arguments map[string]any `json:"arguments"`
+	} `json:"function"`
+}
+
+type Tool struct {
+	Type     string `json:"type"`
+	Function struct {
+		Name        string `json:"name"`
+		Description string `json:"description"`
+		Parameters  struct {
+			Type       string   `json:"type"`
+			Required   []string `json:"required"`
+			Properties map[string]struct {
+				Type        string   `json:"type"`
+				Description string   `json:"description"`
+				Enum        []string `json:"enum,omitempty"`
+			} `json:"properties"`
+		} `json:"parameters"`
+	} `json:"function"`
 }
 
 func (m *Message) UnmarshalJSON(b []byte) error {
@@ -374,6 +404,9 @@ type GenerateResponse struct {
 	// Response is the textual response itself.
 	Response string `json:"response"`
 
+	// ToolCalls is the list of tools the model wants to call
+	ToolCalls []ToolCall `json:"tool_calls,omitempty"`
+
 	// Done specifies if the response is complete.
 	Done bool `json:"done"`
 

+ 9 - 2
server/images.go

@@ -38,7 +38,10 @@ var errCapabilityCompletion = errors.New("completion")
 
 type Capability string
 
-const CapabilityCompletion = Capability("completion")
+const (
+	CapabilityCompletion = Capability("completion")
+	CapabilityTools      = Capability("tools")
+)
 
 type registryOptions struct {
 	Insecure bool
@@ -88,6 +91,10 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
 			if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
 				errs = append(errs, errCapabilityCompletion)
 			}
+		case CapabilityTools:
+			if !slices.Contains(m.Template.Vars(), "tools") {
+				errs = append(errs, errors.New("tools"))
+			}
 		default:
 			slog.Error("unknown capability", "capability", cap)
 			return fmt.Errorf("unknown capability: %s", cap)
@@ -95,7 +102,7 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
 	}
 
 	if err := errors.Join(errs...); err != nil {
-		return fmt.Errorf("missing capabilities: %w", errors.Join(errs...))
+		return fmt.Errorf("does not support %w", errors.Join(errs...))
 	}
 
 	return nil

+ 105 - 0
server/model.go

@@ -4,6 +4,7 @@ import (
 	"archive/zip"
 	"bytes"
 	"context"
+	"encoding/json"
 	"errors"
 	"fmt"
 	"io"
@@ -11,7 +12,11 @@ import (
 	"net/http"
 	"os"
 	"path/filepath"
+	"slices"
+	"strings"
+	"text/template/parse"
 
+	"github.com/google/uuid"
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/convert"
 	"github.com/ollama/ollama/llm"
@@ -289,3 +294,103 @@ func detectContentType(r io.Reader) (string, error) {
 
 	return "unknown", nil
 }
+
+// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
+// mxyng: this only really works if the input contains tool calls in some JSON format
+func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
+	// create a subtree from the node that ranges over .ToolCalls
+	tmpl := m.Template.Subtree(func(n parse.Node) bool {
+		if t, ok := n.(*parse.RangeNode); ok {
+			return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
+		}
+
+		return false
+	})
+
+	if tmpl == nil {
+		return nil, false
+	}
+
+	var b bytes.Buffer
+	if err := tmpl.Execute(&b, map[string][]map[string]any{
+		"ToolCalls": {
+			{
+				"Function": map[string]any{
+					"Name":      "@@name@@",
+					"Arguments": "@@arguments@@",
+				},
+			},
+		},
+	}); err != nil {
+		return nil, false
+	}
+
+	var kv map[string]string
+	// execute the subtree with placeholders to identify the keys
+	if err := json.Unmarshal(b.Bytes(), &kv); err != nil {
+		return nil, false
+	}
+
+	// find the keys that correspond to the name and arguments fields
+	var name, arguments string
+	for k, v := range kv {
+		switch v {
+		case "@@name@@":
+			name = k
+		case "@@arguments@@":
+			arguments = k
+		}
+	}
+
+	var sm []map[string]any
+	decoder := json.NewDecoder(strings.NewReader(s))
+	for {
+		// incrementally decode the JSON into a list of JSON objects
+		// skipping over any invalid tokens
+		if err := decoder.Decode(&sm); err != nil {
+			if errors.Is(err, io.EOF) {
+				break
+			}
+
+			if errors.As(err, new(*json.SyntaxError)) {
+				r := decoder.Buffered()
+				if _, err := r.Read(make([]byte, decoder.InputOffset()+1)); err != nil {
+					break
+				}
+
+				decoder = json.NewDecoder(r)
+				continue
+			}
+
+			return nil, false
+		}
+
+		// break as soon as a valid object is decoded
+		break
+	}
+
+	var toolCalls []api.ToolCall
+	for _, kv := range sm {
+		call := api.ToolCall{
+			ID:   uuid.New().String(),
+			Type: "function",
+		}
+
+		for k, v := range kv {
+			switch k {
+			case name:
+				call.Function.Name = v.(string)
+			case arguments:
+				call.Function.Arguments = v.(map[string]any)
+			}
+		}
+
+		toolCalls = append(toolCalls, call)
+	}
+
+	if len(toolCalls) > 0 {
+		return toolCalls, true
+	}
+
+	return nil, false
+}

+ 122 - 0
server/model_test.go

@@ -3,7 +3,9 @@ package server
 import (
 	"archive/zip"
 	"bytes"
+	"encoding/json"
 	"errors"
+	"fmt"
 	"io"
 	"os"
 	"path/filepath"
@@ -11,7 +13,9 @@ import (
 	"strings"
 	"testing"
 
+	"github.com/google/go-cmp/cmp"
 	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/template"
 )
 
 func createZipFile(t *testing.T, name string) *os.File {
@@ -110,3 +114,121 @@ func TestExtractFromZipFile(t *testing.T) {
 		})
 	}
 }
+
+type function struct {
+	Name      string         `json:"name"`
+	Arguments map[string]any `json:"arguments"`
+}
+
+func readFile(t *testing.T, base, name string) *bytes.Buffer {
+	t.Helper()
+
+	bts, err := os.ReadFile(filepath.Join(base, name))
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	return bytes.NewBuffer(bts)
+}
+
+func TestExecuteWithTools(t *testing.T) {
+	p := filepath.Join("testdata", "tools")
+	cases := []struct {
+		model  string
+		output string
+	}{
+		{"mistral", `[TOOL_CALLS]  [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`},
+		{"mistral", `[TOOL_CALLS]  [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
+
+The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`},
+		{"command-r-plus", "Action: ```json" + `
+[
+    {
+        "tool_name": "get_current_weather",
+        "parameters": {
+            "format": "fahrenheit",
+            "location": "San Francisco, CA"
+        }
+    },
+    {
+        "tool_name": "get_current_weather",
+        "parameters": {
+            "format": "celsius",
+            "location": "Toronto, Canada"
+        }
+    }
+]
+` + "```"},
+		{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`},
+	}
+
+	var tools []api.Tool
+	if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil {
+		t.Fatal(err)
+	}
+
+	var messages []api.Message
+	if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil {
+		t.Fatal(err)
+	}
+
+	calls := []api.ToolCall{
+		{
+			Type: "function",
+			Function: function{
+				Name: "get_current_weather",
+				Arguments: map[string]any{
+					"format":   "fahrenheit",
+					"location": "San Francisco, CA",
+				},
+			},
+		},
+		{
+			Type: "function",
+			Function: function{
+				Name: "get_current_weather",
+				Arguments: map[string]any{
+					"format":   "celsius",
+					"location": "Toronto, Canada",
+				},
+			},
+		},
+	}
+
+	for _, tt := range cases {
+		t.Run(tt.model, func(t *testing.T) {
+			tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			t.Run("template", func(t *testing.T) {
+				var actual bytes.Buffer
+				if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil {
+					t.Fatal(err)
+				}
+
+				if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
+					t.Errorf("mismatch (-got +want):\n%s", diff)
+				}
+			})
+
+			t.Run("parse", func(t *testing.T) {
+				m := &Model{Template: tmpl}
+				actual, ok := m.parseToolCalls(tt.output)
+				if !ok {
+					t.Fatal("failed to parse tool calls")
+				}
+
+				for i := range actual {
+					// ID is randomly generated so clear it for comparison
+					actual[i].ID = ""
+				}
+
+				if diff := cmp.Diff(actual, calls); diff != "" {
+					t.Errorf("mismatch (-got +want):\n%s", diff)
+				}
+			})
+		})
+	}
+}

+ 3 - 3
server/prompt.go

@@ -15,7 +15,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
 // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
 // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
 // latest message and 2) system messages
-func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
+func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
 	var system []api.Message
 	// always include the last message
 	n := len(msgs) - 1
@@ -29,7 +29,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
 		}
 
 		var b bytes.Buffer
-		if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
+		if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
 			return "", nil, err
 		}
 
@@ -57,7 +57,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
 
 	// truncate any messages that do not fit into the context window
 	var b bytes.Buffer
-	if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
+	if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil {
 		return "", nil, err
 	}
 

+ 1 - 1
server/prompt_test.go

@@ -192,7 +192,7 @@ func TestChatPrompt(t *testing.T) {
 		t.Run(tt.name, func(t *testing.T) {
 			model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
 			opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
-			prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs)
+			prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs, nil)
 			if err != nil {
 				t.Fatal(err)
 			}

+ 19 - 5
server/routes.go

@@ -265,6 +265,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		}
 
 		r.Response = sb.String()
+		if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
+			r.ToolCalls = toolCalls
+			r.Response = ""
+		}
+
 		c.JSON(http.StatusOK, r)
 		return
 	}
@@ -1279,6 +1284,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
 	}
 
 	caps := []Capability{CapabilityCompletion}
+	if req.Tools != nil {
+		caps = append(caps, CapabilityTools)
+	}
+
 	r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
 	if errors.Is(err, errCapabilityCompletion) {
 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
@@ -1305,7 +1314,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
 		req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...)
 	}
 
-	prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
+	prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages, req.Tools)
 	if err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
@@ -1348,13 +1357,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
 	}()
 
 	if req.Stream != nil && !*req.Stream {
-		var r api.ChatResponse
+		var resp api.ChatResponse
 		var sb strings.Builder
 		for rr := range ch {
 			switch t := rr.(type) {
 			case api.ChatResponse:
 				sb.WriteString(t.Message.Content)
-				r = t
+				resp = t
 			case gin.H:
 				msg, ok := t["error"].(string)
 				if !ok {
@@ -1369,8 +1378,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
 			}
 		}
 
-		r.Message.Content = sb.String()
-		c.JSON(http.StatusOK, r)
+		resp.Message.Content = sb.String()
+		if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
+			resp.Message.ToolCalls = toolCalls
+			resp.Message.Content = ""
+		}
+
+		c.JSON(http.StatusOK, resp)
 		return
 	}
 

+ 67 - 0
server/testdata/tools/command-r-plus.gotmpl

@@ -0,0 +1,67 @@
+{{- if or .Tools .System }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>
+{{- if .Tools }}# Safety Preamble
+The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
+
+# System Preamble
+## Basic Rules
+You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
+
+{{ if .System }}# User Preamble
+{{ .System }}
+{{- end }}
+
+## Available Tools
+Here is a list of tools that you have available to you:
+{{- range .Tools }}
+
+```python
+def {{ .Function.Name }}(
+{{- range $name, $property := .Function.Parameters.Properties }}{{ $name }}: {{ $property.Type }}, {{ end }}) -> List[Dict]:
+    """{{ .Function.Description }}
+
+{{- if .Function.Parameters.Properties }}
+
+    Args:
+{{- range $name, $property := .Function.Parameters.Properties }}
+        {{ $name }} ({{ $property.Type }}): {{ $property.Description }}
+{{- end }}
+{{- end }}
+    """
+    pass
+```
+{{- end }}
+{{- else if .System }}{{ .System }}
+{{- end }}<|END_OF_TURN_TOKEN|>
+{{- end }}
+{{- range .Messages }}
+{{- if eq .Role "system" }}
+{{- continue }}
+{{- end }}<|START_OF_TURN_TOKEN|>
+{{- if eq .Role "user" }}<|USER_TOKEN|>{{ .Content }}
+{{- else if eq .Role "assistant" }}<|CHATBOT_TOKEN|>
+{{- if .Content }}{{ .Content }}
+{{- else if .ToolCalls }}
+Action: ```json
+[
+{{- range .ToolCalls }}
+    {
+        "tool_name": "{{ .Function.Name }}",
+        "parameters": {{ json .Function.Arguments }}
+    }
+{{- end }}
+]```
+{{ continue }}
+{{ end }}
+{{- else if eq .Role "tool" }}<|SYSTEM_TOKEN|><results>
+{{ .Content }}</results>
+{{- end }}<|END_OF_TURN_TOKEN|>
+{{- end }}
+{{- if .Tools }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
+```json
+[
+    {
+        "tool_name": title of the tool in the specification,
+        "parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
+    }
+]```
+{{- end }}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>

+ 39 - 0
server/testdata/tools/command-r-plus.out

@@ -0,0 +1,39 @@
+<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># Safety Preamble
+The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
+
+# System Preamble
+## Basic Rules
+You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
+
+# User Preamble
+You are a knowledgable assistant. You can answer questions and perform tasks.
+
+## Available Tools
+Here is a list of tools that you have available to you:
+
+```python
+def get_current_weather(format: string, location: string, ) -> List[Dict]:
+    """Get the current weather
+
+    Args:
+        format (string): The temperature unit to use. Infer this from the users location.
+        location (string): The city and state, e.g. San Francisco, CA
+    """
+    pass
+```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in Paris?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
+Action: ```json
+[
+    {
+        "tool_name": "get_current_weather",
+        "parameters": {"format":"celsius","location":"Paris, France"}
+    }
+]```
+<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><results>
+22</results><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>The current temperature in Paris, France is 22 degrees Celsius.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in San Francisco and Toronto?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
+```json
+[
+    {
+        "tool_name": title of the tool in the specification,
+        "parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
+    }
+]```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>

+ 31 - 0
server/testdata/tools/firefunction.gotmpl

@@ -0,0 +1,31 @@
+{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
+{{- if .System }}
+{{ .System }}
+{{- end }}
+In addition to plain text responses, you can chose to call one or more of the provided functions.
+
+Use the following rule to decide when to call a function:
+  * if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
+  * if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
+
+If you decide to call functions:
+  * prefix function calls with functools marker (no closing marker required)
+  * all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
+  * follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
+  * respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
+  * make sure you pick the right functions that match the user intent
+
+Available functions as JSON spec:
+{{- if .Tools }}
+{{ json .Tools }}
+{{- end }}<|eot_id|>
+{{- end }}
+{{- range .Messages }}<|start_header_id|>
+{{- if or (eq .Role "user") (eq .Role "assistant") (eq .Role "tool") }}{{ .Role }}
+{{- end }}<|end_header_id|>
+{{- if .Content }}{{ .Content }}
+{{- else if .ToolCalls }} functools[
+{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}{{ "}" }}
+{{- end }}]
+{{- end }}<|eot_id|>
+{{- end }}<|start_header_id|>assistant<|end_header_id|>

+ 17 - 0
server/testdata/tools/firefunction.out

@@ -0,0 +1,17 @@
+<|start_header_id|>system<|end_header_id|>
+You are a knowledgable assistant. You can answer questions and perform tasks.
+In addition to plain text responses, you can chose to call one or more of the provided functions.
+
+Use the following rule to decide when to call a function:
+  * if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
+  * if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
+
+If you decide to call functions:
+  * prefix function calls with functools marker (no closing marker required)
+  * all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
+  * follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
+  * respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
+  * make sure you pick the right functions that match the user intent
+
+Available functions as JSON spec:
+[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]<|eot_id|><|start_header_id|><|end_header_id|>You are a knowledgable assistant. You can answer questions and perform tasks.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|> functools[{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]<|eot_id|><|start_header_id|>tool<|end_header_id|>22<|eot_id|><|start_header_id|>assistant<|end_header_id|>The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

+ 39 - 0
server/testdata/tools/messages.json

@@ -0,0 +1,39 @@
+[
+  {
+    "role": "system",
+    "content": "You are a knowledgable assistant. You can answer questions and perform tasks."
+  },
+  {
+    "role": "user",
+    "content": "What's the weather like today in Paris?"
+  },
+  {
+    "role": "assistant",
+    "tool_calls": [
+      {
+        "id": "89a1e453-0bce-4de3-a456-c54bed09c520",
+        "type": "function",
+        "function": {
+          "name": "get_current_weather",
+          "arguments": {
+            "location": "Paris, France",
+            "format": "celsius"
+          }
+        }
+      }
+    ]
+  },
+  {
+    "role": "tool",
+    "tool_call_id": "89a1e453-0bce-4de3-a456-c54bed09c520",
+    "content": "22"
+  },
+  {
+    "role": "assistant",
+    "content": "The current temperature in Paris, France is 22 degrees Celsius."
+  },
+  {
+    "role": "user",
+    "content": "What's the weather like today in San Francisco and Toronto?"
+  }
+]

+ 15 - 0
server/testdata/tools/mistral.gotmpl

@@ -0,0 +1,15 @@
+{{- range $index, $_ := .Messages }}
+{{- if eq .Role "user" }}
+{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS]
+{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
+
+{{ end }}{{ .Content }}[/INST]
+{{- else if eq .Role "assistant" }}
+{{- if .Content }} {{ .Content }}</s>
+{{- else if .ToolCalls }}[TOOL_CALLS] [
+{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
+{{- end }}]</s>
+{{- end }}
+{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
+{{- end }}
+{{- end }}

+ 3 - 0
server/testdata/tools/mistral.out

@@ -0,0 +1,3 @@
+[INST] What's the weather like today in Paris?[/INST][TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]</s>[TOOL_RESULTS] {"content": 22}[/TOOL_RESULTS] The current temperature in Paris, France is 22 degrees Celsius.</s>[AVAILABLE_TOOLS] [{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}][/AVAILABLE_TOOLS][INST] You are a knowledgable assistant. You can answer questions and perform tasks.
+
+What's the weather like today in San Francisco and Toronto?[/INST]

+ 30 - 0
server/testdata/tools/tools.json

@@ -0,0 +1,30 @@
+[
+  {
+    "type": "function",
+    "function": {
+      "name": "get_current_weather",
+      "description": "Get the current weather",
+      "parameters": {
+        "type": "object",
+        "properties": {
+          "location": {
+            "type": "string",
+            "description": "The city and state, e.g. San Francisco, CA"
+          },
+          "format": {
+            "type": "string",
+            "enum": [
+              "celsius",
+              "fahrenheit"
+            ],
+            "description": "The temperature unit to use. Infer this from the users location."
+          }
+        },
+        "required": [
+          "location",
+          "format"
+        ]
+      }
+    }
+  }
+]

+ 85 - 38
template/template.go

@@ -102,8 +102,15 @@ var response = parse.ActionNode{
 	},
 }
 
+var funcs = template.FuncMap{
+	"json": func(v any) string {
+		b, _ := json.Marshal(v)
+		return string(b)
+	},
+}
+
 func Parse(s string) (*Template, error) {
-	tmpl := template.New("").Option("missingkey=zero")
+	tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
 
 	tmpl, err := tmpl.Parse(s)
 	if err != nil {
@@ -127,7 +134,7 @@ func (t *Template) Vars() []string {
 	var vars []string
 	for _, tt := range t.Templates() {
 		for _, n := range tt.Root.Nodes {
-			vars = append(vars, parseNode(n)...)
+			vars = append(vars, Identifiers(n)...)
 		}
 	}
 
@@ -143,17 +150,65 @@ func (t *Template) Vars() []string {
 
 type Values struct {
 	Messages []api.Message
+	Tools    []api.Tool
 
 	// forceLegacy is a flag used to test compatibility with legacy templates
 	forceLegacy bool
 }
 
+func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
+	var walk func(parse.Node) parse.Node
+	walk = func(n parse.Node) parse.Node {
+		if fn(n) {
+			return n
+		}
+
+		switch t := n.(type) {
+		case *parse.ListNode:
+			for _, c := range t.Nodes {
+				if n := walk(c); n != nil {
+					return n
+				}
+			}
+		case *parse.BranchNode:
+			for _, n := range []*parse.ListNode{t.List, t.ElseList} {
+				if n != nil {
+					if n := walk(n); n != nil {
+						return n
+					}
+				}
+			}
+		case *parse.IfNode:
+			return walk(&t.BranchNode)
+		case *parse.WithNode:
+			return walk(&t.BranchNode)
+		case *parse.RangeNode:
+			return walk(&t.BranchNode)
+		}
+
+		return nil
+	}
+
+	if n := walk(t.Tree.Root); n != nil {
+		return (&template.Template{
+			Tree: &parse.Tree{
+				Root: &parse.ListNode{
+					Nodes: []parse.Node{n},
+				},
+			},
+		}).Funcs(funcs)
+	}
+
+	return nil
+}
+
 func (t *Template) Execute(w io.Writer, v Values) error {
 	system, messages := collate(v.Messages)
 	if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
 		return t.Template.Execute(w, map[string]any{
 			"System":   system,
 			"Messages": messages,
+			"Tools":    v.Tools,
 		})
 	}
 
@@ -161,7 +216,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
 	var b bytes.Buffer
 	var prompt, response string
 	for _, m := range messages {
-		execute := func () error {
+		execute := func() error {
 			if err := t.Template.Execute(&b, map[string]any{
 				"System":   system,
 				"Prompt":   prompt,
@@ -198,12 +253,8 @@ func (t *Template) Execute(w io.Writer, v Values) error {
 
 	var cut bool
 	nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
-		switch t := n.(type) {
-		case *parse.ActionNode:
-		case *parse.FieldNode:
-			if slices.Contains(t.Ident, "Response") {
-				cut = true
-			}
+		if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
+			cut = true
 		}
 
 		return cut
@@ -255,50 +306,46 @@ func collate(msgs []api.Message) (string, []*api.Message) {
 	return strings.Join(system, "\n\n"), collated
 }
 
-func parseNode(n parse.Node) []string {
+// Identifiers walks the node tree returning any identifiers it finds along the way
+func Identifiers(n parse.Node) []string {
 	switch n := n.(type) {
-	case *parse.ActionNode:
-		return parseNode(n.Pipe)
-	case *parse.IfNode:
-		names := parseNode(n.Pipe)
-		names = append(names, parseNode(n.List)...)
-		if n.ElseList != nil {
-			names = append(names, parseNode(n.ElseList)...)
+	case *parse.ListNode:
+		var names []string
+		for _, n := range n.Nodes {
+			names = append(names, Identifiers(n)...)
 		}
+
 		return names
-	case *parse.RangeNode:
-		names := parseNode(n.Pipe)
-		names = append(names, parseNode(n.List)...)
-		if n.ElseList != nil {
-			names = append(names, parseNode(n.ElseList)...)
+	case *parse.TemplateNode:
+		return Identifiers(n.Pipe)
+	case *parse.ActionNode:
+		return Identifiers(n.Pipe)
+	case *parse.BranchNode:
+		names := Identifiers(n.Pipe)
+		for _, n := range []*parse.ListNode{n.List, n.ElseList} {
+			if n != nil {
+				names = append(names, Identifiers(n)...)
+			}
 		}
 		return names
+	case *parse.IfNode:
+		return Identifiers(&n.BranchNode)
+	case *parse.RangeNode:
+		return Identifiers(&n.BranchNode)
 	case *parse.WithNode:
-		names := parseNode(n.Pipe)
-		names = append(names, parseNode(n.List)...)
-		if n.ElseList != nil {
-			names = append(names, parseNode(n.ElseList)...)
-		}
-		return names
+		return Identifiers(&n.BranchNode)
 	case *parse.PipeNode:
 		var names []string
 		for _, c := range n.Cmds {
 			for _, a := range c.Args {
-				names = append(names, parseNode(a)...)
+				names = append(names, Identifiers(a)...)
 			}
 		}
-		return names
-	case *parse.ListNode:
-		var names []string
-		for _, n := range n.Nodes {
-			names = append(names, parseNode(n)...)
-		}
-
 		return names
 	case *parse.FieldNode:
 		return n.Ident
-	case *parse.TemplateNode:
-		return parseNode(n.Pipe)
+	case *parse.VariableNode:
+		return n.Ident
 	}
 
 	return nil