123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238 |
- package sample
- import (
- "math"
- "math/rand/v2"
- "testing"
- "github.com/google/go-cmp/cmp"
- )
- func TestWeighted(t *testing.T) {
- got, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))})
- if err != nil {
- t.Error(err)
- return
- }
- want := int32(1)
- if want != got {
- t.Errorf("index mismatch: want %d, got %d", want, got)
- }
- got, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))})
- if err == nil {
- t.Error("expected error for no valid tokens, got index", got)
- }
- seed := uint64(42)
- got, err = Weighted(&seed).Sample([]float32{1, 2, 3, 4})
- if err != nil {
- t.Error(err)
- return
- }
- // With seed 42, we expect a consistent sample
- want = int32(3) // This will be deterministic due to the seed
- if want != got {
- t.Errorf("index mismatch: want %d, got %d", want, got)
- }
- }
- type testTransform struct {
- id int
- callOrder *[]int
- }
- func (ts *testTransform) Apply(logits []float64) []float64 {
- if ts.callOrder != nil {
- *ts.callOrder = append(*ts.callOrder, ts.id)
- }
- return logits
- }
- func TestSample(t *testing.T) {
- input := []float32{1, 2, 3, 4}
- var callOrder []int
- mock1 := &testTransform{
- id: 1,
- callOrder: &callOrder,
- }
- mock2 := &testTransform{
- id: 2,
- callOrder: &callOrder,
- }
- mock3 := &testTransform{
- id: 3,
- callOrder: &callOrder,
- }
- got, err := Greedy(mock1, mock2, mock3).Sample(input)
- if err != nil {
- t.Error(err)
- return
- }
- want := int32(3) // Greedy sampler should pick highest logit
- if want != got {
- t.Errorf("index mismatch: want %d, got %d", want, got)
- }
- wantOrder := []int{1, 2, 3}
- if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
- t.Errorf("call order mismatch (-want +got):\n%s", diff)
- }
- callOrder = nil
- _, err = Weighted(nil, mock1, mock2, mock3).Sample(input)
- if err != nil {
- t.Error(err)
- return
- }
- wantOrder = []int{1, 2, 3}
- if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
- t.Errorf("call order mismatch (-want +got):\n%s", diff)
- }
- }
- func TestNewSampler(t *testing.T) {
- tests := []struct {
- name string
- temperature float32
- topK int
- topP float32
- minP float32
- seed int
- wantErr bool
- }{
- {
- name: "no transforms",
- wantErr: true,
- },
- {
- name: "temperature",
- temperature: 0.5,
- wantErr: false,
- },
- {
- name: "invalid temperature negative",
- temperature: -1,
- wantErr: true,
- },
- {
- name: "invalid temperature too high",
- temperature: 2.1,
- wantErr: true,
- },
- {
- name: "top k",
- topK: 10,
- wantErr: false,
- },
- {
- name: "invalid top k negative",
- topK: -1,
- wantErr: true,
- },
- {
- name: "top p",
- topP: 0.9,
- wantErr: false,
- },
- {
- name: "invalid top p negative",
- topP: -0.1,
- wantErr: true,
- },
- {
- name: "invalid top p one",
- topP: 1.0,
- wantErr: true,
- },
- {
- name: "min p",
- minP: 0.2,
- wantErr: false,
- },
- {
- name: "invalid min p negative",
- minP: -0.1,
- wantErr: true,
- },
- {
- name: "invalid min p one",
- minP: 1.0,
- wantErr: true,
- },
- {
- name: "seed",
- seed: 42,
- wantErr: true, // seed alone is not valid without other transforms
- },
- {
- name: "default values",
- temperature: 0.8,
- topK: 40,
- topP: 0.9,
- minP: 0.0,
- seed: 0,
- wantErr: false,
- },
- {
- name: "all zeroes",
- temperature: 0.0,
- topK: 0,
- topP: 0.0,
- minP: 0.0,
- seed: 0,
- wantErr: true, // all zeroes means no transforms
- },
- {
- name: "all transforms",
- temperature: 0.8,
- topK: 50,
- topP: 0.95,
- minP: 0.1,
- seed: 42,
- wantErr: false,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- _, err := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
- if (err != nil) != tt.wantErr {
- t.Errorf("NewSampler() error = %v, wantErr %v", err, tt.wantErr)
- }
- })
- }
- }
- func BenchmarkSample(b *testing.B) {
- transforms := []Transform{
- Temperature(0.5),
- TopK(10),
- TopP(0.9),
- MinP(0.2),
- }
- samplers := map[string]Sampler{
- "Greedy": Greedy(transforms...),
- "Weighted": Weighted(nil, transforms...),
- }
- logits := make([]float32, 1<<16)
- for i := range logits {
- logits[i] = rand.Float32()
- }
- for name, s := range samplers {
- b.Run(name, func(b *testing.B) {
- b.ResetTimer()
- for range b.N {
- if _, err := s.Sample(logits); err != nil {
- b.Error(err)
- }
- }
- })
- }
- }
|