Browse Source

Merge pull request #5726 from ollama/mxyng/tools-templates

fix unmarshal type errors
Michael Yang 9 months ago
parent
commit
a8388beb94
2 changed files with 39 additions and 42 deletions
  1. 18 30
      server/model.go
  2. 21 12
      server/model_test.go

+ 18 - 30
server/model.go

@@ -327,7 +327,8 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
 
 
 	var kv map[string]string
 	var kv map[string]string
 	// execute the subtree with placeholders to identify the keys
 	// execute the subtree with placeholders to identify the keys
-	if err := json.Unmarshal(b.Bytes(), &kv); err != nil {
+	// trim any commands that might exist in the template
+	if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
 		return nil, false
 		return nil, false
 	}
 	}
 
 
@@ -342,35 +343,26 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
 		}
 		}
 	}
 	}
 
 
-	var sm []map[string]any
-	decoder := json.NewDecoder(strings.NewReader(s))
-	for {
-		// incrementally decode the JSON into a list of JSON objects
-		// skipping over any invalid tokens
-		if err := decoder.Decode(&sm); err != nil {
-			if errors.Is(err, io.EOF) {
-				break
-			}
-
-			if errors.As(err, new(*json.SyntaxError)) {
-				r := decoder.Buffered()
-				if _, err := r.Read(make([]byte, decoder.InputOffset()+1)); err != nil {
-					break
-				}
-
-				decoder = json.NewDecoder(r)
-				continue
-			}
-
+	var objs []map[string]any
+	for offset := 0; offset < len(s); {
+		if err := json.NewDecoder(strings.NewReader(s[offset:])).Decode(&objs); errors.Is(err, io.EOF) {
+			break
+		} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
+			// skip over any syntax errors
+			offset += int(syntax.Offset)
+		} else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) {
+			// skip over any unmarshalable types
+			offset += int(unmarshalType.Offset)
+		} else if err != nil {
 			return nil, false
 			return nil, false
+		} else {
+			// break when an object is decoded
+			break
 		}
 		}
-
-		// break as soon as a valid object is decoded
-		break
 	}
 	}
 
 
 	var toolCalls []api.ToolCall
 	var toolCalls []api.ToolCall
-	for _, kv := range sm {
+	for _, kv := range objs {
 		call := api.ToolCall{
 		call := api.ToolCall{
 			ID:   uuid.New().String(),
 			ID:   uuid.New().String(),
 			Type: "function",
 			Type: "function",
@@ -388,9 +380,5 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
 		toolCalls = append(toolCalls, call)
 		toolCalls = append(toolCalls, call)
 	}
 	}
 
 
-	if len(toolCalls) > 0 {
-		return toolCalls, true
-	}
-
-	return nil, false
+	return toolCalls, len(toolCalls) > 0
 }
 }

+ 21 - 12
server/model_test.go

@@ -136,11 +136,16 @@ func TestExecuteWithTools(t *testing.T) {
 	cases := []struct {
 	cases := []struct {
 		model  string
 		model  string
 		output string
 		output string
+		ok     bool
 	}{
 	}{
-		{"mistral", `[TOOL_CALLS]  [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`},
+		{"mistral", `[TOOL_CALLS]  [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
 		{"mistral", `[TOOL_CALLS]  [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
 		{"mistral", `[TOOL_CALLS]  [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
 
 
-The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`},
+The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
+		{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
+
+		[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
+		{"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
 		{"command-r-plus", "Action: ```json" + `
 		{"command-r-plus", "Action: ```json" + `
 [
 [
     {
     {
@@ -158,8 +163,10 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`}
         }
         }
     }
     }
 ]
 ]
-` + "```"},
-		{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`},
+` + "```", true},
+		{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
+		{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
+		{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
 	}
 	}
 
 
 	var tools []api.Tool
 	var tools []api.Tool
@@ -216,17 +223,19 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`}
 			t.Run("parse", func(t *testing.T) {
 			t.Run("parse", func(t *testing.T) {
 				m := &Model{Template: tmpl}
 				m := &Model{Template: tmpl}
 				actual, ok := m.parseToolCalls(tt.output)
 				actual, ok := m.parseToolCalls(tt.output)
-				if !ok {
-					t.Fatal("failed to parse tool calls")
+				if ok != tt.ok {
+					t.Fatalf("expected %t, got %t", tt.ok, ok)
 				}
 				}
 
 
-				for i := range actual {
-					// ID is randomly generated so clear it for comparison
-					actual[i].ID = ""
-				}
+				if tt.ok {
+					for i := range actual {
+						// ID is randomly generated so clear it for comparison
+						actual[i].ID = ""
+					}
 
 
-				if diff := cmp.Diff(actual, calls); diff != "" {
-					t.Errorf("mismatch (-got +want):\n%s", diff)
+					if diff := cmp.Diff(actual, calls); diff != "" {
+						t.Errorf("mismatch (-got +want):\n%s", diff)
+					}
 				}
 				}
 			})
 			})
 		})
 		})