Michael Yang hace 10 meses
padre
commit
d02bbebb11
Se han modificado 7 ficheros con 262 adiciones y 52 borrados
  1. 36 3
      api/types.go
  2. 9 2
      server/images.go
  3. 105 0
      server/model.go
  4. 3 3
      server/prompt.go
  5. 1 1
      server/prompt_test.go
  6. 19 5
      server/routes.go
  7. 89 38
      template/template.go

+ 36 - 3
api/types.go

@@ -97,6 +97,9 @@ type ChatRequest struct {
 	// followin the request.
 	KeepAlive *Duration `json:"keep_alive,omitempty"`
 
+	// Tools is an optional list of tools the model has access to.
+	Tools []Tool `json:"tools,omitempty"`
+
 	// Options lists model-specific options.
 	Options map[string]interface{} `json:"options"`
 }
@@ -105,9 +108,36 @@ type ChatRequest struct {
 // role ("system", "user", or "assistant"), the content and an optional list
 // of images.
 type Message struct {
-	Role    string      `json:"role"`
-	Content string      `json:"content"`
-	Images  []ImageData `json:"images,omitempty"`
+	Role      string      `json:"role"`
+	Content   string      `json:"content,omitempty"`
+	Images    []ImageData `json:"images,omitempty"`
+	ToolCalls []ToolCall  `json:"tool_calls,omitempty"`
+}
+
+type ToolCall struct {
+	ID       string `json:"id"`
+	Type     string `json:"type"`
+	Function struct {
+		Name      string         `json:"name"`
+		Arguments map[string]any `json:"arguments"`
+	} `json:"function"`
+}
+
+type Tool struct {
+	Type     string `json:"type"`
+	Function struct {
+		Name        string `json:"name"`
+		Description string `json:"description"`
+		Parameters  struct {
+			Type       string   `json:"type"`
+			Required   []string `json:"required"`
+			Properties map[string]struct {
+				Type        string   `json:"type"`
+				Description string   `json:"description"`
+				Enum        []string `json:"enum,omitempty"`
+			} `json:"properties"`
+		} `json:"parameters"`
+	} `json:"function"`
 }
 
 func (m *Message) UnmarshalJSON(b []byte) error {
@@ -374,6 +404,9 @@ type GenerateResponse struct {
 	// Response is the textual response itself.
 	Response string `json:"response"`
 
+	// ToolCalls is the list of tools the model wants to call
+	ToolCalls []ToolCall `json:"tool_calls,omitempty"`
+
 	// Done specifies if the response is complete.
 	Done bool `json:"done"`
 

+ 9 - 2
server/images.go

@@ -38,7 +38,10 @@ var errCapabilityCompletion = errors.New("completion")
 
 type Capability string
 
-const CapabilityCompletion = Capability("completion")
+const (
+	CapabilityCompletion = Capability("completion")
+	CapabilityTools      = Capability("tools")
+)
 
 type registryOptions struct {
 	Insecure bool
@@ -88,6 +91,10 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
 			if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
 				errs = append(errs, errCapabilityCompletion)
 			}
+		case CapabilityTools:
+			if !slices.Contains(m.Template.Vars(), "tools") {
+				errs = append(errs, errors.New("tools"))
+			}
 		default:
 			slog.Error("unknown capability", "capability", cap)
 			return fmt.Errorf("unknown capability: %s", cap)
@@ -95,7 +102,7 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
 	}
 
 	if err := errors.Join(errs...); err != nil {
-		return fmt.Errorf("missing capabilities: %w", errors.Join(errs...))
+		return fmt.Errorf("does not support %w", errors.Join(errs...))
 	}
 
 	return nil

+ 105 - 0
server/model.go

@@ -4,6 +4,7 @@ import (
 	"archive/zip"
 	"bytes"
 	"context"
+	"encoding/json"
 	"errors"
 	"fmt"
 	"io"
@@ -11,7 +12,11 @@ import (
 	"net/http"
 	"os"
 	"path/filepath"
+	"slices"
+	"strings"
+	"text/template/parse"
 
+	"github.com/google/uuid"
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/convert"
 	"github.com/ollama/ollama/llm"
@@ -289,3 +294,103 @@ func detectContentType(r io.Reader) (string, error) {
 
 	return "unknown", nil
 }
+
+// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
+// mxyng: this only really works if the input contains tool calls in some JSON format
+func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
+	// create a subtree from the node that ranges over .ToolCalls
+	tmpl := m.Template.Subtree(func(n parse.Node) bool {
+		if t, ok := n.(*parse.RangeNode); ok {
+			return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
+		}
+
+		return false
+	})
+
+	if tmpl == nil {
+		return nil, false
+	}
+
+	var b bytes.Buffer
+	if err := tmpl.Execute(&b, map[string][]map[string]any{
+		"ToolCalls": {
+			{
+				"Function": map[string]any{
+					"Name":      "@@name@@",
+					"Arguments": "@@arguments@@",
+				},
+			},
+		},
+	}); err != nil {
+		return nil, false
+	}
+
+	var kv map[string]string
+	// execute the subtree with placeholders to identify the keys
+	if err := json.Unmarshal(b.Bytes(), &kv); err != nil {
+		return nil, false
+	}
+
+	// find the keys that correspond to the name and arguments fields
+	var name, arguments string
+	for k, v := range kv {
+		switch v {
+		case "@@name@@":
+			name = k
+		case "@@arguments@@":
+			arguments = k
+		}
+	}
+
+	var sm []map[string]any
+	decoder := json.NewDecoder(strings.NewReader(s))
+	for {
+		// incrementally decode the JSON into a list of JSON objects
+		// skipping over any invalid tokens
+		if err := decoder.Decode(&sm); err != nil {
+			if errors.Is(err, io.EOF) {
+				break
+			}
+
+			if errors.As(err, new(*json.SyntaxError)) {
+				r := decoder.Buffered()
+				if _, err := r.Read(make([]byte, decoder.InputOffset()+1)); err != nil {
+					break
+				}
+
+				decoder = json.NewDecoder(r)
+				continue
+			}
+
+			return nil, false
+		}
+
+		// break as soon as a valid object is decoded
+		break
+	}
+
+	var toolCalls []api.ToolCall
+	for _, kv := range sm {
+		call := api.ToolCall{
+			ID:   uuid.New().String(),
+			Type: "function",
+		}
+
+		for k, v := range kv {
+			switch k {
+			case name:
+				call.Function.Name = v.(string)
+			case arguments:
+				call.Function.Arguments = v.(map[string]any)
+			}
+		}
+
+		toolCalls = append(toolCalls, call)
+	}
+
+	if len(toolCalls) > 0 {
+		return toolCalls, true
+	}
+
+	return nil, false
+}

+ 3 - 3
server/prompt.go

@@ -15,7 +15,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
 // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
 // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
 // latest message and 2) system messages
-func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
+func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
 	var system []api.Message
 	// always include the last message
 	n := len(msgs) - 1
@@ -29,7 +29,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
 		}
 
 		var b bytes.Buffer
-		if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
+		if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
 			return "", nil, err
 		}
 
@@ -57,7 +57,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
 
 	// truncate any messages that do not fit into the context window
 	var b bytes.Buffer
-	if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
+	if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil {
 		return "", nil, err
 	}
 

+ 1 - 1
server/prompt_test.go

@@ -192,7 +192,7 @@ func TestChatPrompt(t *testing.T) {
 		t.Run(tt.name, func(t *testing.T) {
 			model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
 			opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
-			prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs)
+			prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs, nil)
 			if err != nil {
 				t.Fatal(err)
 			}

+ 19 - 5
server/routes.go

@@ -265,6 +265,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		}
 
 		r.Response = sb.String()
+		if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
+			r.ToolCalls = toolCalls
+			r.Response = ""
+		}
+
 		c.JSON(http.StatusOK, r)
 		return
 	}
