process_text_test.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. package model
  2. import (
  3. "reflect"
  4. "testing"
  5. )
  6. func TestBytePairEncoding(t *testing.T) {
  7. // Create a simple test vocabulary
  8. vocab := &Vocabulary{
  9. Values: []string{
  10. "Hello",
  11. "World",
  12. "!",
  13. "How",
  14. "are",
  15. "you",
  16. "t",
  17. "o",
  18. "d",
  19. "a",
  20. "y",
  21. "to",
  22. "tod",
  23. "toda",
  24. "today",
  25. " ",
  26. },
  27. Types: []uint32{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3}, // 3 for special token (space)
  28. Merges: []string{
  29. "to",
  30. "tod",
  31. "toda",
  32. "today",
  33. },
  34. BOS: 0,
  35. EOS: 1,
  36. }
  37. bpe := BytePairEncoding{
  38. Pretokenizer: `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
  39. Vocabulary: vocab,
  40. }
  41. tests := []struct {
  42. name string
  43. input string
  44. want []int32
  45. wantErr bool
  46. }{
  47. {
  48. name: "simple hello world",
  49. input: "Hello World!",
  50. want: []int32{0, 15, 1, 2}, // indexes in the vocabulary
  51. wantErr: false,
  52. },
  53. {
  54. name: "empty string",
  55. input: "",
  56. want: []int32{},
  57. wantErr: false,
  58. },
  59. {
  60. name: "just spaces",
  61. input: " ",
  62. want: []int32{15, 15, 15}, // space token repeated
  63. wantErr: false,
  64. },
  65. {
  66. name: "today with merges",
  67. input: "today",
  68. want: []int32{14}, // should merge
  69. wantErr: false,
  70. },
  71. }
  72. for _, tt := range tests {
  73. t.Run(tt.name, func(t *testing.T) {
  74. got, err := bpe.Encode(tt.input)
  75. if (err != nil) != tt.wantErr {
  76. t.Errorf("BytePairEncoding.Encode() error = %v, wantErr %v", err, tt.wantErr)
  77. return
  78. }
  79. if !reflect.DeepEqual(got, tt.want) {
  80. t.Errorf("BytePairEncoding.Encode() = %v, want %v", got, tt.want)
  81. }
  82. // Test round trip if encoding succeeded
  83. if err == nil {
  84. decoded, err := bpe.Decode(got)
  85. if err != nil {
  86. t.Errorf("BytePairEncoding.Decode() error = %v", err)
  87. return
  88. }
  89. // Note: The decoded string might not exactly match the input due to
  90. // tokenization/normalization, so we re-encode it to compare
  91. reEncoded, err := bpe.Encode(decoded)
  92. if err != nil {
  93. t.Errorf("BytePairEncoding.Encode() error on round trip = %v", err)
  94. return
  95. }
  96. if !reflect.DeepEqual(reEncoded, got) {
  97. t.Errorf("Round trip failed: original tokens = %v, after round trip = %v", got, reEncoded)
  98. }
  99. }
  100. })
  101. }
  102. }
  103. func TestBytePairEncodingSpecialTokens(t *testing.T) {
  104. vocab := &Vocabulary{
  105. Values: []string{
  106. "<s>",
  107. "</s>",
  108. "<pad>",
  109. "Hello",
  110. "World",
  111. },
  112. Types: []uint32{3, 3, 3, 1, 1}, // 3 for special tokens
  113. BOS: 0,
  114. EOS: 1,
  115. }
  116. bpe := BytePairEncoding{
  117. Pretokenizer: `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
  118. Vocabulary: vocab,
  119. }
  120. tests := []struct {
  121. name string
  122. input string
  123. want []int32
  124. wantErr bool
  125. }{
  126. {
  127. name: "text with special token at start",
  128. input: "<s>Hello",
  129. want: []int32{0, 3},
  130. wantErr: false,
  131. },
  132. {
  133. name: "text with special token at end",
  134. input: "World</s>",
  135. want: []int32{4, 1},
  136. wantErr: false,
  137. },
  138. {
  139. name: "special token in middle",
  140. input: "Hello<pad>World",
  141. want: []int32{3, 2, 4},
  142. wantErr: false,
  143. },
  144. }
  145. for _, tt := range tests {
  146. t.Run(tt.name, func(t *testing.T) {
  147. got, err := bpe.Encode(tt.input)
  148. if (err != nil) != tt.wantErr {
  149. t.Errorf("BytePairEncoding.Encode() error = %v, wantErr %v", err, tt.wantErr)
  150. return
  151. }
  152. if !reflect.DeepEqual(got, tt.want) {
  153. t.Errorf("BytePairEncoding.Encode() = %v, want %v", got, tt.want)
  154. }
  155. })
  156. }
  157. }
  158. func TestBytePairEncodingSplit(t *testing.T) {
  159. bpe := BytePairEncoding{
  160. Pretokenizer: `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
  161. }
  162. tests := []struct {
  163. name string
  164. input string
  165. want []string
  166. wantErr bool
  167. }{
  168. {
  169. name: "basic splitting",
  170. input: "Hello World!",
  171. want: []string{"Hello", " World", "!"},
  172. },
  173. {
  174. name: "contractions",
  175. input: "I'm don't won't",
  176. want: []string{"I", "'m", " don", "'t", " won", "'t"},
  177. },
  178. {
  179. name: "numbers",
  180. input: "In 2024 there are 365 days",
  181. want: []string{"In", " ", "202", "4", " there", " are", " ", "365", " days"},
  182. },
  183. {
  184. name: "special characters",
  185. input: "Hello!! ...world",
  186. want: []string{"Hello", "!!", " ...", "world"},
  187. },
  188. {
  189. name: "multiple spaces",
  190. input: "Hello World",
  191. want: []string{"Hello", " ", " World"},
  192. },
  193. {
  194. name: "newlines",
  195. input: "Hello\nWorld",
  196. want: []string{"Hello", "\n", "World"},
  197. },
  198. {
  199. name: "mixed case and punctuation",
  200. input: "Hello, WORLD!! How's it going?",
  201. want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?"},
  202. },
  203. }
  204. for _, tt := range tests {
  205. t.Run(tt.name, func(t *testing.T) {
  206. got, err := bpe.split(tt.input)
  207. if (err != nil) != tt.wantErr {
  208. t.Errorf("BytePairEncoding.split() error = %v, wantErr %v", err, tt.wantErr)
  209. return
  210. }
  211. if !reflect.DeepEqual(got, tt.want) {
  212. t.Errorf("BytePairEncoding.split() = %v, want %v", got, tt.want)
  213. }
  214. })
  215. }
  216. }