samplers_test.go 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. package sample
  2. import (
  3. "math"
  4. "math/rand/v2"
  5. "testing"
  6. )
  7. func TestWeighted(t *testing.T) {
  8. logits := []float32{-10, 3, -10, -10}
  9. sampler := NewSampler(0, 0, 0, 0, 0, nil)
  10. got, err := sampler.Sample(logits)
  11. if err != nil {
  12. t.Error(err)
  13. return
  14. }
  15. want := int32(1)
  16. if want != got {
  17. t.Errorf("index mismatch: want %d, got %d", want, got)
  18. }
  19. logits = []float32{-100, -10, 0, 10}
  20. sampler = NewSampler(0, 0, 0, 0, 0, nil)
  21. got, err = sampler.Sample(logits)
  22. if err != nil {
  23. t.Error(err)
  24. return
  25. }
  26. want = int32(3) // Should pick highest probability with this r value
  27. if want != got {
  28. t.Errorf("index mismatch: want %d, got %d", want, got)
  29. }
  30. // Test very high p
  31. logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
  32. // Use extremely small topP to filter out all tokens
  33. sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
  34. got, err = sampler.Sample(logits)
  35. if err != nil {
  36. t.Error(err)
  37. return
  38. }
  39. // Should get the token with the highest logit
  40. want = int32(0)
  41. if want != got {
  42. t.Errorf("index mismatch: want %d, got %d", want, got)
  43. }
  44. logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
  45. sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
  46. got, err = sampler.Sample(logits)
  47. if err == nil {
  48. t.Errorf("expected error, got %d", got)
  49. return
  50. }
  51. }
  52. func BenchmarkSample(b *testing.B) {
  53. samplers := map[string]Sampler{
  54. "Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
  55. "Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
  56. }
  57. // Generate random logits for benchmarking
  58. logits := make([]float32, 1<<16)
  59. for i := range logits {
  60. logits[i] = rand.Float32()
  61. }
  62. for name, s := range samplers {
  63. b.Run(name, func(b *testing.B) {
  64. b.ResetTimer()
  65. for b.Loop() {
  66. if _, err := s.Sample(logits); err != nil {
  67. b.Fatalf("error sampling: %v", err)
  68. }
  69. }
  70. })
  71. }
  72. }