@@ -1279,6 +1284,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
 	}
 
 	caps := []Capability{CapabilityCompletion}
+	if req.Tools != nil {
+		caps = append(caps, CapabilityTools)
+	}
+
 	r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
 	if errors.Is(err, errCapabilityCompletion) {
 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
@@ -1305,7 +1314,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
 		req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...)
 	}
 
-	prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
+	prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages, req.Tools)
 	if err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
@@ -1348,13 +1357,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
 	}()
 
 	if req.Stream != nil && !*req.Stream {
-		var r api.ChatResponse
+		var resp api.ChatResponse
 		var sb strings.Builder
 		for rr := range ch {
 			switch t := rr.(type) {
 			case api.ChatResponse:
 				sb.WriteString(t.Message.Content)
-				r = t
+				resp = t
 			case gin.H:
 				msg, ok := t["error"].(string)
 				if !ok {
@@ -1369,8 +1378,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
 			}
 		}
 
-		r.Message.Content = sb.String()
-		c.JSON(http.StatusOK, r)
+		resp.Message.Content = sb.String()
+		if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
+			resp.Message.ToolCalls = toolCalls
+			resp.Message.Content = ""
+		}
+
+		c.JSON(http.StatusOK, resp)
 		return
 	}
 

+ 89 - 38
template/template.go

@@ -13,6 +13,7 @@ import (
 	"sync"
 	"text/template"
 	"text/template/parse"
+	"time"
 
 	"github.com/agnivade/levenshtein"
 	"github.com/ollama/ollama/api"
@@ -102,8 +103,18 @@ var response = parse.ActionNode{
 	},
 }
 
