samplers_test.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. package sample
  2. import (
  3. "math/rand/v2"
  4. "testing"
  5. )
  6. func TestWeighted(t *testing.T) {
  7. logits := []float32{-10, 3, -10, -10}
  8. sampler := NewSampler(0, 0, 0, 0, 0)
  9. got, err := sampler.Sample(logits)
  10. if err != nil {
  11. t.Error(err)
  12. return
  13. }
  14. want := int32(1)
  15. if want != got {
  16. t.Errorf("index mismatch: want %d, got %d", want, got)
  17. }
  18. logits = []float32{-100, -10, 0, 10}
  19. sampler = NewSampler(0, 0, 0, 0, 0)
  20. got, err = sampler.Sample(logits)
  21. if err != nil {
  22. t.Error(err)
  23. return
  24. }
  25. want = int32(3) // Should pick highest probability with this r value
  26. if want != got {
  27. t.Errorf("index mismatch: want %d, got %d", want, got)
  28. }
  29. }
  30. func TestNewSampler(t *testing.T) {
  31. tests := []struct {
  32. name string
  33. temperature float32
  34. topK int
  35. topP float32
  36. minP float32
  37. seed int
  38. wantGreedy bool // Instead of wantErr, check if we get greedy sampler
  39. }{
  40. {
  41. name: "temperature",
  42. temperature: 0.5,
  43. wantGreedy: false,
  44. },
  45. {
  46. name: "zero temperature - greedy",
  47. temperature: 0,
  48. wantGreedy: true,
  49. },
  50. {
  51. name: "top k",
  52. temperature: 0.1,
  53. topK: 10,
  54. wantGreedy: false,
  55. },
  56. {
  57. name: "top p",
  58. temperature: 0.1,
  59. topP: 0.9,
  60. wantGreedy: false,
  61. },
  62. {
  63. name: "min p",
  64. temperature: 0.1,
  65. minP: 0.2,
  66. wantGreedy: false,
  67. },
  68. {
  69. name: "seed - weighted",
  70. temperature: 0.1,
  71. seed: 42,
  72. wantGreedy: false,
  73. },
  74. {
  75. name: "default values",
  76. temperature: 0.8,
  77. topK: 40,
  78. topP: 0.9,
  79. minP: 0.0,
  80. seed: 0,
  81. wantGreedy: false,
  82. },
  83. {
  84. name: "all zeroes - greedy",
  85. temperature: 0.0,
  86. topK: 0,
  87. topP: 0.0,
  88. minP: 0.0,
  89. seed: 0,
  90. wantGreedy: true,
  91. },
  92. {
  93. name: "all transforms",
  94. temperature: 0.8,
  95. topK: 50,
  96. topP: 0.95,
  97. minP: 0.1,
  98. seed: 42,
  99. wantGreedy: false,
  100. },
  101. }
  102. for _, tt := range tests {
  103. t.Run(tt.name, func(t *testing.T) {
  104. sampler := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
  105. _, isGreedy := sampler.(*greedy)
  106. if isGreedy != tt.wantGreedy {
  107. t.Errorf("NewSampler() got greedy = %v, want %v", isGreedy, tt.wantGreedy)
  108. }
  109. })
  110. }
  111. }
  112. func BenchmarkSample(b *testing.B) {
  113. weighted := NewSampler(0.5, 10, 0.9, 0.2, -1)
  114. samplers := map[string]Sampler{
  115. "Greedy": NewSampler(0, 0, 0, 0, 0), // Use NewSampler with temp=0 for greedy
  116. "Weighted": weighted,
  117. }
  118. // Generate random logits for benchmarking
  119. logits := make([]float32, 1<<16)
  120. for i := range logits {
  121. logits[i] = rand.Float32()
  122. }
  123. for name, s := range samplers {
  124. b.Run(name, func(b *testing.B) {
  125. b.ResetTimer()
  126. for b.Loop() {
  127. if _, err := s.Sample(logits); err != nil {
  128. b.Error(err)
  129. }
  130. }
  131. })
  132. }
  133. }