samplers_test.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. package sample
  2. import (
  3. "math"
  4. "math/rand/v2"
  5. "testing"
  6. "github.com/google/go-cmp/cmp"
  7. )
  8. func TestWeighted(t *testing.T) {
  9. got, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))})
  10. if err != nil {
  11. t.Error(err)
  12. return
  13. }
  14. want := int32(1)
  15. if want != got {
  16. t.Errorf("index mismatch: want %d, got %d", want, got)
  17. }
  18. got, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))})
  19. if err == nil {
  20. t.Error("expected error for no valid tokens, got index", got)
  21. }
  22. seed := uint64(42)
  23. got, err = Weighted(&seed).Sample([]float32{1, 2, 3, 4})
  24. if err != nil {
  25. t.Error(err)
  26. return
  27. }
  28. // With seed 42, we expect a consistent sample
  29. want = int32(3) // This will be deterministic due to the seed
  30. if want != got {
  31. t.Errorf("index mismatch: want %d, got %d", want, got)
  32. }
  33. }
  34. type testTransform struct {
  35. id int
  36. callOrder *[]int
  37. }
  38. func (ts *testTransform) Apply(logits []float64) []float64 {
  39. if ts.callOrder != nil {
  40. *ts.callOrder = append(*ts.callOrder, ts.id)
  41. }
  42. return logits
  43. }
  44. func TestSample(t *testing.T) {
  45. input := []float32{1, 2, 3, 4}
  46. var callOrder []int
  47. mock1 := &testTransform{
  48. id: 1,
  49. callOrder: &callOrder,
  50. }
  51. mock2 := &testTransform{
  52. id: 2,
  53. callOrder: &callOrder,
  54. }
  55. mock3 := &testTransform{
  56. id: 3,
  57. callOrder: &callOrder,
  58. }
  59. got, err := Greedy(mock1, mock2, mock3).Sample(input)
  60. if err != nil {
  61. t.Error(err)
  62. return
  63. }
  64. want := int32(3) // Greedy sampler should pick highest logit
  65. if want != got {
  66. t.Errorf("index mismatch: want %d, got %d", want, got)
  67. }
  68. wantOrder := []int{1, 2, 3}
  69. if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
  70. t.Errorf("call order mismatch (-want +got):\n%s", diff)
  71. }
  72. callOrder = nil
  73. _, err = Weighted(nil, mock1, mock2, mock3).Sample(input)
  74. if err != nil {
  75. t.Error(err)
  76. return
  77. }
  78. wantOrder = []int{1, 2, 3}
  79. if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
  80. t.Errorf("call order mismatch (-want +got):\n%s", diff)
  81. }
  82. }
  83. func TestNewSampler(t *testing.T) {
  84. tests := []struct {
  85. name string
  86. temperature float32
  87. topK int
  88. topP float32
  89. minP float32
  90. seed int
  91. wantErr bool
  92. }{
  93. {
  94. name: "no transforms",
  95. wantErr: true,
  96. },
  97. {
  98. name: "temperature",
  99. temperature: 0.5,
  100. wantErr: false,
  101. },
  102. {
  103. name: "invalid temperature negative",
  104. temperature: -1,
  105. wantErr: true,
  106. },
  107. {
  108. name: "invalid temperature too high",
  109. temperature: 2.1,
  110. wantErr: true,
  111. },
  112. {
  113. name: "top k",
  114. topK: 10,
  115. wantErr: false,
  116. },
  117. {
  118. name: "invalid top k negative",
  119. topK: -1,
  120. wantErr: true,
  121. },
  122. {
  123. name: "top p",
  124. topP: 0.9,
  125. wantErr: false,
  126. },
  127. {
  128. name: "invalid top p negative",
  129. topP: -0.1,
  130. wantErr: true,
  131. },
  132. {
  133. name: "invalid top p one",
  134. topP: 1.0,
  135. wantErr: true,
  136. },
  137. {
  138. name: "min p",
  139. minP: 0.2,
  140. wantErr: false,
  141. },
  142. {
  143. name: "invalid min p negative",
  144. minP: -0.1,
  145. wantErr: true,
  146. },
  147. {
  148. name: "invalid min p one",
  149. minP: 1.0,
  150. wantErr: true,
  151. },
  152. {
  153. name: "seed",
  154. seed: 42,
  155. wantErr: true, // seed alone is not valid without other transforms
  156. },
  157. {
  158. name: "default values",
  159. temperature: 0.8,
  160. topK: 40,
  161. topP: 0.9,
  162. minP: 0.0,
  163. seed: 0,
  164. wantErr: false,
  165. },
  166. {
  167. name: "all zeroes",
  168. temperature: 0.0,
  169. topK: 0,
  170. topP: 0.0,
  171. minP: 0.0,
  172. seed: 0,
  173. wantErr: true, // all zeroes means no transforms
  174. },
  175. {
  176. name: "all transforms",
  177. temperature: 0.8,
  178. topK: 50,
  179. topP: 0.95,
  180. minP: 0.1,
  181. seed: 42,
  182. wantErr: false,
  183. },
  184. }
  185. for _, tt := range tests {
  186. t.Run(tt.name, func(t *testing.T) {
  187. _, err := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
  188. if (err != nil) != tt.wantErr {
  189. t.Errorf("NewSampler() error = %v, wantErr %v", err, tt.wantErr)
  190. }
  191. })
  192. }
  193. }
  194. func BenchmarkSample(b *testing.B) {
  195. transforms := []Transform{
  196. Temperature(0.5),
  197. TopK(10),
  198. TopP(0.9),
  199. MinP(0.2),
  200. }
  201. samplers := map[string]Sampler{
  202. "Greedy": Greedy(transforms...),
  203. "Weighted": Weighted(nil, transforms...),
  204. }
  205. logits := make([]float32, 1<<16)
  206. for i := range logits {
  207. logits[i] = rand.Float32()
  208. }
  209. for name, s := range samplers {
  210. b.Run(name, func(b *testing.B) {
  211. b.ResetTimer()
  212. for range b.N {
  213. if _, err := s.Sample(logits); err != nil {
  214. b.Error(err)
  215. }
  216. }
  217. })
  218. }
  219. }