process_text_test.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. package model
  2. import (
  3. "reflect"
  4. "testing"
  5. )
  6. func TestBytePairEncoding(t *testing.T) {
  7. // Create a simple test vocabulary
  8. vocab := &Vocabulary{
  9. Values: []string{
  10. "Hello",
  11. "World",
  12. "!",
  13. "How",
  14. "are",
  15. "you",
  16. "t",
  17. "o",
  18. "d",
  19. "a",
  20. "y",
  21. "to",
  22. "tod",
  23. "toda",
  24. "today",
  25. " ",
  26. },
  27. Types: []uint32{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3}, // 3 for special token (space)
  28. Merges: []string{
  29. "to",
  30. "tod",
  31. "toda",
  32. "today",
  33. },
  34. BOS: 0,
  35. EOS: 1,
  36. }
  37. bpe := BytePairEncoding{
  38. Pretokenizer: `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
  39. Vocabulary: vocab,
  40. }
  41. tests := []struct {
  42. name string
  43. input string
  44. want []int32
  45. wantErr bool
  46. }{
  47. {
  48. name: "simple hello world",
  49. input: "Hello World!",
  50. want: []int32{0, 15, 1, 2}, // indexes in the vocabulary
  51. wantErr: false,
  52. },
  53. {
  54. name: "empty string",
  55. input: "",
  56. wantErr: false,
  57. },
  58. {
  59. name: "just spaces",
  60. input: " ",
  61. want: []int32{15, 15, 15}, // space token repeated
  62. wantErr: false,
  63. },
  64. {
  65. name: "today with merges",
  66. input: "today",
  67. want: []int32{14}, // should merge
  68. wantErr: false,
  69. },
  70. }
  71. for _, tt := range tests {
  72. t.Run(tt.name, func(t *testing.T) {
  73. got, err := bpe.Encode(tt.input)
  74. if (err != nil) != tt.wantErr {
  75. t.Errorf("BytePairEncoding.Encode() error = %v, wantErr %v", err, tt.wantErr)
  76. return
  77. }
  78. if !reflect.DeepEqual(got, tt.want) {
  79. t.Errorf("BytePairEncoding.Encode() = %v, want %v", got, tt.want)
  80. }
  81. // Test round trip if encoding succeeded
  82. if err == nil {
  83. decoded, err := bpe.Decode(got)
  84. if err != nil {
  85. t.Errorf("BytePairEncoding.Decode() error = %v", err)
  86. return
  87. }
  88. // Note: The decoded string might not exactly match the input due to
  89. // tokenization/normalization, so we re-encode it to compare
  90. reEncoded, err := bpe.Encode(decoded)
  91. if err != nil {
  92. t.Errorf("BytePairEncoding.Encode() error on round trip = %v", err)
  93. return
  94. }
  95. if !reflect.DeepEqual(reEncoded, got) {
  96. t.Errorf("Round trip failed: original tokens = %v, after round trip = %v", got, reEncoded)
  97. }
  98. }
  99. })
  100. }
  101. }
  102. func TestBytePairEncodingSpecialTokens(t *testing.T) {
  103. vocab := &Vocabulary{
  104. Values: []string{
  105. "<s>",
  106. "</s>",
  107. "<pad>",
  108. "Hello",
  109. "World",
  110. },
  111. Types: []uint32{3, 3, 3, 1, 1}, // 3 for special tokens
  112. BOS: 0,
  113. EOS: 1,
  114. }
  115. bpe := BytePairEncoding{
  116. Pretokenizer: `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
  117. Vocabulary: vocab,
  118. }
  119. tests := []struct {
  120. name string
  121. input string
  122. want []int32
  123. wantErr bool
  124. }{
  125. {
  126. name: "text with special token at start",
  127. input: "<s>Hello",
  128. want: []int32{0, 3},
  129. wantErr: false,
  130. },
  131. {
  132. name: "text with special token at end",
  133. input: "World</s>",
  134. want: []int32{4, 1},
  135. wantErr: false,
  136. },
  137. {
  138. name: "special token in middle",
  139. input: "Hello<pad>World",
  140. want: []int32{3, 2, 4},
  141. wantErr: false,
  142. },
  143. }
  144. for _, tt := range tests {
  145. t.Run(tt.name, func(t *testing.T) {
  146. got, err := bpe.Encode(tt.input)
  147. if (err != nil) != tt.wantErr {
  148. t.Errorf("BytePairEncoding.Encode() error = %v, wantErr %v", err, tt.wantErr)
  149. return
  150. }
  151. if !reflect.DeepEqual(got, tt.want) {
  152. t.Errorf("BytePairEncoding.Encode() = %v, want %v", got, tt.want)
  153. }
  154. })
  155. }
  156. }
  157. func TestBytePairEncodingSplit(t *testing.T) {
  158. bpe := BytePairEncoding{
  159. Pretokenizer: `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
  160. }
  161. tests := []struct {
  162. name string
  163. input string
  164. want []string
  165. wantErr bool
  166. }{
  167. {
  168. name: "basic splitting",
  169. input: "Hello World!",
  170. want: []string{"Hello", " World", "!"},
  171. },
  172. {
  173. name: "contractions",
  174. input: "I'm don't won't",
  175. want: []string{"I", "'m", " don", "'t", " won", "'t"},
  176. },
  177. {
  178. name: "numbers",
  179. input: "In 2024 there are 365 days",
  180. want: []string{"In", " ", "202", "4", " there", " are", " ", "365", " days"},
  181. },
  182. {
  183. name: "special characters",
  184. input: "Hello!! ...world",
  185. want: []string{"Hello", "!!", " ...", "world"},
  186. },
  187. {
  188. name: "multiple spaces",
  189. input: "Hello World",
  190. want: []string{"Hello", " ", " World"},
  191. },
  192. {
  193. name: "newlines",
  194. input: "Hello\nWorld",
  195. want: []string{"Hello", "\n", "World"},
  196. },
  197. {
  198. name: "mixed case and punctuation",
  199. input: "Hello, WORLD!! How's it going?",
  200. want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?"},
  201. },
  202. }
  203. for _, tt := range tests {
  204. t.Run(tt.name, func(t *testing.T) {
  205. got, err := bpe.split(tt.input)
  206. if (err != nil) != tt.wantErr {
  207. t.Errorf("BytePairEncoding.split() error = %v, wantErr %v", err, tt.wantErr)
  208. return
  209. }
  210. if !reflect.DeepEqual(got, tt.want) {
  211. t.Errorf("BytePairEncoding.split() = %v, want %v", got, tt.want)
  212. }
  213. })
  214. }
  215. }