model_test.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. package server
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "os"
  7. "path/filepath"
  8. "testing"
  9. "github.com/google/go-cmp/cmp"
  10. "github.com/ollama/ollama/api"
  11. "github.com/ollama/ollama/template"
  12. )
  13. func readFile(t *testing.T, base, name string) *bytes.Buffer {
  14. t.Helper()
  15. bts, err := os.ReadFile(filepath.Join(base, name))
  16. if err != nil {
  17. t.Fatal(err)
  18. }
  19. return bytes.NewBuffer(bts)
  20. }
  21. func TestExecuteWithTools(t *testing.T) {
  22. p := filepath.Join("testdata", "tools")
  23. cases := []struct {
  24. model string
  25. output string
  26. ok bool
  27. }{
  28. {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
  29. {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
  30. The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
  31. {"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
  32. [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
  33. {"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
  34. {"command-r-plus", "Action: ```json" + `
  35. [
  36. {
  37. "tool_name": "get_current_weather",
  38. "parameters": {
  39. "format": "fahrenheit",
  40. "location": "San Francisco, CA"
  41. }
  42. },
  43. {
  44. "tool_name": "get_current_weather",
  45. "parameters": {
  46. "format": "celsius",
  47. "location": "Toronto, Canada"
  48. }
  49. }
  50. ]
  51. ` + "```", true},
  52. {"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
  53. {"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
  54. {"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
  55. {"llama3-groq-tool-use", `<tool_call>
  56. {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
  57. {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
  58. </tool_call>`, true},
  59. {"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true},
  60. }
  61. var tools []api.Tool
  62. if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil {
  63. t.Fatal(err)
  64. }
  65. var messages []api.Message
  66. if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil {
  67. t.Fatal(err)
  68. }
  69. calls := []api.ToolCall{
  70. {
  71. Function: api.ToolCallFunction{
  72. Name: "get_current_weather",
  73. Arguments: api.ToolCallFunctionArguments{
  74. "format": "fahrenheit",
  75. "location": "San Francisco, CA",
  76. },
  77. },
  78. },
  79. {
  80. Function: api.ToolCallFunction{
  81. Name: "get_current_weather",
  82. Arguments: api.ToolCallFunctionArguments{
  83. "format": "celsius",
  84. "location": "Toronto, Canada",
  85. },
  86. },
  87. },
  88. }
  89. for _, tt := range cases {
  90. t.Run(tt.model, func(t *testing.T) {
  91. tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
  92. if err != nil {
  93. t.Fatal(err)
  94. }
  95. t.Run("template", func(t *testing.T) {
  96. var actual bytes.Buffer
  97. if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil {
  98. t.Fatal(err)
  99. }
  100. if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
  101. t.Errorf("mismatch (-got +want):\n%s", diff)
  102. }
  103. })
  104. t.Run("parse", func(t *testing.T) {
  105. m := &Model{Template: tmpl}
  106. actual, ok := m.parseToolCalls(tt.output)
  107. if ok != tt.ok {
  108. t.Fatalf("expected %t, got %t", tt.ok, ok)
  109. }
  110. if tt.ok {
  111. if diff := cmp.Diff(actual, calls); diff != "" {
  112. t.Errorf("mismatch (-got +want):\n%s", diff)
  113. }
  114. }
  115. })
  116. })
  117. }
  118. }