samplers_test.go 4.3 KB


  1. package sample
  2. import (
  3. "math"
  4. "math/rand/v2"
  5. "testing"
  6. "github.com/google/go-cmp/cmp"
  7. )
  8. func TestWeighted(t *testing.T) {
  9. got, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))})
  10. if err != nil {
  11. t.Error(err)
  12. return
  13. }
  14. want := int32(1)
  15. if want != got {
  16. t.Errorf("index mismatch: want %d, got %d", want, got)
  17. }
  18. got, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))})
  19. if err == nil {
  20. t.Error("expected error for no valid tokens, got index", got)
  21. }
  22. seed := uint64(42)
  23. got, err = Weighted(&seed).Sample([]float32{1, 2, 3, 4})
  24. if err != nil {
  25. t.Error(err)
  26. return
  27. }
  28. // With seed 42, we expect a consistent sample
  29. want = int32(3) // This will be deterministic due to the seed
  30. if want != got {
  31. t.Errorf("index mismatch: want %d, got %d", want, got)
  32. }
  33. }
  34. type testTransform struct {
  35. id int
  36. callOrder *[]int
  37. }
  38. func (ts *testTransform) Apply(logits []float64) []float64 {
  39. if ts.callOrder != nil {
  40. *ts.callOrder = append(*ts.callOrder, ts.id)
  41. }
  42. return logits
  43. }
  44. func TestSample(t *testing.T) {
  45. input := []float32{1, 2, 3, 4}
  46. var callOrder []int
  47. mock1 := &testTransform{
  48. id: 1,
  49. callOrder: &callOrder,
  50. }
  51. mock2 := &testTransform{
  52. id: 2,
  53. callOrder: &callOrder,
  54. }
  55. mock3 := &testTransform{
  56. id: 3,
  57. callOrder: &callOrder,
  58. }
  59. _, err := Weighted(nil, mock1, mock2, mock3).Sample(input)
  60. if err != nil {
  61. t.Error(err)
  62. return
  63. }
  64. wantOrder := []int{1, 2, 3}
  65. if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
  66. t.Errorf("call order mismatch (-want +got):\n%s", diff)
  67. }
  68. }
  69. func TestNewSampler(t *testing.T) {
  70. tests := []struct {
  71. name string
  72. temperature float32
  73. topK int
  74. topP float32
  75. minP float32
  76. seed int
  77. wantErr bool
  78. }{
  79. {
  80. name: "no transforms",
  81. // temperature is 0, so greedy should be used
  82. wantErr: false,
  83. },
  84. {
  85. name: "temperature",
  86. temperature: 0.5,
  87. wantErr: false,
  88. },
  89. {
  90. name: "invalid temperature negative",
  91. temperature: -1,
  92. wantErr: true,
  93. },
  94. {
  95. name: "invalid temperature too high",
  96. temperature: 2.1,
  97. wantErr: true,
  98. },
  99. {
  100. name: "top k",
  101. topK: 10,
  102. temperature: 0.8,
  103. wantErr: false,
  104. },
  105. {
  106. name: "invalid top k negative",
  107. topK: -1,
  108. temperature: 0.8,
  109. wantErr: true,
  110. },
  111. {
  112. name: "top p",
  113. topP: 0.9,
  114. temperature: 0.8,
  115. wantErr: false,
  116. },
  117. {
  118. name: "invalid top p negative",
  119. topP: -0.1,
  120. temperature: 0.8,
  121. wantErr: true,
  122. },
  123. {
  124. name: "invalid top p one",
  125. topP: 1.0,
  126. temperature: 0.8,
  127. wantErr: true,
  128. },
  129. {
  130. name: "min p",
  131. minP: 0.2,
  132. temperature: 0.8,
  133. wantErr: false,
  134. },
  135. {
  136. name: "invalid min p negative",
  137. minP: -0.1,
  138. temperature: 0.8,
  139. wantErr: true,
  140. },
  141. {
  142. name: "invalid min p one",
  143. minP: 1.0,
  144. temperature: 0.8,
  145. wantErr: true,
  146. },
  147. {
  148. name: "default values",
  149. temperature: 0.8,
  150. topK: 40,
  151. topP: 0.9,
  152. minP: 0.0,
  153. seed: 0,
  154. wantErr: false,
  155. },
  156. {
  157. name: "all zeroes",
  158. temperature: 0.0,
  159. topK: 0,
  160. topP: 0.0,
  161. minP: 0.0,
  162. seed: 0,
  163. wantErr: false, // all zeroes means no transforms
  164. },
  165. {
  166. name: "all transforms",
  167. temperature: 0.8,
  168. topK: 50,
  169. topP: 0.95,
  170. minP: 0.1,
  171. seed: 42,
  172. wantErr: false,
  173. },
  174. }
  175. for _, tt := range tests {
  176. t.Run(tt.name, func(t *testing.T) {
  177. _, err := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
  178. if (err != nil) != tt.wantErr {
  179. t.Errorf("NewSampler() error = %v, wantErr %v", err, tt.wantErr)
  180. }
  181. })
  182. }
  183. }
  184. func BenchmarkSample(b *testing.B) {
  185. transforms := []Transform{
  186. Temperature(0.5),
  187. TopK(10),
  188. TopP(0.9),
  189. MinP(0.2),
  190. }
  191. samplers := map[string]Sampler{
  192. "Greedy": Greedy(),
  193. "Weighted": Weighted(nil, transforms...),
  194. }
  195. logits := make([]float32, 1<<16)
  196. for i := range logits {
  197. logits[i] = rand.Float32()
  198. }
  199. for name, s := range samplers {
  200. b.Run(name, func(b *testing.B) {
  201. b.ResetTimer()
  202. for range b.N {
  203. if _, err := s.Sample(logits); err != nil {
  204. b.Error(err)
  205. }
  206. }
  207. })
  208. }
  209. }