model_test.go 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. package server
  2. import (
  3. "archive/zip"
  4. "bytes"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "os"
  10. "path/filepath"
  11. "slices"
  12. "strings"
  13. "testing"
  14. "github.com/google/go-cmp/cmp"
  15. "github.com/ollama/ollama/api"
  16. "github.com/ollama/ollama/template"
  17. )
  18. func createZipFile(t *testing.T, name string) *os.File {
  19. t.Helper()
  20. f, err := os.CreateTemp(t.TempDir(), "")
  21. if err != nil {
  22. t.Fatal(err)
  23. }
  24. zf := zip.NewWriter(f)
  25. defer zf.Close()
  26. zh, err := zf.CreateHeader(&zip.FileHeader{Name: name})
  27. if err != nil {
  28. t.Fatal(err)
  29. }
  30. if _, err := io.Copy(zh, bytes.NewReader([]byte(""))); err != nil {
  31. t.Fatal(err)
  32. }
  33. return f
  34. }
  35. func TestExtractFromZipFile(t *testing.T) {
  36. cases := []struct {
  37. name string
  38. expect []string
  39. err error
  40. }{
  41. {
  42. name: "good",
  43. expect: []string{"good"},
  44. },
  45. {
  46. name: strings.Join([]string{"path", "..", "to", "good"}, string(os.PathSeparator)),
  47. expect: []string{filepath.Join("to", "good")},
  48. },
  49. {
  50. name: strings.Join([]string{"path", "..", "to", "..", "good"}, string(os.PathSeparator)),
  51. expect: []string{"good"},
  52. },
  53. {
  54. name: strings.Join([]string{"path", "to", "..", "..", "good"}, string(os.PathSeparator)),
  55. expect: []string{"good"},
  56. },
  57. {
  58. name: strings.Join([]string{"..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "bad"}, string(os.PathSeparator)),
  59. err: zip.ErrInsecurePath,
  60. },
  61. {
  62. name: strings.Join([]string{"path", "..", "..", "to", "bad"}, string(os.PathSeparator)),
  63. err: zip.ErrInsecurePath,
  64. },
  65. }
  66. for _, tt := range cases {
  67. t.Run(tt.name, func(t *testing.T) {
  68. f := createZipFile(t, tt.name)
  69. defer f.Close()
  70. tempDir := t.TempDir()
  71. if err := extractFromZipFile(tempDir, f, func(api.ProgressResponse) {}); !errors.Is(err, tt.err) {
  72. t.Fatal(err)
  73. }
  74. var matches []string
  75. if err := filepath.Walk(tempDir, func(p string, fi os.FileInfo, err error) error {
  76. if err != nil {
  77. return err
  78. }
  79. if !fi.IsDir() {
  80. matches = append(matches, p)
  81. }
  82. return nil
  83. }); err != nil {
  84. t.Fatal(err)
  85. }
  86. var actual []string
  87. for _, match := range matches {
  88. rel, err := filepath.Rel(tempDir, match)
  89. if err != nil {
  90. t.Error(err)
  91. }
  92. actual = append(actual, rel)
  93. }
  94. if !slices.Equal(actual, tt.expect) {
  95. t.Fatalf("expected %d files, got %d", len(tt.expect), len(matches))
  96. }
  97. })
  98. }
  99. }
  100. func readFile(t *testing.T, base, name string) *bytes.Buffer {
  101. t.Helper()
  102. bts, err := os.ReadFile(filepath.Join(base, name))
  103. if err != nil {
  104. t.Fatal(err)
  105. }
  106. return bytes.NewBuffer(bts)
  107. }
  108. func TestExecuteWithTools(t *testing.T) {
  109. p := filepath.Join("testdata", "tools")
  110. cases := []struct {
  111. model string
  112. output string
  113. ok bool
  114. }{
  115. {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
  116. {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
  117. The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
  118. {"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
  119. [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
  120. {"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
  121. {"command-r-plus", "Action: ```json" + `
  122. [
  123. {
  124. "tool_name": "get_current_weather",
  125. "parameters": {
  126. "format": "fahrenheit",
  127. "location": "San Francisco, CA"
  128. }
  129. },
  130. {
  131. "tool_name": "get_current_weather",
  132. "parameters": {
  133. "format": "celsius",
  134. "location": "Toronto, Canada"
  135. }
  136. }
  137. ]
  138. ` + "```", true},
  139. {"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
  140. {"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
  141. {"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
  142. {"llama3-groq-tool-use", `<tool_call>
  143. {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
  144. {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
  145. </tool_call>`, true},
  146. {"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true},
  147. }
  148. var tools []api.Tool
  149. if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil {
  150. t.Fatal(err)
  151. }
  152. var messages []api.Message
  153. if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil {
  154. t.Fatal(err)
  155. }
  156. calls := []api.ToolCall{
  157. {
  158. Function: api.ToolCallFunction{
  159. Name: "get_current_weather",
  160. Arguments: api.ToolCallFunctionArguments{
  161. "format": "fahrenheit",
  162. "location": "San Francisco, CA",
  163. },
  164. },
  165. },
  166. {
  167. Function: api.ToolCallFunction{
  168. Name: "get_current_weather",
  169. Arguments: api.ToolCallFunctionArguments{
  170. "format": "celsius",
  171. "location": "Toronto, Canada",
  172. },
  173. },
  174. },
  175. }
  176. for _, tt := range cases {
  177. t.Run(tt.model, func(t *testing.T) {
  178. tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
  179. if err != nil {
  180. t.Fatal(err)
  181. }
  182. t.Run("template", func(t *testing.T) {
  183. var actual bytes.Buffer
  184. if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil {
  185. t.Fatal(err)
  186. }
  187. if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
  188. t.Errorf("mismatch (-got +want):\n%s", diff)
  189. }
  190. })
  191. t.Run("parse", func(t *testing.T) {
  192. m := &Model{Template: tmpl}
  193. actual, ok := m.parseToolCalls(tt.output)
  194. if ok != tt.ok {
  195. t.Fatalf("expected %t, got %t", tt.ok, ok)
  196. }
  197. if tt.ok {
  198. if diff := cmp.Diff(actual, calls); diff != "" {
  199. t.Errorf("mismatch (-got +want):\n%s", diff)
  200. }
  201. }
  202. })
  203. })
  204. }
  205. }