transforms_test.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. package sample
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "math"
  6. "math/rand/v2"
  7. "os"
  8. "path/filepath"
  9. "runtime"
  10. "testing"
  11. )
  12. // Helper to convert float32 slice to logit slice
  13. func toTokens(values []float32) []token {
  14. tokens := make([]token, len(values))
  15. for i, v := range values {
  16. tokens[i] = token{
  17. id: int32(i),
  18. value: v,
  19. }
  20. }
  21. return tokens
  22. }
  23. // Helper to compare logit slices
  24. func compareLogits(t *testing.T, name string, want []float32, got []token) {
  25. t.Helper()
  26. if len(want) != len(got) {
  27. t.Errorf("%s: length mismatch: want %d, got %d", name, len(want), len(got))
  28. return
  29. }
  30. for i := range want {
  31. if math.Abs(float64(got[i].value-want[i])) > 1e-6 {
  32. t.Errorf("%s: index %d: want %f, got %f", name, i, want[i], got[i].value)
  33. }
  34. }
  35. }
  36. func TestTemperatureAndSoftmax(t *testing.T) {
  37. input := []float32{1, 4, -2, 0}
  38. got := temperature(toTokens(input), 0.5)
  39. // Check probabilities sum to 1
  40. var sum float32
  41. for _, token := range got {
  42. sum += token.value
  43. }
  44. if math.Abs(float64(sum-1.0)) > 1e-6 {
  45. t.Errorf("probabilities don't sum to 1: got %f", sum)
  46. }
  47. got = temperature(toTokens(input), 1)
  48. // Check probabilities sum to 1
  49. sum = 0.0
  50. for _, token := range got {
  51. sum += token.value
  52. }
  53. if math.Abs(float64(sum-1.0)) > 1e-6 {
  54. t.Errorf("probabilities don't sum to 1: got %f", sum)
  55. }
  56. }
  57. func TestTopK(t *testing.T) {
  58. input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
  59. // Test k=5
  60. got := topK(toTokens(input), 5)
  61. if len(got) != 5 {
  62. t.Errorf("topK(5): wrong length: want 5, got %d", len(got))
  63. }
  64. // Should keep highest 3 values in descending order
  65. want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
  66. compareLogits(t, "topK(3)", want, got)
  67. got = topK(toTokens(input), 20)
  68. if len(got) != len(input) {
  69. t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got))
  70. }
  71. // Test k=-1
  72. input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
  73. want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
  74. got = topK(toTokens(input), -1)
  75. if len(got) != len(input) {
  76. t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
  77. }
  78. compareLogits(t, "topK(-1)", want, got)
  79. // Test k=0
  80. input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
  81. want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
  82. got = topK(toTokens(input), 0)
  83. if len(got) != len(input) {
  84. t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
  85. }
  86. compareLogits(t, "topK(-1)", want, got)
  87. }
  88. func TestTopP(t *testing.T) {
  89. input := []float32{-3, -2, -1, 0, 1, 2, 4}
  90. tokens := toTokens(input)
  91. // First apply temperature and softmax to get probabilities
  92. tokens = temperature(tokens, 1)
  93. tokens = topK(tokens, 20)
  94. // Then apply topP
  95. got := topP(tokens, 0.95)
  96. // Should keep tokens until cumsum > 0.95
  97. if len(got) > 3 {
  98. t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
  99. t.Logf("got: %v", got)
  100. }
  101. }
  102. func TestMinP(t *testing.T) {
  103. input := []float32{-3, -2, -1, 0, 1, 2, 4, 3}
  104. tokens := toTokens(input)
  105. // First apply temperature and softmax
  106. tokens = temperature(tokens, 1)
  107. // Then apply minP
  108. got := minP(tokens, 0.2)
  109. // Should keep tokens with prob >= 0.2 * max_prob
  110. if len(got) > 3 {
  111. t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
  112. }
  113. }
  114. func TestSortLogits(t *testing.T) {
  115. input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
  116. tokens := toTokens(input)
  117. tokens = topK(tokens, 20)
  118. for i := 1; i < len(tokens); i++ {
  119. if tokens[i].value > tokens[i-1].value {
  120. t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
  121. i, tokens[i].value, tokens[i-1].value)
  122. }
  123. }
  124. want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
  125. compareLogits(t, "sortLogits", want, tokens)
  126. }
  127. // TestSortLogitsWithRealData tests sorting behavior using real model logit distributions
  128. func TestSortLogitsWithRealData(t *testing.T) {
  129. // This will be populated from testdata/logits.bin
  130. // Format: 32-bit float array in binary format
  131. logits, err := loadTestLogits(t)
  132. if err != nil {
  133. t.Skipf("Skipping real logit test: %v", err)
  134. return
  135. }
  136. tokens := toTokens(logits)
  137. sortLogits(tokens)
  138. // Calculate n for verification
  139. n := int(math.Sqrt(float64(len(tokens)))) + 1
  140. if n > 1000 {
  141. n = 1000
  142. } else if n < 100 {
  143. n = 100
  144. }
  145. t.Logf("Testing with %d tokens, partial sorting top %d", len(tokens), n)
  146. // Only verify the top n elements are sorted (which is what we guarantee)
  147. // This is much faster than checking the entire array
  148. topN := tokens[:n]
  149. for i := 1; i < len(topN); i++ {
  150. if topN[i].value > topN[i-1].value {
  151. t.Fatalf("top %d tokens not properly sorted at index %d: %.15f > %.15f",
  152. n, i, topN[i].value, topN[i-1].value)
  153. }
  154. }
  155. // Verify we didn't lose any high value tokens by checking that
  156. // all tokens after position n are <= the nth token
  157. // Do this in chunks to avoid timeouts on large arrays
  158. nthValue := tokens[n-1].value
  159. const chunkSize = 1000
  160. for start := n; start < len(tokens); start += chunkSize {
  161. end := min(start+chunkSize, len(tokens))
  162. for i := start; i < end; i++ {
  163. if tokens[i].value > nthValue {
  164. t.Fatalf("found higher value token after position %d: tokens[%d].value = %.15f > %.15f",
  165. n, i, tokens[i].value, nthValue)
  166. }
  167. }
  168. }
  169. }
  170. // loadTestLogits loads logit test data from testdata/logits.bin
  171. func loadTestLogits(t *testing.T) ([]float32, error) {
  172. t.Helper()
  173. _, currFile, _, ok := runtime.Caller(0)
  174. if !ok {
  175. return nil, errors.New("could not determine test file path")
  176. }
  177. testDataPath := filepath.Join(filepath.Dir(currFile), "testdata", "logits.bin")
  178. file, err := os.Open(testDataPath)
  179. if err != nil {
  180. return nil, err
  181. }
  182. defer file.Close()
  183. stat, err := file.Stat()
  184. if err != nil {
  185. return nil, err
  186. }
  187. numFloats := stat.Size() / 4 // each float32 is 4 bytes
  188. if numFloats*4 != stat.Size() {
  189. return nil, errors.New("logits.bin has invalid size: not a multiple of 4 bytes")
  190. }
  191. logits := make([]float32, numFloats)
  192. for i := range logits {
  193. var val uint32
  194. if err := binary.Read(file, binary.LittleEndian, &val); err != nil {
  195. return nil, err
  196. }
  197. logits[i] = math.Float32frombits(val)
  198. }
  199. if len(logits) == 0 {
  200. return nil, errors.New("logits.bin is empty")
  201. }
  202. return logits, nil
  203. }
  204. func BenchmarkTransforms(b *testing.B) {
  205. // Generate random logits
  206. tokens := make([]token, 1<<16)
  207. for i := range tokens {
  208. tokens[i] = token{
  209. id: int32(i),
  210. value: rand.Float32(),
  211. }
  212. }
  213. tokensCopy := make([]token, len(tokens))
  214. b.Run("Temperature", func(b *testing.B) {
  215. b.ResetTimer()
  216. for b.Loop() {
  217. copy(tokensCopy, tokens)
  218. temperature(tokensCopy, 0.5)
  219. }
  220. })
  221. b.Run("TopK", func(b *testing.B) {
  222. b.ResetTimer()
  223. for b.Loop() {
  224. copy(tokensCopy, tokens)
  225. topK(tokensCopy, 10)
  226. }
  227. })
  228. b.Run("TopP", func(b *testing.B) {
  229. b.ResetTimer()
  230. for b.Loop() {
  231. copy(tokensCopy, tokens)
  232. topP(tokensCopy, 0.9)
  233. }
  234. })
  235. b.Run("MinP", func(b *testing.B) {
  236. b.ResetTimer()
  237. for b.Loop() {
  238. copy(tokensCopy, tokens)
  239. minP(tokensCopy, 0.2)
  240. }
  241. })
  242. b.Run("SortTokens", func(b *testing.B) {
  243. b.ResetTimer()
  244. for b.Loop() {
  245. copy(tokensCopy, tokens)
  246. topK(tokensCopy, 200000)
  247. }
  248. })
  249. }