tokenizer_test.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. package convert
  2. import (
  3. "io"
  4. "io/fs"
  5. "os"
  6. "path/filepath"
  7. "strings"
  8. "testing"
  9. "github.com/google/go-cmp/cmp"
  10. )
  11. func createTokenizerFS(t *testing.T, dir string, files map[string]io.Reader) fs.FS {
  12. t.Helper()
  13. for k, v := range files {
  14. if err := func() error {
  15. f, err := os.Create(filepath.Join(dir, k))
  16. if err != nil {
  17. return err
  18. }
  19. defer f.Close()
  20. if _, err := io.Copy(f, v); err != nil {
  21. return err
  22. }
  23. return nil
  24. }(); err != nil {
  25. t.Fatalf("unexpected error: %v", err)
  26. }
  27. }
  28. return os.DirFS(dir)
  29. }
  30. func TestParseTokenizer(t *testing.T) {
  31. cases := []struct {
  32. name string
  33. fsys fs.FS
  34. specialTokenTypes []string
  35. want *Tokenizer
  36. }{
  37. {
  38. name: "string chat template",
  39. fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
  40. "tokenizer.json": strings.NewReader(`{}`),
  41. "tokenizer_config.json": strings.NewReader(`{
  42. "chat_template": "<default template>"
  43. }`),
  44. }),
  45. want: &Tokenizer{
  46. Vocabulary: &Vocabulary{Model: "gpt2"},
  47. Pre: "default",
  48. Template: "<default template>",
  49. },
  50. },
  51. {
  52. name: "list chat template",
  53. fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
  54. "tokenizer.json": strings.NewReader(`{}`),
  55. "tokenizer_config.json": strings.NewReader(`{
  56. "chat_template": [
  57. {
  58. "name": "default",
  59. "template": "<default template>"
  60. },
  61. {
  62. "name": "tools",
  63. "template": "<tools template>"
  64. }
  65. ]
  66. }`),
  67. }),
  68. want: &Tokenizer{
  69. Vocabulary: &Vocabulary{Model: "gpt2"},
  70. Pre: "default",
  71. Template: "<default template>",
  72. },
  73. },
  74. {
  75. name: "added tokens",
  76. fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
  77. "tokenizer.json": strings.NewReader(`{
  78. "added_tokens": [
  79. {
  80. "id": 999,
  81. "content": "<unused999>",
  82. "special": false
  83. }
  84. ]
  85. }`),
  86. }),
  87. want: &Tokenizer{
  88. Vocabulary: &Vocabulary{
  89. Model: "gpt2",
  90. Tokens: []string{"<unused999>"},
  91. Scores: []float32{999},
  92. Types: []int32{4},
  93. },
  94. Pre: "default",
  95. },
  96. },
  97. {
  98. name: "added tokens overlap vocab",
  99. fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
  100. "tokenizer.json": strings.NewReader(`{
  101. "added_tokens": [
  102. {
  103. "id": 0,
  104. "content": "<pad>",
  105. "special": true
  106. }
  107. ],
  108. "model": {
  109. "vocab": {
  110. "<pad>": 0
  111. }
  112. }
  113. }`),
  114. }),
  115. want: &Tokenizer{
  116. Vocabulary: &Vocabulary{
  117. Model: "gpt2",
  118. Tokens: []string{"<pad>"},
  119. Scores: []float32{0},
  120. Types: []int32{3},
  121. },
  122. Pre: "default",
  123. },
  124. },
  125. {
  126. name: "special token types",
  127. fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
  128. "tokenizer.json": strings.NewReader(`{
  129. "added_tokens": [
  130. {
  131. "id": 0,
  132. "content": "<pad>",
  133. "special": true
  134. },
  135. {
  136. "id": 1,
  137. "content": "<eos>",
  138. "special": true
  139. },
  140. {
  141. "id": 2,
  142. "content": "<bos>",
  143. "special": true
  144. },
  145. {
  146. "id": 3,
  147. "content": "<unk>",
  148. "special": true
  149. }
  150. ],
  151. "model": {
  152. "vocab": {
  153. "<pad>": 0,
  154. "<eos>": 1,
  155. "<bos>": 2,
  156. "<unk>": 3
  157. }
  158. }
  159. }`),
  160. "tokenizer_config.json": strings.NewReader(`{
  161. "add_bos_token": true,
  162. "add_eos_token": false,
  163. "bos_token": "<bos>",
  164. "eos_token": "<eos>",
  165. "pad_token": "<pad>",
  166. "unk_token": "<unk>"
  167. }`),
  168. }),
  169. specialTokenTypes: []string{"pad", "eos", "bos", "unk"},
  170. want: &Tokenizer{
  171. Vocabulary: &Vocabulary{
  172. Model: "gpt2",
  173. Tokens: []string{"<pad>", "<eos>", "<bos>", "<unk>"},
  174. Scores: []float32{0, 1, 2, 3},
  175. Types: []int32{3, 3, 3, 3},
  176. },
  177. SpecialVocabulary: []*SpecialVocabulary{
  178. {Type: "pad", Content: "<pad>", ID: 0, AddToken: false},
  179. {Type: "eos", Content: "<eos>", ID: 1, AddToken: false},
  180. {Type: "bos", Content: "<bos>", ID: 2, AddToken: true},
  181. {Type: "unk", Content: "<unk>", ID: 3, AddToken: false},
  182. },
  183. Pre: "default",
  184. },
  185. },
  186. {
  187. name: "list string merges",
  188. fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
  189. "tokenizer.json": strings.NewReader(`{
  190. "model": {
  191. "merges": [
  192. "a b",
  193. "c d",
  194. "e f"
  195. ]
  196. }
  197. }`),
  198. }),
  199. want: &Tokenizer{
  200. Vocabulary: &Vocabulary{
  201. Model: "gpt2",
  202. },
  203. Merges: []string{
  204. "a b",
  205. "c d",
  206. "e f",
  207. },
  208. Pre: "default",
  209. },
  210. },
  211. {
  212. name: "list list string merges",
  213. fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
  214. "tokenizer.json": strings.NewReader(`{
  215. "model": {
  216. "merges": [
  217. [
  218. "a", "b"
  219. ],
  220. [
  221. "c", "d"
  222. ],
  223. [
  224. "e", "f"
  225. ]
  226. ]
  227. }
  228. }`),
  229. }),
  230. want: &Tokenizer{
  231. Vocabulary: &Vocabulary{
  232. Model: "gpt2",
  233. },
  234. Merges: []string{
  235. "a b",
  236. "c d",
  237. "e f",
  238. },
  239. Pre: "default",
  240. },
  241. },
  242. }
  243. for _, tt := range cases {
  244. t.Run(tt.name, func(t *testing.T) {
  245. tokenizer, err := parseTokenizer(tt.fsys, tt.specialTokenTypes)
  246. if err != nil {
  247. t.Fatalf("unexpected error: %v", err)
  248. }
  249. if diff := cmp.Diff(tt.want, tokenizer); diff != "" {
  250. t.Errorf("unexpected tokenizer (-want +got):\n%s", diff)
  251. }
  252. })
  253. }
  254. }