grammar_test.go 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. package grammar
  2. import (
  3. "bufio"
  4. "cmp"
  5. "iter"
  6. "strings"
  7. "testing"
  8. _ "embed"
  9. "github.com/ollama/ollama/grammar/internal/diff"
  10. )
  11. func TestFromSchema(t *testing.T) {
  12. for tt := range testCases(t) {
  13. t.Run(tt.name, func(t *testing.T) {
  14. g, err := FromSchema(nil, []byte(tt.schema))
  15. if err != nil {
  16. t.Fatalf("FromSchema: %v", err)
  17. }
  18. got := string(g)
  19. got = strings.TrimPrefix(got, jsonTerms)
  20. if got != tt.want {
  21. t.Logf("schema:\n%s", tt.schema)
  22. t.Fatal(string(diff.Diff("got", []byte(got), "want", []byte(tt.want))))
  23. }
  24. })
  25. }
  26. }
  27. type testCase struct {
  28. name string
  29. schema string
  30. want string
  31. }
  32. //go:embed testdata/schemas.txt
  33. var tests string
  34. func testCases(t testing.TB) iter.Seq[testCase] {
  35. t.Helper()
  36. return func(yield func(testCase) bool) {
  37. t.Helper()
  38. sc := bufio.NewScanner(strings.NewReader(tests))
  39. name := ""
  40. for sc.Scan() {
  41. line := strings.TrimSpace(sc.Text())
  42. if line == "" {
  43. name = ""
  44. continue
  45. }
  46. if line[0] == '#' {
  47. name = cmp.Or(name, strings.TrimSpace(line[1:]))
  48. continue
  49. }
  50. s := sc.Text()
  51. g := ""
  52. for sc.Scan() {
  53. line = strings.TrimSpace(sc.Text())
  54. if line == "" || line[0] == '#' {
  55. break
  56. }
  57. g += sc.Text() + "\n"
  58. }
  59. if !yield(testCase{name, s, g}) {
  60. return
  61. }
  62. name = strings.TrimSpace(strings.TrimPrefix(line, "#"))
  63. }
  64. if err := sc.Err(); err != nil {
  65. t.Fatalf("error reading tests: %v", err)
  66. }
  67. }
  68. }