samplers_test.go 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. package sample
  2. import (
  3. "math/rand/v2"
  4. "testing"
  5. )
  6. func TestWeighted(t *testing.T) {
  7. logits := []float32{-10, 3, -10, -10}
  8. sampler := NewSampler(0, 0, 0, 0, 0, nil)
  9. got, err := sampler.Sample(logits)
  10. if err != nil {
  11. t.Error(err)
  12. return
  13. }
  14. want := int32(1)
  15. if want != got {
  16. t.Errorf("index mismatch: want %d, got %d", want, got)
  17. }
  18. logits = []float32{-100, -10, 0, 10}
  19. sampler = NewSampler(0, 0, 0, 0, 0, nil)
  20. got, err = sampler.Sample(logits)
  21. if err != nil {
  22. t.Error(err)
  23. return
  24. }
  25. want = int32(3) // Should pick highest probability with this r value
  26. if want != got {
  27. t.Errorf("index mismatch: want %d, got %d", want, got)
  28. }
  29. }
  30. func BenchmarkSample(b *testing.B) {
  31. samplers := map[string]Sampler{
  32. "Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
  33. "Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
  34. }
  35. // Generate random logits for benchmarking
  36. logits := make([]float32, 1<<16)
  37. for i := range logits {
  38. logits[i] = rand.Float32()
  39. }
  40. for name, s := range samplers {
  41. b.Run(name, func(b *testing.B) {
  42. b.ResetTimer()
  43. for b.Loop() {
  44. if _, err := s.Sample(logits); err != nil {
  45. b.Fatalf("error sampling: %v", err)
  46. }
  47. }
  48. })
  49. }
  50. }