transforms_test.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. package sample
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "math"
  6. "math/rand/v2"
  7. "os"
  8. "path/filepath"
  9. "runtime"
  10. "testing"
  11. )
  12. // Helper to convert float32 slice to logit slice
  13. func toTokens(values []float32) []token {
  14. tokens := make([]token, len(values))
  15. for i, v := range values {
  16. tokens[i] = token{
  17. id: int32(i),
  18. value: v,
  19. }
  20. }
  21. return tokens
  22. }
  23. // Helper to compare logit slices
  24. func compareLogits(t *testing.T, name string, want []float32, got []token) {
  25. t.Helper()
  26. if len(want) != len(got) {
  27. t.Errorf("%s: length mismatch: want %d, got %d", name, len(want), len(got))
  28. return
  29. }
  30. for i := range want {
  31. if math.Abs(float64(got[i].value-want[i])) > 1e-6 {
  32. t.Errorf("%s: index %d: want %f, got %f", name, i, want[i], got[i].value)
  33. }
  34. }
  35. }
  36. func TestTemperatureAndSoftmax(t *testing.T) {
  37. input := []float32{1, 4, -2, 0}
  38. got := temperature(toTokens(input), 0.5)
  39. // Check probabilities sum to 1
  40. var sum float32
  41. for _, token := range got {
  42. sum += token.value
  43. }
  44. if math.Abs(float64(sum-1.0)) > 1e-6 {
  45. t.Errorf("probabilities don't sum to 1: got %f", sum)
  46. }
  47. got = temperature(toTokens(input), 1)
  48. // Check probabilities sum to 1
  49. sum = 0.0
  50. for _, token := range got {
  51. sum += token.value
  52. }
  53. if math.Abs(float64(sum-1.0)) > 1e-6 {
  54. t.Errorf("probabilities don't sum to 1: got %f", sum)
  55. }
  56. }
  57. func TestTopK(t *testing.T) {
  58. input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
  59. // Test k=3
  60. got := topK(toTokens(input), 5)
  61. if len(got) != 5 {
  62. t.Errorf("topK(5): wrong length: want 5, got %d", len(got))
  63. }
  64. // Should keep highest 3 values in descending order
  65. want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
  66. compareLogits(t, "topK(3)", want, got)
  67. got = topK(toTokens(input), 20)
  68. if len(got) != len(input) {
  69. t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got))
  70. }
  71. }
  72. func TestTopP(t *testing.T) {
  73. input := []float32{-3, -2, -1, 0, 1, 2, 4}
  74. tokens := toTokens(input)
  75. // First apply temperature and softmax to get probabilities
  76. tokens = temperature(tokens, 1)
  77. sortLogits(tokens)
  78. // Then apply topP
  79. got := topP(tokens, 0.95)
  80. // Should keep tokens until cumsum > 0.95
  81. if len(got) > 3 {
  82. t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
  83. t.Logf("got: %v", got)
  84. }
  85. }
  86. func TestMinP(t *testing.T) {
  87. input := []float32{-3, -2, -1, 0, 1, 2, 4, 3}
  88. tokens := toTokens(input)
  89. // First apply temperature and softmax
  90. tokens = temperature(tokens, 1)
  91. // Then apply minP
  92. got := minP(tokens, 0.2)
  93. // Should keep tokens with prob >= 0.2 * max_prob
  94. if len(got) > 3 {
  95. t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
  96. }
  97. }
  98. func TestSortLogits(t *testing.T) {
  99. input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
  100. tokens := toTokens(input)
  101. sortLogits(tokens)
  102. for i := 1; i < len(tokens); i++ {
  103. if tokens[i].value > tokens[i-1].value {
  104. t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
  105. i, tokens[i].value, tokens[i-1].value)
  106. }
  107. }
  108. want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
  109. compareLogits(t, "sortLogits", want, tokens)
  110. }
  111. // TestSortLogitsWithRealData tests sorting behavior using real model logit distributions
  112. func TestSortLogitsWithRealData(t *testing.T) {
  113. // This will be populated from testdata/logits.bin
  114. // Format: 32-bit float array in binary format
  115. logits, err := loadTestLogits(t)
  116. if err != nil {
  117. t.Skipf("Skipping real logit test: %v", err)
  118. return
  119. }
  120. tokens := toTokens(logits)
  121. sortLogits(tokens)
  122. // Calculate n for verification
  123. n := int(math.Sqrt(float64(len(tokens)))) + 1
  124. if n > 1000 {
  125. n = 1000
  126. } else if n < 100 {
  127. n = 100
  128. }
  129. t.Logf("Testing with %d tokens, partial sorting top %d", len(tokens), n)
  130. // Only verify the top n elements are sorted (which is what we guarantee)
  131. // This is much faster than checking the entire array
  132. topN := tokens[:n]
  133. for i := 1; i < len(topN); i++ {
  134. if topN[i].value > topN[i-1].value {
  135. t.Fatalf("top %d tokens not properly sorted at index %d: %.15f > %.15f",
  136. n, i, topN[i].value, topN[i-1].value)
  137. }
  138. }
  139. // Verify we didn't lose any high value tokens by checking that
  140. // all tokens after position n are <= the nth token
  141. // Do this in chunks to avoid timeouts on large arrays
  142. nthValue := tokens[n-1].value
  143. const chunkSize = 1000
  144. for start := n; start < len(tokens); start += chunkSize {
  145. end := min(start+chunkSize, len(tokens))
  146. for i := start; i < end; i++ {
  147. if tokens[i].value > nthValue {
  148. t.Fatalf("found higher value token after position %d: tokens[%d].value = %.15f > %.15f",
  149. n, i, tokens[i].value, nthValue)
  150. }
  151. }
  152. }
  153. }
  154. // loadTestLogits loads logit test data from testdata/logits.bin
  155. func loadTestLogits(t *testing.T) ([]float32, error) {
  156. t.Helper()
  157. _, currFile, _, ok := runtime.Caller(0)
  158. if !ok {
  159. return nil, errors.New("could not determine test file path")
  160. }
  161. testDataPath := filepath.Join(filepath.Dir(currFile), "testdata", "logits.bin")
  162. file, err := os.Open(testDataPath)
  163. if err != nil {
  164. return nil, err
  165. }
  166. defer file.Close()
  167. stat, err := file.Stat()
  168. if err != nil {
  169. return nil, err
  170. }
  171. numFloats := stat.Size() / 4 // each float32 is 4 bytes
  172. if numFloats*4 != stat.Size() {
  173. return nil, errors.New("logits.bin has invalid size: not a multiple of 4 bytes")
  174. }
  175. logits := make([]float32, numFloats)
  176. for i := range logits {
  177. var val uint32
  178. if err := binary.Read(file, binary.LittleEndian, &val); err != nil {
  179. return nil, err
  180. }
  181. logits[i] = math.Float32frombits(val)
  182. }
  183. if len(logits) == 0 {
  184. return nil, errors.New("logits.bin is empty")
  185. }
  186. return logits, nil
  187. }
  188. func BenchmarkTransforms(b *testing.B) {
  189. // Generate random logits
  190. tokens := make([]token, 1<<16)
  191. for i := range tokens {
  192. tokens[i] = token{
  193. id: int32(i),
  194. value: rand.Float32(),
  195. }
  196. }
  197. tokensCopy := make([]token, len(tokens))
  198. b.Run("Temperature", func(b *testing.B) {
  199. b.ResetTimer()
  200. for b.Loop() {
  201. copy(tokensCopy, tokens)
  202. temperature(tokensCopy, 0.5)
  203. }
  204. })
  205. b.Run("TopK", func(b *testing.B) {
  206. b.ResetTimer()
  207. for b.Loop() {
  208. copy(tokensCopy, tokens)
  209. topK(tokensCopy, 10)
  210. }
  211. })
  212. b.Run("TopP", func(b *testing.B) {
  213. b.ResetTimer()
  214. for b.Loop() {
  215. copy(tokensCopy, tokens)
  216. topP(tokensCopy, 0.9)
  217. }
  218. })
  219. b.Run("MinP", func(b *testing.B) {
  220. b.ResetTimer()
  221. for b.Loop() {
  222. copy(tokensCopy, tokens)
  223. minP(tokensCopy, 0.2)
  224. }
  225. })
  226. b.Run("SortTokens", func(b *testing.B) {
  227. b.ResetTimer()
  228. for b.Loop() {
  229. copy(tokensCopy, tokens)
  230. sortLogits(tokensCopy)
  231. }
  232. })
  233. }