1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556 |
- package sample
- import (
- "math/rand/v2"
- "testing"
- )
- func TestWeighted(t *testing.T) {
- logits := []float32{-10, 3, -10, -10}
- sampler := NewSampler(0, 0, 0, 0, 0, nil)
- got, err := sampler.Sample(logits)
- if err != nil {
- t.Error(err)
- return
- }
- want := int32(1)
- if want != got {
- t.Errorf("index mismatch: want %d, got %d", want, got)
- }
- logits = []float32{-100, -10, 0, 10}
- sampler = NewSampler(0, 0, 0, 0, 0, nil)
- got, err = sampler.Sample(logits)
- if err != nil {
- t.Error(err)
- return
- }
- want = int32(3) // Should pick highest probability with this r value
- if want != got {
- t.Errorf("index mismatch: want %d, got %d", want, got)
- }
- }
- func BenchmarkSample(b *testing.B) {
- samplers := map[string]Sampler{
- "Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
- "Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
- }
- // Generate random logits for benchmarking
- logits := make([]float32, 1<<16)
- for i := range logits {
- logits[i] = rand.Float32()
- }
- for name, s := range samplers {
- b.Run(name, func(b *testing.B) {
- b.ResetTimer()
- for b.Loop() {
- if _, err := s.Sample(logits); err != nil {
- b.Fatalf("error sampling: %v", err)
- }
- }
- })
- }
- }
|