model_test.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. package server
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "os"
  7. "path/filepath"
  8. "testing"
  9. "github.com/google/go-cmp/cmp"
  10. "github.com/ollama/ollama/api"
  11. "github.com/ollama/ollama/template"
  12. )
  13. func readFile(t *testing.T, base, name string) *bytes.Buffer {
  14. t.Helper()
  15. bts, err := os.ReadFile(filepath.Join(base, name))
  16. if err != nil {
  17. t.Fatal(err)
  18. }
  19. return bytes.NewBuffer(bts)
  20. }
  21. func TestExecuteWithTools(t *testing.T) {
  22. p := filepath.Join("testdata", "tools")
  23. cases := []struct {
  24. model string
  25. output string
  26. ok bool
  27. }{
  28. {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
  29. {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
  30. The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
  31. {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false},
  32. {"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
  33. [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
  34. {"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
  35. {"command-r-plus", "Action: ```json" + `
  36. [
  37. {
  38. "tool_name": "get_current_weather",
  39. "parameters": {
  40. "format": "fahrenheit",
  41. "location": "San Francisco, CA"
  42. }
  43. },
  44. {
  45. "tool_name": "get_current_weather",
  46. "parameters": {
  47. "format": "celsius",
  48. "location": "Toronto, Canada"
  49. }
  50. }
  51. ]
  52. ` + "```", true},
  53. {"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
  54. {"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
  55. {"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
  56. {"llama3-groq-tool-use", `<tool_call>
  57. {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
  58. {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
  59. </tool_call>`, true},
  60. {"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true},
  61. {"nemotron", `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} </toolcall>`, true},
  62. }
  63. var tools []api.Tool
  64. if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil {
  65. t.Fatal(err)
  66. }
  67. var messages []api.Message
  68. if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil {
  69. t.Fatal(err)
  70. }
  71. calls := []api.ToolCall{
  72. {
  73. Function: api.ToolCallFunction{
  74. Name: "get_current_weather",
  75. Arguments: api.ToolCallFunctionArguments{
  76. "format": "fahrenheit",
  77. "location": "San Francisco, CA",
  78. },
  79. },
  80. },
  81. {
  82. Function: api.ToolCallFunction{
  83. Name: "get_current_weather",
  84. Arguments: api.ToolCallFunctionArguments{
  85. "format": "celsius",
  86. "location": "Toronto, Canada",
  87. },
  88. },
  89. },
  90. }
  91. for _, tt := range cases {
  92. t.Run(tt.model, func(t *testing.T) {
  93. tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
  94. if err != nil {
  95. t.Fatal(err)
  96. }
  97. t.Run("template", func(t *testing.T) {
  98. var actual bytes.Buffer
  99. if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil {
  100. t.Fatal(err)
  101. }
  102. if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
  103. t.Errorf("mismatch (-got +want):\n%s", diff)
  104. }
  105. })
  106. t.Run("parse", func(t *testing.T) {
  107. m := &Model{Template: tmpl}
  108. actual, ok := m.parseToolCalls(tt.output)
  109. if ok != tt.ok {
  110. t.Fatalf("expected %t, got %t", tt.ok, ok)
  111. }
  112. if tt.ok {
  113. if diff := cmp.Diff(actual, calls); diff != "" {
  114. t.Errorf("mismatch (-got +want):\n%s", diff)
  115. }
  116. }
  117. })
  118. })
  119. }
  120. }
  121. func TestParseObjects(t *testing.T) {
  122. tests := []struct {
  123. input string
  124. want []map[string]any
  125. }{
  126. {
  127. input: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
  128. want: []map[string]any{
  129. {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
  130. {"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, Canada"}},
  131. },
  132. },
  133. {
  134. input: `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </toolcall>`,
  135. want: []map[string]any{
  136. {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
  137. },
  138. },
  139. {
  140. input: `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </toolcall> <toolcall>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, ON"}} </toolcall>`,
  141. want: []map[string]any{
  142. {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
  143. {"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, ON"}},
  144. },
  145. },
  146. {
  147. input: `{"name": "get_current_weather", "arguments": `,
  148. want: nil,
  149. },
  150. }
  151. for _, tc := range tests {
  152. t.Run(tc.input, func(t *testing.T) {
  153. got := parseObjects(tc.input)
  154. if diff := cmp.Diff(got, tc.want); diff != "" {
  155. t.Errorf("mismatch (-got +want):\n%s", diff)
  156. }
  157. })
  158. }
  159. }