process_text_test.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574
  1. package model
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "math"
  6. "os"
  7. "path/filepath"
  8. "slices"
  9. "strconv"
  10. "strings"
  11. "testing"
  12. "github.com/google/go-cmp/cmp"
  13. )
  14. func llama(t testing.TB) BytePairEncoding {
  15. t.Helper()
  16. f, err := os.Open(filepath.Join("testdata", "llama3.2", "encoder.json"))
  17. if err != nil {
  18. t.Fatal(err)
  19. }
  20. defer f.Close()
  21. vocab := make(map[string]int32)
  22. if err := json.NewDecoder(f).Decode(&vocab); err != nil {
  23. t.Fatal(err)
  24. }
  25. types := make([]uint32, len(vocab))
  26. tokens := make([]string, len(vocab))
  27. for token, id := range vocab {
  28. tokens[id] = token
  29. types[id] = 1
  30. }
  31. for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
  32. if _, ok := vocab[token]; !ok {
  33. tokens = append(tokens, token) //nolint:makezero
  34. types = append(types, 3) //nolint:makezero
  35. vocab[token] = int32(len(vocab))
  36. }
  37. }
  38. f, err = os.Open(filepath.Join("testdata", "llama3.2", "vocab.bpe"))
  39. if err != nil {
  40. t.Fatal(err)
  41. }
  42. defer f.Close()
  43. merges := make([]string, 0, 50000)
  44. scanner := bufio.NewScanner(f)
  45. for scanner.Scan() {
  46. if !strings.HasPrefix(scanner.Text(), "#") {
  47. merges = append(merges, scanner.Text())
  48. }
  49. }
  50. return NewBytePairEncoding(
  51. `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
  52. &Vocabulary{
  53. Values: tokens,
  54. Types: types,
  55. Merges: merges,
  56. },
  57. )
  58. }
  59. func TestLlama(t *testing.T) {
  60. tokenizer := llama(t)
  61. t.Run("simple", func(t *testing.T) {
  62. t.Parallel()
  63. ids, err := tokenizer.Encode("hello world", true)
  64. if err != nil {
  65. t.Error(err)
  66. }
  67. if diff := cmp.Diff([]int32{15339, 1917}, ids); diff != "" {
  68. t.Errorf("no match (-theirs +ours):\n%s", diff)
  69. }
  70. s, err := tokenizer.Decode([]int32{15339, 1917})
  71. if err != nil {
  72. t.Fatal(err)
  73. }
  74. if s != "hello world" {
  75. t.Errorf("got %q, want hello world", s)
  76. }
  77. ids, err = tokenizer.Encode("hello <|end_of_text|>", true)
  78. if err != nil {
  79. t.Error(err)
  80. }
  81. if diff := cmp.Diff([]int32{15339, 220, 128001}, ids); diff != "" {
  82. t.Errorf("no match (-theirs +ours):\n%s", diff)
  83. }
  84. })
  85. t.Run("simple repeated", func(t *testing.T) {
  86. t.Parallel()
  87. cases := map[string][]int32{
  88. strings.Repeat("0", 1): {15},
  89. strings.Repeat("0", 2): {410},
  90. strings.Repeat("0", 3): {931},
  91. strings.Repeat("0", 4): {931, 15},
  92. strings.Repeat("0", 5): {931, 410},
  93. strings.Repeat("0", 6): {931, 931},
  94. strings.Repeat("0", 7): {931, 931, 15},
  95. strings.Repeat("0", 8): {931, 931, 410},
  96. strings.Repeat("0", 9): {931, 931, 931},
  97. strings.Repeat("0", 10): {931, 931, 931, 15},
  98. strings.Repeat("0", 11): {931, 931, 931, 410},
  99. strings.Repeat("0", 12): {931, 931, 931, 931},
  100. strings.Repeat("0", 13): {931, 931, 931, 931, 15},
  101. strings.Repeat("0", 14): {931, 931, 931, 931, 410},
  102. strings.Repeat("0", 15): {931, 931, 931, 931, 931},
  103. strings.Repeat("0", 16): {931, 931, 931, 931, 931, 15},
  104. strings.Repeat("0", 17): {931, 931, 931, 931, 931, 410},
  105. }
  106. for s, want := range cases {
  107. ids, err := tokenizer.Encode(s, true)
  108. if err != nil {
  109. t.Error(err)
  110. }
  111. if diff := cmp.Diff(want, ids); diff != "" {
  112. t.Errorf("%q no match (-theirs +ours):\n%s", s, diff)
  113. }
  114. }
  115. })
  116. t.Run("basic roundtrip", func(t *testing.T) {
  117. t.Parallel()
  118. cases := []string{
  119. "hello",
  120. "hello ",
  121. "hello ",
  122. " hello",
  123. " hello ",
  124. " hello ",
  125. "hello world",
  126. "请考试我的软件!12345",
  127. }
  128. for _, want := range cases {
  129. ids, err := tokenizer.Encode(want, true)
  130. if err != nil {
  131. t.Error(err)
  132. }
  133. if got, err := tokenizer.Decode(ids); err != nil {
  134. t.Fatal(err)
  135. } else if got != want {
  136. t.Errorf("got %q, want %q", got, want)
  137. }
  138. }
  139. })
  140. t.Run("special", func(t *testing.T) {
  141. t.Parallel()
  142. cases := map[string][]int32{
  143. "<|begin_of_text|>A B!": {128000, 32, 426, 0},
  144. "<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
  145. "<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
  146. "<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
  147. }
  148. for s, want := range cases {
  149. ids, err := tokenizer.Encode(s, true)
  150. if err != nil {
  151. t.Fatal(err)
  152. }
  153. if diff := cmp.Diff(want, ids); diff != "" {
  154. t.Errorf("no match (-theirs +ours):\n%s", diff)
  155. }
  156. }
  157. })
  158. t.Run("split", func(t *testing.T) {
  159. t.Parallel()
  160. cases := map[string][]string{
  161. "Hello World!": {"Hello", " World", "!"},
  162. "I'm don't won't": {"I", "'m", " don", "'t", " won", "'t"},
  163. "In 2024 there are 366 days": {"In", " ", "202", "4", " there", " are", " ", "366", " days"},
  164. "Hello!! ...world": {"Hello", "!!", " ...", "world"},
  165. "Hello World": {"Hello", " ", " World"},
  166. "Hello\nWorld": {"Hello", "\n", "World"},
  167. "Hello, WORLD!! How's it going?": {"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?"},
  168. }
  169. for s, want := range cases {
  170. got := slices.Collect(tokenizer.split(s))
  171. if diff := cmp.Diff(want, got); diff != "" {
  172. t.Errorf("no match (-theirs +ours):\n%s", diff)
  173. }
  174. }
  175. })
  176. }
  177. // tekken loads the Tekken tokenizer for testing
  178. func tekken(t testing.TB) TextProcessor {
  179. t.Helper()
  180. // Load tokenizer config from mistral-small
  181. tokenizerConfigPath := filepath.Join("testdata", "mistral-small", "tokenizer_config.json")
  182. configFile, err := os.Open(tokenizerConfigPath)
  183. if err != nil {
  184. t.Fatal(err)
  185. }
  186. defer configFile.Close()
  187. var config struct {
  188. AddBosToken bool `json:"add_bos_token"`
  189. AddEosToken bool `json:"add_eos_token"`
  190. BosToken struct {
  191. Content string `json:"content"`
  192. } `json:"bos_token"`
  193. EosToken struct {
  194. Content string `json:"content"`
  195. } `json:"eos_token"`
  196. }
  197. if err := json.NewDecoder(configFile).Decode(&config); err != nil {
  198. t.Fatal(err)
  199. }
  200. // Load tokenizer.json which contains the vocabulary and other settings
  201. tokenizerJsonPath := filepath.Join("testdata", "mistral-small", "tokenizer.json")
  202. tokenizerFile, err := os.Open(tokenizerJsonPath)
  203. if err != nil {
  204. t.Fatal(err)
  205. }
  206. defer tokenizerFile.Close()
  207. var tokenizerData struct {
  208. Model struct {
  209. Type string `json:"type"`
  210. Vocab map[string]int32 `json:"vocab"`
  211. Merges []string `json:"merges"`
  212. } `json:"model"`
  213. AddedTokens []struct {
  214. Id int32 `json:"id"`
  215. Content string `json:"content"`
  216. Special bool `json:"special"`
  217. } `json:"added_tokens"`
  218. PreTokenizer struct {
  219. Type string `json:"type"`
  220. Pretokenizers []struct {
  221. Type string `json:"type"`
  222. Pattern struct {
  223. String string `json:"String"`
  224. } `json:"pattern"`
  225. Behavior string `json:"behavior"`
  226. } `json:"pretokenizers"`
  227. } `json:"pre_tokenizer"`
  228. }
  229. if err := json.NewDecoder(tokenizerFile).Decode(&tokenizerData); err != nil {
  230. t.Fatal(err)
  231. }
  232. // Extract the pattern from pre_tokenizer if available
  233. var pattern string
  234. if tokenizerData.PreTokenizer.Type == "Sequence" && len(tokenizerData.PreTokenizer.Pretokenizers) > 0 {
  235. pattern = tokenizerData.PreTokenizer.Pretokenizers[0].Pattern.String
  236. }
  237. // Combine regular vocab and added tokens
  238. vocab := tokenizerData.Model.Vocab
  239. // Add special tokens from added_tokens
  240. for _, token := range tokenizerData.AddedTokens {
  241. vocab[token.Content] = token.Id
  242. }
  243. // Create vocabulary arrays
  244. maxId := int32(-1)
  245. for _, id := range vocab {
  246. if id > maxId {
  247. maxId = id
  248. }
  249. }
  250. vocabSize := int(maxId + 1)
  251. types := make([]uint32, vocabSize)
  252. tokens := make([]string, vocabSize)
  253. scores := make([]float32, vocabSize)
  254. for token, id := range vocab {
  255. tokens[id] = token
  256. types[id] = TOKEN_TYPE_NORMAL
  257. // Assign appropriate token types for special tokens
  258. if token == "<s>" {
  259. types[id] = TOKEN_TYPE_CONTROL
  260. } else if token == "</s>" {
  261. types[id] = TOKEN_TYPE_CONTROL
  262. } else if token == "[INST]" || token == "[/INST]" {
  263. types[id] = TOKEN_TYPE_CONTROL
  264. }
  265. }
  266. // In Tekken, we don't need to load merges separately as they're part of the model
  267. var merges []string
  268. // Create vocabulary object
  269. vocabObj := &Vocabulary{
  270. Values: tokens,
  271. Types: types,
  272. Scores: scores,
  273. Merges: merges,
  274. BOS: vocab[config.BosToken.Content],
  275. EOS: vocab[config.EosToken.Content],
  276. AddBOS: config.AddBosToken,
  277. AddEOS: config.AddEosToken,
  278. }
  279. // Use pattern from tokenizer.json if available
  280. if pattern != "" {
  281. // Ensure pattern has proper escaping for Go regexp
  282. pattern = strings.ReplaceAll(pattern, "p{", "\\p{")
  283. return NewBytePairEncoding(pattern, vocabObj)
  284. }
  285. // Fallback pattern if not found
  286. return NewBytePairEncoding(
  287. `\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+|\s+`,
  288. vocabObj,
  289. )
  290. }
  291. func TestTekken(t *testing.T) {
  292. // Skip if the test data isn't available
  293. if _, err := os.Stat(filepath.Join("testdata", "mistral-small")); os.IsNotExist(err) {
  294. t.Skip("Mistral-small test data not available")
  295. }
  296. tokenizer := tekken(t)
  297. t.Run("whitespace_handling", func(t *testing.T) {
  298. t.Parallel()
  299. // The key difference from SentencePiece is that Tekken doesn't prepend whitespace
  300. cases := []struct {
  301. input string
  302. expected string
  303. }{
  304. {" hello", " hello"},
  305. {"hello ", "hello "},
  306. {"hello world", "hello world"},
  307. {" hello world ", " hello world "},
  308. }
  309. for _, tc := range cases {
  310. ids, err := tokenizer.Encode(tc.input, false)
  311. if err != nil {
  312. t.Errorf("Failed to encode %q: %v", tc.input, err)
  313. continue
  314. }
  315. decoded, err := tokenizer.Decode(ids)
  316. if err != nil {
  317. t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
  318. continue
  319. }
  320. if decoded != tc.expected {
  321. t.Errorf("Whitespace handling: got %q, want %q", decoded, tc.expected)
  322. }
  323. }
  324. })
  325. t.Run("chat_templates", func(t *testing.T) {
  326. t.Parallel()
  327. // Test the Tekken chat template format which doesn't have spaces after special tokens
  328. templates := []struct {
  329. input string
  330. expectSpace bool // whether we expect a space after special tokens
  331. }{
  332. {"<s>[INST]user message[/INST]", false},
  333. {"<s>[INST] user message[/INST]", true},
  334. {"<s>[INST]user message [/INST]", true},
  335. }
  336. for _, tc := range templates {
  337. ids, err := tokenizer.Encode(tc.input, false)
  338. if err != nil {
  339. t.Errorf("Failed to encode %q: %v", tc.input, err)
  340. continue
  341. }
  342. decoded, err := tokenizer.Decode(ids)
  343. if err != nil {
  344. t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
  345. continue
  346. }
  347. // Check if there's a space after special tokens
  348. hasSpaceAfterINST := strings.Contains(decoded, "[INST] ")
  349. if hasSpaceAfterINST != tc.expectSpace {
  350. t.Errorf("Chat template space handling: got space=%v, want space=%v for %q",
  351. hasSpaceAfterINST, tc.expectSpace, tc.input)
  352. }
  353. }
  354. })
  355. t.Run("special_tokens", func(t *testing.T) {
  356. t.Parallel()
  357. // Test how Tekken handles special tokens
  358. cases := []struct {
  359. input string
  360. expected []string // We'll check if these tokens are in the decoded output
  361. }{
  362. {"<s>[INST]hello[/INST]", []string{"<s>", "[INST]", "hello", "[/INST]"}},
  363. {"[INST]hello[/INST]</s>", []string{"[INST]", "hello", "[/INST]", "</s>"}},
  364. {"<s>[INST]hello[/INST]</s>[INST]again[/INST]", []string{"<s>", "[INST]", "hello", "[/INST]", "</s>", "[INST]", "again", "[/INST]"}},
  365. }
  366. for _, tc := range cases {
  367. ids, err := tokenizer.Encode(tc.input, false)
  368. if err != nil {
  369. t.Errorf("Failed to encode %q: %v", tc.input, err)
  370. continue
  371. }
  372. decoded, err := tokenizer.Decode(ids)
  373. if err != nil {
  374. t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
  375. continue
  376. }
  377. for _, expected := range tc.expected {
  378. if !strings.Contains(decoded, expected) {
  379. t.Errorf("Special token handling: %q missing in decoded output %q", expected, decoded)
  380. }
  381. }
  382. }
  383. })
  384. t.Run("vocabulary_coverage", func(t *testing.T) {
  385. t.Parallel()
  386. // Tekken has a larger vocabulary, so test coverage of various token types
  387. samples := []string{
  388. "Hello world!",
  389. "This is a test of the Tekken tokenizer.",
  390. "It has a considerably larger vocabulary size.",
  391. "Special characters: !@#$%^&*()",
  392. "Numbers: 1234567890",
  393. "Multiple languages: こんにちは 你好 안녕하세요",
  394. "Code snippets: def function(): return True",
  395. }
  396. for _, sample := range samples {
  397. ids, err := tokenizer.Encode(sample, false)
  398. if err != nil {
  399. t.Errorf("Failed to encode %q: %v", sample, err)
  400. continue
  401. }
  402. decoded, err := tokenizer.Decode(ids)
  403. if err != nil {
  404. t.Errorf("Failed to decode tokens for %q: %v", sample, err)
  405. continue
  406. }
  407. if decoded != sample {
  408. t.Errorf("Vocabulary coverage: got %q, want %q", decoded, sample)
  409. }
  410. }
  411. })
  412. t.Run("splitting_behavior", func(t *testing.T) {
  413. t.Parallel()
  414. // Test the splitting behavior which might differ from SentencePiece
  415. cases := map[string][]string{
  416. "Hello World!": {"Hello", " World", "!"},
  417. "user message": {"user", " message"},
  418. "[INST]hello": {"[INST]", "hello"},
  419. "hello[/INST]": {"hello", "[/INST]"},
  420. }
  421. for s, want := range cases {
  422. got := slices.Collect(tokenizer.(*BytePairEncoding).split(s))
  423. if diff := cmp.Diff(want, got); diff != "" {
  424. t.Errorf("Splitting behavior no match (-want +got):\n%s", diff)
  425. }
  426. }
  427. })
  428. t.Run("full_chat_sequence", func(t *testing.T) {
  429. t.Parallel()
  430. // Test a complete chat sequence with Tekken's format
  431. chatSequence := "<s>[INST]user message[/INST]assistant message</s>[INST]new user message[/INST]"
  432. ids, err := tokenizer.Encode(chatSequence, false)
  433. if err != nil {
  434. t.Fatalf("Failed to encode chat sequence: %v", err)
  435. }
  436. decoded, err := tokenizer.Decode(ids)
  437. if err != nil {
  438. t.Fatalf("Failed to decode chat sequence tokens: %v", err)
  439. }
  440. // In Tekken, the whitespace shouldn't be added after special tokens
  441. if strings.Contains(decoded, "[INST] ") {
  442. t.Errorf("Tekken chat sequence has unexpected space after [INST]: %q", decoded)
  443. }
  444. if strings.Contains(decoded, "[/INST] ") {
  445. t.Errorf("Tekken chat sequence has unexpected space after [/INST]: %q", decoded)
  446. }
  447. })
  448. }
  449. func BenchmarkBytePairEncoding(b *testing.B) {
  450. tokenizer := llama(b)
  451. bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))
  452. if err != nil {
  453. b.Fatal(err)
  454. }
  455. for i := range 8 {
  456. n := min(int(math.Pow10(i)), len(bts))
  457. bts := bts[:n]
  458. b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
  459. b.ResetTimer()
  460. for range b.N {
  461. _, err := tokenizer.Encode(string(bts), true)
  462. if err != nil {
  463. b.Fatal(err)
  464. }
  465. }
  466. })
  467. b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
  468. ids, err := tokenizer.Encode(string(bts), true)
  469. if err != nil {
  470. b.Fatal(err)
  471. }
  472. b.ResetTimer()
  473. for range b.N {
  474. _, err := tokenizer.Decode(ids)
  475. if err != nil {
  476. b.Fatal(err)
  477. }
  478. }
  479. })
  480. b.Run("split"+strconv.Itoa(n), func(b *testing.B) {
  481. b.ResetTimer()
  482. for range b.N {
  483. slices.Collect(tokenizer.split(string(bts)))
  484. }
  485. })
  486. }
  487. }