123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242 |
- package sample
- import (
- "fmt"
- "math"
- "math/rand/v2"
- "testing"
- "github.com/google/go-cmp/cmp"
- )
- func TestTemperature(t *testing.T) {
- logits, err := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0})
- if err != nil {
- t.Error(err)
- return
- }
- want := []float64{-4, -10, 0, -14, -6, -12, -8}
- if diff := cmp.Diff(want, logits); diff != "" {
- t.Errorf("logits mismatch (-want +got):\n%s", diff)
- }
- logits, err = Temperature(-1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
- if err == nil {
- t.Errorf("expected error for temperature=-1, got %v", logits)
- }
- logits, err = Temperature(0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
- if err == nil {
- t.Errorf("expected error for temperature=0, got %v", logits)
- }
- logits, err = Temperature(2.1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
- if err == nil {
- t.Errorf("expected error for temperature=2.1, got %v", logits)
- }
- }
- func TestSoftmax(t *testing.T) {
- probs := softmax([]float64{-3, -2, -1, 0, 1, 2, 4})
- expectedProbs := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085}
- if diff := cmp.Diff(expectedProbs, probs); diff != "" {
- t.Errorf("probs mismatch (-want +got):\n%s", diff)
- }
- }
- func TestTopK(t *testing.T) {
- logits, err := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
- if err != nil {
- t.Error(err)
- return
- }
- expectedlogits := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4}
- if diff := cmp.Diff(expectedlogits, logits); diff != "" {
- t.Errorf("logits mismatch (-want +got):\n%s", diff)
- }
- _, err = TopK(0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
- if err == nil {
- t.Errorf("expected error for k=0, got %v", err)
- }
- logits, err = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
- if err != nil {
- t.Error(err)
- return
- }
- expectedlogits = []float64{-3, -2, -1, 0, 1, 2, 4}
- if diff := cmp.Diff(expectedlogits, logits); diff != "" {
- t.Errorf("logits mismatch (-want +got):\n%s", diff)
- }
- }
- func TestTopP(t *testing.T) {
- logits, err := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
- if err != nil {
- t.Error(err)
- return
- }
- want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4}
- if diff := cmp.Diff(want, logits); diff != "" {
- t.Errorf("logits mismatch (-want +got):\n%s", diff)
- }
- _, err = TopP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
- if err == nil {
- t.Error("expected error for p=1.0")
- }
- _, err = TopP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
- if err == nil {
- t.Error("expected error for p=0.0")
- }
- }
- func TestMinP(t *testing.T) {
- logits, err := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3})
- if err != nil {
- t.Error(err)
- return
- }
- want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3}
- if diff := cmp.Diff(want, logits); diff != "" {
- t.Errorf("logits mismatch (-want +got):\n%s", diff)
- }
- _, err = MinP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
- if err == nil {
- t.Error("expected error for p=1.0")
- }
- _, err = MinP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
- if err == nil {
- t.Error("expected error for p=0.0")
- }
- }
- func TestWeighed(t *testing.T) {
- idx, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))})
- if err != nil {
- t.Error(err)
- return
- }
- want := 1
- if diff := cmp.Diff(want, idx); diff != "" {
- t.Errorf("index mismatch (-want +got):\n%s", diff)
- }
- idx, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))})
- if err == nil {
- t.Error("expected error for no valid tokens, got index", idx)
- }
- }
- func TestSample(t *testing.T) {
- input := []float32{1, 2, 3, 4}
- var callOrder []int
- mock1 := &testTransform{
- id: 1,
- callOrder: &callOrder,
- }
- mock2 := &testTransform{
- id: 2,
- callOrder: &callOrder,
- }
- mock3 := &testTransform{
- id: 3,
- callOrder: &callOrder,
- }
- got, err := Greedy().Sample(input, mock1, mock2, mock3)
- if err != nil {
- t.Error(err)
- return
- }
- want := 3 // Greedy sampler should pick highest logit
- if diff := cmp.Diff(want, got); diff != "" {
- t.Errorf("sampled index mismatch (-want +got):\n%s", diff)
- }
- _, err = Weighted(nil).Sample(input, mock1, mock2, mock3)
- if err != nil {
- t.Error(err)
- return
- }
- wantOrder := []int{1, 2, 3}
- if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
- t.Errorf("call order mismatch (-want +got):\n%s", diff)
- }
- errMock := &testTransform{
- returnErr: fmt.Errorf("mock error"),
- }
- _, err = Weighted(nil).Sample(input, mock1, errMock, mock2)
- if err == nil {
- t.Error("Expected error from sampler")
- }
- }
- type testTransform struct {
- id int
- callOrder *[]int
- returnErr error
- }
- func (ts *testTransform) Apply(logits []float64) ([]float64, error) {
- if ts.callOrder != nil {
- *ts.callOrder = append(*ts.callOrder, ts.id)
- }
- if ts.returnErr != nil {
- return nil, ts.returnErr
- }
- return logits, nil
- }
- func BenchmarkTransform(b *testing.B) {
- transforms := map[string]Transform{
- "Temperature": Temperature(0.5),
- "TopK": TopK(10),
- "TopP": TopP(0.9),
- "MinP": MinP(0.2),
- }
- logits := make([]float64, 1<<16)
- for i := range logits {
- logits[i] = rand.Float64()
- }
- for name, transform := range transforms {
- b.Run(name, func(b *testing.B) {
- b.ResetTimer()
- for range b.N {
- _, err := transform.Apply(logits)
- if err != nil {
- b.Error(err)
- }
- }
- })
- }
- }
- func BenchmarkSample(b *testing.B) {
- samplers := map[string]Sampler{
- "Greedy": Greedy(),
- "Weighted": Weighted(nil),
- }
- logits := make([]float32, 1<<16)
- for i := range logits {
- logits[i] = rand.Float32()
- }
- for name, s := range samplers {
- b.Run(name, func(b *testing.B) {
- b.ResetTimer()
- for range b.N {
- if _, err := s.Sample(logits); err != nil {
- b.Error(err)
- }
- }
- })
- }
- }
|