Jelajahi Sumber

server: add tool parsing support for nemotron-mini (#6849)

Jeffrey Morgan 7 bulan lalu
induk
melakukan
d05da29912
4 mengubah file dengan 144 tambahan dan 39 penghapusan
  1. 50 39
      server/model.go
  2. 43 0
      server/model_test.go
  3. 33 0
      server/testdata/tools/nemotron.gotmpl
  4. 18 0
      server/testdata/tools/nemotron.out

+ 50 - 39
server/model.go

@@ -272,6 +272,30 @@ func detectContentType(r io.Reader) (string, error) {
 	return "unknown", nil
 	return "unknown", nil
 }
 }
 
 
+func parseObjects(s string) []map[string]any {
+	var objs []map[string]any
+	for offset := 0; offset < len(s); {
+		var obj map[string]any
+		decoder := json.NewDecoder(strings.NewReader(s[offset:]))
+		if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
+			break
+		} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
+			// skip over any syntax errors
+			offset += int(syntax.Offset)
+		} else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) {
+			// skip over any unmarshalable types
+			offset += int(unmarshalType.Offset)
+		} else if err != nil {
+			return nil
+		} else {
+			offset += int(decoder.InputOffset())
+			objs = append(objs, obj)
+		}
+	}
+
+	return objs
+}
+
 // parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
 // parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
 // mxyng: this only really works if the input contains tool calls in some JSON format
 // mxyng: this only really works if the input contains tool calls in some JSON format
 func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
 func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
@@ -304,16 +328,14 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
 		return nil, false
 		return nil, false
 	}
 	}
 
 
-	var kv map[string]any
-	// execute the subtree with placeholders to identify the keys
-	// trim any commands that might exist in the template
-	if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
+	templateObjects := parseObjects(b.String())
+	if len(templateObjects) == 0 {
 		return nil, false
 		return nil, false
 	}
 	}
 
 
 	// find the keys that correspond to the name and arguments fields
 	// find the keys that correspond to the name and arguments fields
 	var name, arguments string
 	var name, arguments string
-	for k, v := range kv {
+	for k, v := range templateObjects[0] {
 		switch v.(type) {
 		switch v.(type) {
 		case string:
 		case string:
 			name = k
 			name = k
@@ -326,43 +348,32 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
 		return nil, false
 		return nil, false
 	}
 	}
 
 
-	var objs []map[string]any
-	for offset := 0; offset < len(s); {
-		var obj map[string]any
-		decoder := json.NewDecoder(strings.NewReader(s[offset:]))
-		if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
-			break
-		} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
-			// skip over any syntax errors
-			offset += int(syntax.Offset)
-		} else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) {
-			// skip over any unmarshalable types
-			offset += int(unmarshalType.Offset)
-		} else if err != nil {
-			slog.Error("parseToolCalls", "error", err)
-			return nil, false
-		} else {
-			offset += int(decoder.InputOffset())
-
-			// collect all nested objects
-			var collect func(any) []map[string]any
-			collect = func(obj any) (all []map[string]any) {
-				switch o := obj.(type) {
-				case map[string]any:
-					all = append(all, o)
-					for _, v := range o {
-						all = append(all, collect(v)...)
-					}
-				case []any:
-					for _, v := range o {
-						all = append(all, collect(v)...)
-					}
-				}
+	responseObjects := parseObjects(s)
+	if len(responseObjects) == 0 {
+		return nil, false
+	}
 
 
-				return all
+	// collect all nested objects
+	var collect func(any) []map[string]any
+	collect = func(obj any) (all []map[string]any) {
+		switch o := obj.(type) {
+		case map[string]any:
+			all = append(all, o)
+			for _, v := range o {
+				all = append(all, collect(v)...)
+			}
+		case []any:
+			for _, v := range o {
+				all = append(all, collect(v)...)
 			}
 			}
-			objs = append(objs, collect(obj)...)
 		}
 		}
+
+		return all
+	}
+
+	var objs []map[string]any
+	for _, p := range responseObjects {
+		objs = append(objs, collect(p)...)
 	}
 	}
 
 
 	var toolCalls []api.ToolCall
 	var toolCalls []api.ToolCall

+ 43 - 0
server/model_test.go

@@ -69,6 +69,7 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
 {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
 {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
 </tool_call>`, true},
 </tool_call>`, true},
 		{"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true},
 		{"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true},
+		{"nemotron", `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} </toolcall>`, true},
 	}
 	}
 
 
 	var tools []api.Tool
 	var tools []api.Tool
@@ -217,3 +218,45 @@ func TestParseLayerFromCopy(t *testing.T) {
 		t.Fatalf("got %d != want 5", len(layers))
 		t.Fatalf("got %d != want 5", len(layers))
 	}
 	}
 }
 }
+
+func TestParseObjects(t *testing.T) {
+	tests := []struct {
+		input string
+		want  []map[string]any
+	}{
+		{
+			input: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
+			want: []map[string]any{
+				{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
+				{"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, Canada"}},
+			},
+		},
+		{
+			input: `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </toolcall>`,
+			want: []map[string]any{
+				{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
+			},
+		},
+		{
+			input: `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </toolcall> <toolcall>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, ON"}} </toolcall>`,
+			want: []map[string]any{
+				{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
+				{"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, ON"}},
+			},
+		},
+		{
+			input: `{"name": "get_current_weather", "arguments": `,
+			want:  nil,
+		},
+	}
+
+	for _, tc := range tests {
+		t.Run(tc.input, func(t *testing.T) {
+			got := parseObjects(tc.input)
+
+			if diff := cmp.Diff(got, tc.want); diff != "" {
+				t.Errorf("mismatch (-got +want):\n%s", diff)
+			}
+		})
+	}
+}

+ 33 - 0
server/testdata/tools/nemotron.gotmpl

@@ -0,0 +1,33 @@
+{{- if (or .Tools .System) }}<extra_id_0>System
+{{ if .System }}{{ .System }}
+
+
+{{ end }}
+{{- if .Tools }}
+{{- range .Tools }}<tool> {{ . }} </tool>{{ end }}
+
+
+{{ end }}
+{{- end }}
+{{- range $i, $m := .Messages }}
+{{- $last := eq (len (slice $.Messages $i)) 1 -}}
+{{- if eq .Role "user" }}<extra_id_1>User
+{{ .Content }}
+{{- if $last }}
+<extra_id_1>Assistant
+{{- end }}
+{{ else if eq .Role "tool" }}<extra_id_1>Tool
+{{ .Content }}
+{{- if $last }}
+<extra_id_1>Assistant
+{{- end }}
+{{ else if eq .Role "assistant" }}<extra_id_1>Assistant
+{{- if .ToolCalls }}
+{{ range .ToolCalls }}<toolcall> {"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} </toolcall> {{ end }}
+{{ else }}
+{{ .Content }}
+{{- if not $last }}
+{{ end }}
+{{- end }}
+{{- end }}
+{{- end }}

+ 18 - 0
server/testdata/tools/nemotron.out

@@ -0,0 +1,18 @@
+<extra_id_0>System
+You are a knowledgable assistant. You can answer questions and perform tasks.
+
+
+<tool> {"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} </tool>
+
+
+<extra_id_1>User
+What's the weather like today in Paris?
+<extra_id_1>Assistant
+<toolcall> {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}} </toolcall> 
+<extra_id_1>Tool
+22
+<extra_id_1>Assistant
+The current temperature in Paris, France is 22 degrees Celsius.
+<extra_id_1>User
+What's the weather like today in San Francisco and Toronto?
+<extra_id_1>Assistant