tokenizer_test.go 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. package convert
  2. import (
  3. "io"
  4. "io/fs"
  5. "os"
  6. "path/filepath"
  7. "strings"
  8. "testing"
  9. "github.com/google/go-cmp/cmp"
  10. )
  11. func createTokenizerFS(t *testing.T, dir string, files map[string]io.Reader) fs.FS {
  12. t.Helper()
  13. for k, v := range files {
  14. if err := func() error {
  15. f, err := os.Create(filepath.Join(dir, k))
  16. if err != nil {
  17. return err
  18. }
  19. defer f.Close()
  20. if _, err := io.Copy(f, v); err != nil {
  21. return err
  22. }
  23. return nil
  24. }(); err != nil {
  25. t.Fatalf("unexpected error: %v", err)
  26. }
  27. }
  28. return os.DirFS(dir)
  29. }
  30. func TestParseTokenizer(t *testing.T) {
  31. cases := []struct {
  32. name string
  33. fsys fs.FS
  34. specialTokenTypes []string
  35. want *Tokenizer
  36. }{
  37. {
  38. name: "string chat template",
  39. fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
  40. "tokenizer.json": strings.NewReader(`{}`),
  41. "tokenizer_config.json": strings.NewReader(`{
  42. "chat_template": "<default template>"
  43. }`),
  44. }),
  45. want: &Tokenizer{
  46. Vocabulary: &Vocabulary{Model: "gpt2"},
  47. Pre: "default",
  48. Template: "<default template>",
  49. },
  50. },
  51. {
  52. name: "list chat template",
  53. fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
  54. "tokenizer.json": strings.NewReader(`{}`),
  55. "tokenizer_config.json": strings.NewReader(`{
  56. "chat_template": [
  57. {
  58. "name": "default",
  59. "template": "<default template>"
  60. },
  61. {
  62. "name": "tools",
  63. "template": "<tools template>"
  64. }
  65. ]
  66. }`),
  67. }),
  68. want: &Tokenizer{
  69. Vocabulary: &Vocabulary{Model: "gpt2"},
  70. Pre: "default",
  71. Template: "<default template>",
  72. },
  73. },
  74. }
  75. for _, tt := range cases {
  76. t.Run(tt.name, func(t *testing.T) {
  77. tokenizer, err := parseTokenizer(tt.fsys, tt.specialTokenTypes)
  78. if err != nil {
  79. t.Fatalf("unexpected error: %v", err)
  80. }
  81. if diff := cmp.Diff(tt.want, tokenizer); diff != "" {
  82. t.Errorf("unexpected tokenizer (-want +got):\n%s", diff)
  83. }
  84. })
  85. }
  86. }