+var funcs = template.FuncMap{
+	"json": func(v any) string {
+		b, _ := json.Marshal(v)
+		return string(b)
+	},
+	"now": func() string {
+		return time.Now().Format("2006-01-02 15:04:05")
+	},
+}
+
 func Parse(s string) (*Template, error) {
-	tmpl := template.New("").Option("missingkey=zero")
+	tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
 
 	tmpl, err := tmpl.Parse(s)
 	if err != nil {
@@ -127,7 +138,7 @@ func (t *Template) Vars() []string {
 	var vars []string
 	for _, tt := range t.Templates() {
 		for _, n := range tt.Root.Nodes {
-			vars = append(vars, parseNode(n)...)
+			vars = append(vars, Identifiers(n)...)
 		}
 	}
 
@@ -143,17 +154,65 @@ func (t *Template) Vars() []string {
 
 type Values struct {
 	Messages []api.Message
+	Tools    []api.Tool
 
 	// forceLegacy is a flag used to test compatibility with legacy templates
 	forceLegacy bool
 }
 
+func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
+	var walk func(parse.Node) parse.Node
+	walk = func(n parse.Node) parse.Node {
+		if fn(n) {
+			return n
+		}
+
+		switch t := n.(type) {
+		case *parse.ListNode:
+			for _, c := range t.Nodes {
+				if n := walk(c); n != nil {
+					return n
+				}
+			}
+		case *parse.BranchNode:
+			for _, n := range []*parse.ListNode{t.List, t.ElseList} {
+				if n != nil {
+					if n := walk(n); n != nil {
+						return n
+					}
+				}
+			}
+		case *parse.IfNode:
+			return walk(&t.BranchNode)
+		case *parse.WithNode:
+			return walk(&t.BranchNode)
+		case *parse.RangeNode:
+			return walk(&t.BranchNode)
+		}
+
+		return nil
+	}
+
+	if n := walk(t.Tree.Root); n != nil {
+		return (&template.Template{
+			Tree: &parse.Tree{
+				Root: &parse.ListNode{
+					Nodes: []parse.Node{n},
+				},
+			},
+		}).Funcs(funcs)
+	}
+
+	return nil
+}
+
 func (t *Template) Execute(w io.Writer, v Values) error {
 	system, messages := collate(v.Messages)
 	if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
 		return t.Template.Execute(w, map[string]any{
 			"System":   system,
 			"Messages": messages,
+			"Tools":    v.Tools,
 		})
 	}
 
@@ -161,7 +220,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
 	var b bytes.Buffer
 	var prompt, response string
 	for _, m := range messages {
-		execute := func () error {
+		execute := func() error {
 			if err := t.Template.Execute(&b, map[string]any{
 				"System":   system,
 				"Prompt":   prompt,
@@ -198,12 +257,8 @@ func (t *Template) Execute(w io.Writer, v Values) error {
 
 	var cut bool
 	nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
-		switch t := n.(type) {
-		case *parse.ActionNode:
-		case *parse.FieldNode:
-			if slices.Contains(t.Ident, "Response") {
-				cut = true
-			}
+		if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
+			cut = true
 		}
 
 		return cut
@@ -255,50 +310,46 @@ func collate(msgs []api.Message) (string, []*api.Message) {
 	return strings.Join(system, "\n\n"), collated
 }
 
-func parseNode(n parse.Node) []string {
+// Identifiers walks the node tree returning any identifiers it finds along the way
+func Identifiers(n parse.Node) []string {
 	switch n := n.(type) {
-	case *parse.ActionNode:
-		return parseNode(n.Pipe)
-	case *parse.IfNode:
-		names := parseNode(n.Pipe)
-		names = append(names, parseNode(n.List)...)
-		if n.ElseList != nil {
-			names = append(names, parseNode(n.ElseList)...)
+	case *parse.ListNode:
+		var names []string
+		for _, n := range n.Nodes {
+			names = append(names, Identifiers(n)...)
 		}
+
 		return names
-	case *parse.RangeNode:
-		names := parseNode(n.Pipe)
-		names = append(names, parseNode(n.List)...)
-		if n.ElseList != nil {
-			names = append(names, parseNode(n.ElseList)...)
+	case *parse.TemplateNode:
+		return Identifiers(n.Pipe)
+	case *parse.ActionNode:
+		return Identifiers(n.Pipe)
+	case *parse.BranchNode:
+		names := Identifiers(n.Pipe)
+		for _, n := range []*parse.ListNode{n.List, n.ElseList} {
+			if n != nil {
+				names = append(names, Identifiers(n)...)
+			}
 		}
 		return names
+	case *parse.IfNode:
+		return Identifiers(&n.BranchNode)
+	case *parse.RangeNode:
+		return Identifiers(&n.BranchNode)
 	case *parse.WithNode:
-		names := parseNode(n.Pipe)
-		names = append(names, parseNode(n.List)...)
-		if n.ElseList != nil {
-			names = append(names, parseNode(n.ElseList)...)
-		}
-		return names
+		return Identifiers(&n.BranchNode)
 	case *parse.PipeNode:
 		var names []string
 		for _, c := range n.Cmds {
 			for _, a := range c.Args {
-				names = append(names, parseNode(a)...)
+				names = append(names, Identifiers(a)...)
 			}
 		}
-		return names
-	case *parse.ListNode:
-		var names []string
-		for _, n := range n.Nodes {
-			names = append(names, parseNode(n)...)
-		}
-
 		return names
 	case *parse.FieldNode:
 		return n.Ident
-	case *parse.TemplateNode:
-		return parseNode(n.Pipe)
+	case *parse.VariableNode:
+		return n.Ident
 	}
 
 	return nil