process_text_benchmark_test.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. package model
  2. import (
  3. "testing"
  4. )
  5. // BenchmarkVocabulary is a reusable test vocabulary for benchmarks
  6. var BenchmarkVocabulary = &Vocabulary{
  7. Values: []string{
  8. "Hello",
  9. "World",
  10. "!",
  11. "How",
  12. "are",
  13. "you",
  14. "t",
  15. "o",
  16. "d",
  17. "a",
  18. "y",
  19. "to",
  20. "tod",
  21. "toda",
  22. "today",
  23. " ",
  24. "<s>",
  25. "</s>",
  26. "<pad>",
  27. "'s",
  28. "'t",
  29. "'re",
  30. "'ve",
  31. "'m",
  32. "'ll",
  33. "'d",
  34. },
  35. Types: []uint32{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1}, // 3 for special tokens
  36. Merges: []string{
  37. "to",
  38. "tod",
  39. "toda",
  40. "today",
  41. },
  42. BOS: 16, // <s>
  43. EOS: 17, // </s>
  44. }
  45. func BenchmarkBytePairEncoding(b *testing.B) {
  46. bpe := BytePairEncoding{
  47. Pretokenizer: `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
  48. Vocabulary: BenchmarkVocabulary,
  49. }
  50. benchmarks := []struct {
  51. name string
  52. input string
  53. }{
  54. {
  55. name: "simple_hello_world",
  56. input: "Hello World!",
  57. },
  58. {
  59. name: "with_special_tokens",
  60. input: "<s>Hello World!</s>",
  61. },
  62. {
  63. name: "with_merges",
  64. input: "today is today and today",
  65. },
  66. {
  67. name: "with_contractions",
  68. input: "I'm don't won't can't they're we've you'll he'd",
  69. },
  70. {
  71. name: "long_text",
  72. input: "Hello World! How are you today? I'm doing great! This is a longer text to test the performance of the encoding and decoding process with multiple sentences and various tokens including special ones like <s> and </s> and contractions like don't and won't.",
  73. },
  74. }
  75. for _, bm := range benchmarks {
  76. // Benchmark Encoding
  77. b.Run("Encode_"+bm.name, func(b *testing.B) {
  78. b.ReportAllocs()
  79. for range b.N {
  80. tokens, err := bpe.Encode(bm.input)
  81. if err != nil {
  82. b.Fatal(err)
  83. }
  84. b.SetBytes(int64(len(tokens) * 4)) // Each token is 4 bytes (int32)
  85. }
  86. })
  87. // First encode the input to get tokens for decode benchmark
  88. tokens, err := bpe.Encode(bm.input)
  89. if err != nil {
  90. b.Fatal(err)
  91. }
  92. // Benchmark Decoding
  93. b.Run("Decode_"+bm.name, func(b *testing.B) {
  94. b.ReportAllocs()
  95. for range b.N {
  96. decoded, err := bpe.Decode(tokens)
  97. if err != nil {
  98. b.Fatal(err)
  99. }
  100. b.SetBytes(int64(len(decoded)))
  101. }
  102. })
  103. }
  104. }
  105. func BenchmarkBytePairEncodingSplit(b *testing.B) {
  106. bpe := BytePairEncoding{
  107. Pretokenizer: `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
  108. }
  109. benchmarks := []struct {
  110. name string
  111. input string
  112. }{
  113. {
  114. name: "simple_text",
  115. input: "Hello World!",
  116. },
  117. {
  118. name: "with_contractions",
  119. input: "I'm don't won't",
  120. },
  121. {
  122. name: "with_numbers",
  123. input: "In 2024 there are 365 days",
  124. },
  125. {
  126. name: "with_special_chars",
  127. input: "Hello!! ...world",
  128. },
  129. {
  130. name: "with_spaces",
  131. input: "Hello World",
  132. },
  133. {
  134. name: "with_newlines",
  135. input: "Hello\nWorld\nHow\nAre\nYou",
  136. },
  137. }
  138. for _, bm := range benchmarks {
  139. b.Run("Split_"+bm.name, func(b *testing.B) {
  140. b.ReportAllocs()
  141. for range b.N {
  142. splits, err := bpe.split(bm.input)
  143. if err != nil {
  144. b.Fatal(err)
  145. }
  146. b.SetBytes(int64(len(splits)))
  147. }
  148. })
  149. }
  150. }