|
@@ -0,0 +1,238 @@
|
|
|
+package sample
|
|
|
+
|
|
|
+import (
|
|
|
+ "math"
|
|
|
+ "math/rand/v2"
|
|
|
+ "testing"
|
|
|
+
|
|
|
+ "github.com/google/go-cmp/cmp"
|
|
|
+)
|
|
|
+
|
|
|
+func TestWeighted(t *testing.T) {
|
|
|
+ got, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))})
|
|
|
+ if err != nil {
|
|
|
+ t.Error(err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ want := int32(1)
|
|
|
+ if want != got {
|
|
|
+ t.Errorf("index mismatch: want %d, got %d", want, got)
|
|
|
+ }
|
|
|
+
|
|
|
+ got, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))})
|
|
|
+ if err == nil {
|
|
|
+ t.Error("expected error for no valid tokens, got index", got)
|
|
|
+ }
|
|
|
+
|
|
|
+ seed := uint64(42)
|
|
|
+ got, err = Weighted(&seed).Sample([]float32{1, 2, 3, 4})
|
|
|
+ if err != nil {
|
|
|
+ t.Error(err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ // With seed 42, we expect a consistent sample
|
|
|
+ want = int32(3) // This will be deterministic due to the seed
|
|
|
+ if want != got {
|
|
|
+ t.Errorf("index mismatch: want %d, got %d", want, got)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+type testTransform struct {
|
|
|
+ id int
|
|
|
+ callOrder *[]int
|
|
|
+}
|
|
|
+
|
|
|
+func (ts *testTransform) Apply(logits []float64) []float64 {
|
|
|
+ if ts.callOrder != nil {
|
|
|
+ *ts.callOrder = append(*ts.callOrder, ts.id)
|
|
|
+ }
|
|
|
+ return logits
|
|
|
+}
|
|
|
+
|
|
|
+func TestSample(t *testing.T) {
|
|
|
+ input := []float32{1, 2, 3, 4}
|
|
|
+
|
|
|
+ var callOrder []int
|
|
|
+ mock1 := &testTransform{
|
|
|
+ id: 1,
|
|
|
+ callOrder: &callOrder,
|
|
|
+ }
|
|
|
+ mock2 := &testTransform{
|
|
|
+ id: 2,
|
|
|
+ callOrder: &callOrder,
|
|
|
+ }
|
|
|
+ mock3 := &testTransform{
|
|
|
+ id: 3,
|
|
|
+ callOrder: &callOrder,
|
|
|
+ }
|
|
|
+
|
|
|
+ got, err := Greedy(mock1, mock2, mock3).Sample(input)
|
|
|
+ if err != nil {
|
|
|
+ t.Error(err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ want := int32(3) // Greedy sampler should pick highest logit
|
|
|
+ if want != got {
|
|
|
+ t.Errorf("index mismatch: want %d, got %d", want, got)
|
|
|
+ }
|
|
|
+ wantOrder := []int{1, 2, 3}
|
|
|
+ if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
|
|
|
+ t.Errorf("call order mismatch (-want +got):\n%s", diff)
|
|
|
+ }
|
|
|
+
|
|
|
+ callOrder = nil
|
|
|
+
|
|
|
+ _, err = Weighted(nil, mock1, mock2, mock3).Sample(input)
|
|
|
+ if err != nil {
|
|
|
+ t.Error(err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ wantOrder = []int{1, 2, 3}
|
|
|
+ if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
|
|
|
+ t.Errorf("call order mismatch (-want +got):\n%s", diff)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestNewSampler(t *testing.T) {
|
|
|
+ tests := []struct {
|
|
|
+ name string
|
|
|
+ temperature float32
|
|
|
+ topK int
|
|
|
+ topP float32
|
|
|
+ minP float32
|
|
|
+ seed int
|
|
|
+ wantErr bool
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ name: "no transforms",
|
|
|
+ wantErr: true,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "temperature",
|
|
|
+ temperature: 0.5,
|
|
|
+ wantErr: false,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "invalid temperature negative",
|
|
|
+ temperature: -1,
|
|
|
+ wantErr: true,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "invalid temperature too high",
|
|
|
+ temperature: 2.1,
|
|
|
+ wantErr: true,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "top k",
|
|
|
+ topK: 10,
|
|
|
+ wantErr: false,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "invalid top k negative",
|
|
|
+ topK: -1,
|
|
|
+ wantErr: true,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "top p",
|
|
|
+ topP: 0.9,
|
|
|
+ wantErr: false,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "invalid top p negative",
|
|
|
+ topP: -0.1,
|
|
|
+ wantErr: true,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "invalid top p one",
|
|
|
+ topP: 1.0,
|
|
|
+ wantErr: true,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "min p",
|
|
|
+ minP: 0.2,
|
|
|
+ wantErr: false,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "invalid min p negative",
|
|
|
+ minP: -0.1,
|
|
|
+ wantErr: true,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "invalid min p one",
|
|
|
+ minP: 1.0,
|
|
|
+ wantErr: true,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "seed",
|
|
|
+ seed: 42,
|
|
|
+ wantErr: true, // seed alone is not valid without other transforms
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "default values",
|
|
|
+ temperature: 0.8,
|
|
|
+ topK: 40,
|
|
|
+ topP: 0.9,
|
|
|
+ minP: 0.0,
|
|
|
+ seed: 0,
|
|
|
+ wantErr: false,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "all zeroes",
|
|
|
+ temperature: 0.0,
|
|
|
+ topK: 0,
|
|
|
+ topP: 0.0,
|
|
|
+ minP: 0.0,
|
|
|
+ seed: 0,
|
|
|
+ wantErr: true, // all zeroes means no transforms
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "all transforms",
|
|
|
+ temperature: 0.8,
|
|
|
+ topK: 50,
|
|
|
+ topP: 0.95,
|
|
|
+ minP: 0.1,
|
|
|
+ seed: 42,
|
|
|
+ wantErr: false,
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, tt := range tests {
|
|
|
+ t.Run(tt.name, func(t *testing.T) {
|
|
|
+ _, err := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
|
|
|
+ if (err != nil) != tt.wantErr {
|
|
|
+ t.Errorf("NewSampler() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func BenchmarkSample(b *testing.B) {
|
|
|
+ transforms := []Transform{
|
|
|
+ Temperature(0.5),
|
|
|
+ TopK(10),
|
|
|
+ TopP(0.9),
|
|
|
+ MinP(0.2),
|
|
|
+ }
|
|
|
+
|
|
|
+ samplers := map[string]Sampler{
|
|
|
+ "Greedy": Greedy(transforms...),
|
|
|
+ "Weighted": Weighted(nil, transforms...),
|
|
|
+ }
|
|
|
+
|
|
|
+ logits := make([]float32, 1<<16)
|
|
|
+ for i := range logits {
|
|
|
+ logits[i] = rand.Float32()
|
|
|
+ }
|
|
|
+
|
|
|
+ for name, s := range samplers {
|
|
|
+ b.Run(name, func(b *testing.B) {
|
|
|
+ b.ResetTimer()
|
|
|
+ for range b.N {
|
|
|
+ if _, err := s.Sample(logits); err != nil {
|
|
|
+ b.Error(err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|