Browse Source

marshal json automatically for some template values (#5758)

Michael Yang 9 months ago
parent
commit
b255445557

+ 48 - 25
api/types.go

@@ -101,12 +101,19 @@ type ChatRequest struct {
 	KeepAlive *Duration `json:"keep_alive,omitempty"`
 
 	// Tools is an optional list of tools the model has access to.
-	Tools []Tool `json:"tools,omitempty"`
+	Tools `json:"tools,omitempty"`
 
 	// Options lists model-specific options.
 	Options map[string]interface{} `json:"options"`
 }
 
+type Tools []Tool
+
+func (t Tools) String() string {
+	bts, _ := json.Marshal(t)
+	return string(bts)
+}
+
 // Message is a single message in a chat sequence. The message contains the
 // role ("system", "user", or "assistant"), the content and an optional list
 // of images.
@@ -117,30 +124,6 @@ type Message struct {
 	ToolCalls []ToolCall  `json:"tool_calls,omitempty"`
 }
 
-type ToolCall struct {
-	Function struct {
-		Name      string         `json:"name"`
-		Arguments map[string]any `json:"arguments"`
-	} `json:"function"`
-}
-
-type Tool struct {
-	Type     string `json:"type"`
-	Function struct {
-		Name        string `json:"name"`
-		Description string `json:"description"`
-		Parameters  struct {
-			Type       string   `json:"type"`
-			Required   []string `json:"required"`
-			Properties map[string]struct {
-				Type        string   `json:"type"`
-				Description string   `json:"description"`
-				Enum        []string `json:"enum,omitempty"`
-			} `json:"properties"`
-		} `json:"parameters"`
-	} `json:"function"`
-}
-
 func (m *Message) UnmarshalJSON(b []byte) error {
 	type Alias Message
 	var a Alias
@@ -153,6 +136,46 @@ func (m *Message) UnmarshalJSON(b []byte) error {
 	return nil
 }
 
+type ToolCall struct {
+	Function ToolCallFunction `json:"function"`
+}
+
+type ToolCallFunction struct {
+	Name      string                    `json:"name"`
+	Arguments ToolCallFunctionArguments `json:"arguments"`
+}
+
+type ToolCallFunctionArguments map[string]any
+
+func (t *ToolCallFunctionArguments) String() string {
+	bts, _ := json.Marshal(t)
+	return string(bts)
+}
+
+type Tool struct {
+	Type     string       `json:"type"`
+	Function ToolFunction `json:"function"`
+}
+
+type ToolFunction struct {
+	Name        string `json:"name"`
+	Description string `json:"description"`
+	Parameters  struct {
+		Type       string   `json:"type"`
+		Required   []string `json:"required"`
+		Properties map[string]struct {
+			Type        string   `json:"type"`
+			Description string   `json:"description"`
+			Enum        []string `json:"enum,omitempty"`
+		} `json:"properties"`
+	} `json:"parameters"`
+}
+
+func (t *ToolFunction) String() string {
+	bts, _ := json.Marshal(t)
+	return string(bts)
+}
+
 // ChatResponse is the response returned by [Client.Chat]. Its fields are
 // similar to [GenerateResponse].
 type ChatResponse struct {

+ 10 - 8
server/model.go

@@ -311,12 +311,14 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
 	}
 
 	var b bytes.Buffer
-	if err := tmpl.Execute(&b, map[string][]map[string]any{
+	if err := tmpl.Execute(&b, map[string][]api.ToolCall{
 		"ToolCalls": {
 			{
-				"Function": map[string]any{
-					"Name":      "@@name@@",
-					"Arguments": "@@arguments@@",
+				Function: api.ToolCallFunction{
+					Name: "@@name@@",
+					Arguments: api.ToolCallFunctionArguments{
+						"@@argument@@": 1,
+					},
 				},
 			},
 		},
@@ -324,7 +326,7 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
 		return nil, false
 	}
 
-	var kv map[string]string
+	var kv map[string]any
 	// execute the subtree with placeholders to identify the keys
 	// trim any commands that might exist in the template
 	if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
@@ -334,10 +336,10 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
 	// find the keys that correspond to the name and arguments fields
 	var name, arguments string
 	for k, v := range kv {
-		switch v {
-		case "@@name@@":
+		switch v.(type) {
+		case string:
 			name = k
-		case "@@arguments@@":
+		case map[string]any:
 			arguments = k
 		}
 	}

+ 4 - 9
server/model_test.go

@@ -115,11 +115,6 @@ func TestExtractFromZipFile(t *testing.T) {
 	}
 }
 
-type function struct {
-	Name      string         `json:"name"`
-	Arguments map[string]any `json:"arguments"`
-}
-
 func readFile(t *testing.T, base, name string) *bytes.Buffer {
 	t.Helper()
 
@@ -185,18 +180,18 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
 
 	calls := []api.ToolCall{
 		{
-			Function: function{
+			Function: api.ToolCallFunction{
 				Name: "get_current_weather",
-				Arguments: map[string]any{
+				Arguments: api.ToolCallFunctionArguments{
 					"format":   "fahrenheit",
 					"location": "San Francisco, CA",
 				},
 			},
 		},
 		{
-			Function: function{
+			Function: api.ToolCallFunction{
 				Name: "get_current_weather",
-				Arguments: map[string]any{
+				Arguments: api.ToolCallFunctionArguments{
 					"format":   "celsius",
 					"location": "Toronto, Canada",
 				},

+ 1 - 1
server/testdata/tools/command-r-plus.gotmpl

@@ -46,7 +46,7 @@ Action: ```json
 {{- range .ToolCalls }}
     {
         "tool_name": "{{ .Function.Name }}",
-        "parameters": {{ json .Function.Arguments }}
+        "parameters": {{ .Function.Arguments }}
     }
 {{- end }}
 ]```

+ 2 - 2
server/testdata/tools/firefunction.gotmpl

@@ -17,7 +17,7 @@ If you decide to call functions:
 
 Available functions as JSON spec:
 {{- if .Tools }}
-{{ json .Tools }}
+{{ .Tools }}
 {{- end }}<|eot_id|>
 {{- end }}
 {{- range .Messages }}<|start_header_id|>
@@ -25,7 +25,7 @@ Available functions as JSON spec:
 {{- end }}<|end_header_id|>
 {{- if .Content }}{{ .Content }}
 {{- else if .ToolCalls }} functools[
-{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}{{ "}" }}
+{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}{{ "}" }}
 {{- end }}]
 {{- end }}<|eot_id|>
 {{- end }}<|start_header_id|>assistant<|end_header_id|>

+ 2 - 2
server/testdata/tools/llama3-groq-tool-use.gotmpl

@@ -9,7 +9,7 @@
 
 Here are the available tools:
 <tools>
-{{- range .Tools }} {{ json .Function }}
+{{- range .Tools }} {{ .Function }}
 {{- end }} </tools>
 {{- end }}
 {{- end }}<|eot_id|>
@@ -20,7 +20,7 @@ Here are the available tools:
 {{- else if eq .Role "assistant" }}
 {{- if .Content }}{{ .Content }}
 {{- else if .ToolCalls }}<tool_call>
-{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
+{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
 {{- end }}
 </tool_call>
 {{- end }}

+ 2 - 2
server/testdata/tools/mistral.gotmpl

@@ -1,13 +1,13 @@
 {{- range $index, $_ := .Messages }}
 {{- if eq .Role "user" }}
-{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS]
+{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ $.Tools }}[/AVAILABLE_TOOLS]
 {{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
 
 {{ end }}{{ .Content }}[/INST]
 {{- else if eq .Role "assistant" }}
 {{- if .Content }} {{ .Content }}</s>
 {{- else if .ToolCalls }}[TOOL_CALLS] [
-{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
+{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
 {{- end }}]</s>
 {{- end }}
 {{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]

+ 3 - 3
template/template.go

@@ -150,9 +150,9 @@ func (t *Template) Vars() []string {
 
 type Values struct {
 	Messages []api.Message
-	Tools    []api.Tool
-	Prompt   string
-	Suffix   string
+	api.Tools
+	Prompt string
+	Suffix string
 
 	// forceLegacy is a flag used to test compatibility with legacy templates
 	forceLegacy bool