process_text_test.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. package model
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "math"
  6. "os"
  7. "path/filepath"
  8. "slices"
  9. "strconv"
  10. "strings"
  11. "testing"
  12. "github.com/google/go-cmp/cmp"
  13. )
  14. func llama(t testing.TB) BytePairEncoding {
  15. t.Helper()
  16. f, err := os.Open(filepath.Join("testdata", "llama3.2", "encoder.json"))
  17. if err != nil {
  18. t.Fatal(err)
  19. }
  20. defer f.Close()
  21. vocab := make(map[string]int32)
  22. if err := json.NewDecoder(f).Decode(&vocab); err != nil {
  23. t.Fatal(err)
  24. }
  25. types := make([]uint32, len(vocab))
  26. tokens := make([]string, len(vocab))
  27. for token, id := range vocab {
  28. tokens[id] = token
  29. types[id] = 1
  30. }
  31. for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
  32. if _, ok := vocab[token]; !ok {
  33. tokens = append(tokens, token) //nolint:makezero
  34. types = append(types, 3) //nolint:makezero
  35. vocab[token] = int32(len(vocab))
  36. }
  37. }
  38. f, err = os.Open(filepath.Join("testdata", "llama3.2", "vocab.bpe"))
  39. if err != nil {
  40. t.Fatal(err)
  41. }
  42. defer f.Close()
  43. merges := make([]string, 0, 50000)
  44. scanner := bufio.NewScanner(f)
  45. for scanner.Scan() {
  46. if !strings.HasPrefix(scanner.Text(), "#") {
  47. merges = append(merges, scanner.Text())
  48. }
  49. }
  50. return NewBytePairEncoding(
  51. `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
  52. &Vocabulary{
  53. Values: tokens,
  54. Types: types,
  55. Merges: merges,
  56. },
  57. )
  58. }
  59. func TestLlama(t *testing.T) {
  60. tokenizer := llama(t)
  61. t.Run("simple", func(t *testing.T) {
  62. t.Parallel()
  63. ids, err := tokenizer.Encode("hello world", true)
  64. if err != nil {
  65. t.Error(err)
  66. }
  67. if diff := cmp.Diff([]int32{15339, 1917}, ids); diff != "" {
  68. t.Errorf("no match (-theirs +ours):\n%s", diff)
  69. }
  70. s, err := tokenizer.Decode([]int32{15339, 1917})
  71. if err != nil {
  72. t.Fatal(err)
  73. }
  74. if s != "hello world" {
  75. t.Errorf("got %q, want hello world", s)
  76. }
  77. ids, err = tokenizer.Encode("hello <|end_of_text|>", true)
  78. if err != nil {
  79. t.Error(err)
  80. }
  81. if diff := cmp.Diff([]int32{15339, 220, 128001}, ids); diff != "" {
  82. t.Errorf("no match (-theirs +ours):\n%s", diff)
  83. }
  84. })
  85. t.Run("simple repeated", func(t *testing.T) {
  86. t.Parallel()
  87. cases := map[string][]int32{
  88. strings.Repeat("0", 1): {15},
  89. strings.Repeat("0", 2): {410},
  90. strings.Repeat("0", 3): {931},
  91. strings.Repeat("0", 4): {931, 15},
  92. strings.Repeat("0", 5): {931, 410},
  93. strings.Repeat("0", 6): {931, 931},
  94. strings.Repeat("0", 7): {931, 931, 15},
  95. strings.Repeat("0", 8): {931, 931, 410},
  96. strings.Repeat("0", 9): {931, 931, 931},
  97. strings.Repeat("0", 10): {931, 931, 931, 15},
  98. strings.Repeat("0", 11): {931, 931, 931, 410},
  99. strings.Repeat("0", 12): {931, 931, 931, 931},
  100. strings.Repeat("0", 13): {931, 931, 931, 931, 15},
  101. strings.Repeat("0", 14): {931, 931, 931, 931, 410},
  102. strings.Repeat("0", 15): {931, 931, 931, 931, 931},
  103. strings.Repeat("0", 16): {931, 931, 931, 931, 931, 15},
  104. strings.Repeat("0", 17): {931, 931, 931, 931, 931, 410},
  105. }
  106. for s, want := range cases {
  107. ids, err := tokenizer.Encode(s, true)
  108. if err != nil {
  109. t.Error(err)
  110. }
  111. if diff := cmp.Diff(want, ids); diff != "" {
  112. t.Errorf("%q no match (-theirs +ours):\n%s", s, diff)
  113. }
  114. }
  115. })
  116. t.Run("basic roundtrip", func(t *testing.T) {
  117. t.Parallel()
  118. cases := []string{
  119. "hello",
  120. "hello ",
  121. "hello ",
  122. " hello",
  123. " hello ",
  124. " hello ",
  125. "hello world",
  126. "请考试我的软件!12345",
  127. }
  128. for _, want := range cases {
  129. ids, err := tokenizer.Encode(want, true)
  130. if err != nil {
  131. t.Error(err)
  132. }
  133. if got, err := tokenizer.Decode(ids); err != nil {
  134. t.Fatal(err)
  135. } else if got != want {
  136. t.Errorf("got %q, want %q", got, want)
  137. }
  138. }
  139. })
  140. t.Run("special", func(t *testing.T) {
  141. t.Parallel()
  142. cases := map[string][]int32{
  143. "<|begin_of_text|>A B!": {128000, 32, 426, 0},
  144. "<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
  145. "<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
  146. "<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
  147. }
  148. for s, want := range cases {
  149. ids, err := tokenizer.Encode(s, true)
  150. if err != nil {
  151. t.Fatal(err)
  152. }
  153. if diff := cmp.Diff(want, ids); diff != "" {
  154. t.Errorf("no match (-theirs +ours):\n%s", diff)
  155. }
  156. }
  157. })
  158. t.Run("split", func(t *testing.T) {
  159. t.Parallel()
  160. cases := map[string][]string{
  161. "Hello World!": {"Hello", " World", "!"},
  162. "I'm don't won't": {"I", "'m", " don", "'t", " won", "'t"},
  163. "In 2024 there are 366 days": {"In", " ", "202", "4", " there", " are", " ", "366", " days"},
  164. "Hello!! ...world": {"Hello", "!!", " ...", "world"},
  165. "Hello World": {"Hello", " ", " World"},
  166. "Hello\nWorld": {"Hello", "\n", "World"},
  167. "Hello, WORLD!! How's it going?": {"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?"},
  168. }
  169. for s, want := range cases {
  170. got := slices.Collect(tokenizer.split(s))
  171. if diff := cmp.Diff(want, got); diff != "" {
  172. t.Errorf("no match (-theirs +ours):\n%s", diff)
  173. }
  174. }
  175. })
  176. }
  177. func BenchmarkBytePairEncoding(b *testing.B) {
  178. tokenizer := llama(b)
  179. bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))
  180. if err != nil {
  181. b.Fatal(err)
  182. }
  183. for i := range 8 {
  184. n := min(int(math.Pow10(i)), len(bts))
  185. bts := bts[:n]
  186. b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
  187. b.ResetTimer()
  188. for range b.N {
  189. _, err := tokenizer.Encode(string(bts), true)
  190. if err != nil {
  191. b.Fatal(err)
  192. }
  193. }
  194. })
  195. b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
  196. ids, err := tokenizer.Encode(string(bts), true)
  197. if err != nil {
  198. b.Fatal(err)
  199. }
  200. b.ResetTimer()
  201. for range b.N {
  202. _, err := tokenizer.Decode(ids)
  203. if err != nil {
  204. b.Fatal(err)
  205. }
  206. }
  207. })
  208. b.Run("split"+strconv.Itoa(n), func(b *testing.B) {
  209. b.ResetTimer()
  210. for range b.N {
  211. slices.Collect(tokenizer.split(string(bts)))
  212. }
  213. })
  214. }
  215. }