Ver Fonte

addressing comments + cleanup

ParthSareen há 3 meses atrás
pai
commit
5b19d4941a
2 ficheiros alterados com 50 adições e 39 exclusões
  1. 21 15
      sample/sample.go
  2. 29 24
      sample/sample_test.go

+ 21 - 15
sample/sample.go

@@ -1,9 +1,10 @@
 package sample
 package sample
 
 
 import (
 import (
+	"cmp"
 	"errors"
 	"errors"
 	"math"
 	"math"
-	"sort"
+	"slices"
 
 
 	"gonum.org/v1/gonum/floats"
 	"gonum.org/v1/gonum/floats"
 	"gonum.org/v1/gonum/stat/sampleuv"
 	"gonum.org/v1/gonum/stat/sampleuv"
@@ -20,13 +21,12 @@ func (s Temperature) Sample(logits []float64) ([]float64, error) {
 		return nil, errors.New("temperature must be between 0 and 1")
 		return nil, errors.New("temperature must be between 0 and 1")
 	}
 	}
 
 
-	copiedLogits := append([]float64(nil), logits...)
-	// Greedy sampling
+	// greedy sampling
 	if s == 0 {
 	if s == 0 {
-		return []float64{floats.Max(copiedLogits)}, nil
+		return []float64{floats.Max(logits)}, nil
 	}
 	}
-	floats.Scale(1.0/float64(s), copiedLogits)
-	return copiedLogits, nil
+	floats.Scale(1.0/float64(s), logits)
+	return logits, nil
 }
 }
 
 
 type softmax struct{}
 type softmax struct{}
@@ -69,8 +69,9 @@ func (k TopK) Sample(logits []float64) ([]float64, error) {
 		indices[i] = i
 		indices[i] = i
 	}
 	}
 
 
-	sort.Slice(indices, func(i, j int) bool {
-		return logits[indices[i]] > logits[indices[j]]
+	// sort in descending order
+	slices.SortFunc(indices, func(i, j int) int {
+		return cmp.Compare(logits[j], logits[i])
 	})
 	})
 
 
 	for _, idx := range indices[k:] {
 	for _, idx := range indices[k:] {
@@ -96,8 +97,10 @@ func (p TopP) Sample(logits []float64) ([]float64, error) {
 	for i := range indices {
 	for i := range indices {
 		indices[i] = i
 		indices[i] = i
 	}
 	}
-	sort.Slice(indices, func(i, j int) bool {
-		return probs[indices[i]] > probs[indices[j]]
+
+	// sort in descending order
+	slices.SortFunc(indices, func(i, j int) int {
+		return cmp.Compare(probs[j], probs[i])
 	})
 	})
 
 
 	cumSum := 0.0
 	cumSum := 0.0
@@ -127,9 +130,9 @@ func (p MinP) Sample(logits []float64) ([]float64, error) {
 	copiedProbs := make([]float64, len(probs))
 	copiedProbs := make([]float64, len(probs))
 	copy(copiedProbs, probs)
 	copy(copiedProbs, probs)
 
 
-	sort.Slice(copiedProbs, func(i, j int) bool { return copiedProbs[i] > copiedProbs[j] })
+	slices.Sort(copiedProbs)
 
 
-	maxProb := floats.Max(probs)
+	maxProb := copiedProbs[len(copiedProbs)-1]
 	probThreshold := float64(p) * maxProb
 	probThreshold := float64(p) * maxProb
 
 
 	for i := range probs {
 	for i := range probs {
@@ -162,20 +165,23 @@ func (s weighed) Sample(logits []float64) ([]float64, error) {
 		return nil, errors.New("no valid tokens found")
 		return nil, errors.New("no valid tokens found")
 	}
 	}
 
 
+	// usually, a softmax is applied to sample from the logits
+	// in this case the uv sampler normalizes the logits so that the sum of the weights is 1
 	w := sampleuv.NewWeighted(logitsCopy, nil)
 	w := sampleuv.NewWeighted(logitsCopy, nil)
 	if v, ok := w.Take(); ok {
 	if v, ok := w.Take(); ok {
+		// returns the token ID
 		return []float64{float64(indices[v])}, nil
 		return []float64{float64(indices[v])}, nil
 	}
 	}
 	return nil, errors.New("weighed sampler failed")
 	return nil, errors.New("weighed sampler failed")
 }
 }
 
 
-func Sample(tokenID []float64, samplers ...Sampler) ([]float64, error) {
+func Sample(logits []float64, samplers ...Sampler) ([]float64, error) {
 	var err error
 	var err error
 	for _, sampler := range samplers {
 	for _, sampler := range samplers {
-		tokenID, err = sampler.Sample(tokenID)
+		logits, err = sampler.Sample(logits)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
 	}
 	}
-	return tokenID, nil
+	return logits, nil
 }
 }

+ 29 - 24
sample/sample_test.go

@@ -3,9 +3,14 @@ package sample
 import (
 import (
 	"fmt"
 	"fmt"
 	"math"
 	"math"
+	"math/rand"
+	"os"
+	"runtime"
 	"slices"
 	"slices"
 	"testing"
 	"testing"
 
 
+	"runtime/trace"
+
 	"gonum.org/v1/gonum/floats"
 	"gonum.org/v1/gonum/floats"
 )
 )
 
 
@@ -14,9 +19,9 @@ func TestTemperature(t *testing.T) {
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
-	expectedlogits := []float64{-6, -4, -2, 0, 2, 4, 8}
-	if !floats.Equal(logits, expectedlogits) {
-		t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits)
+	want := []float64{-6, -4, -2, 0, 2, 4, 8}
+	if !floats.Equal(logits, want) {
+		t.Fatalf("got: %v, want: %v", logits, want)
 	}
 	}
 
 
 	// Only expect the max value returned
 	// Only expect the max value returned
@@ -24,9 +29,9 @@ func TestTemperature(t *testing.T) {
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
-	expectedlogits = []float64{4}
-	if !floats.Equal(logits, expectedlogits) {
-		t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits)
+	want = []float64{4}
+	if !floats.Equal(logits, want) {
+		t.Fatalf("got: %v, want: %v", logits, want)
 	}
 	}
 
 
 	if _, err := Temperature(-1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
 	if _, err := Temperature(-1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
@@ -35,7 +40,7 @@ func TestTemperature(t *testing.T) {
 }
 }
 
 
 func TestSoftmax(t *testing.T) {
 func TestSoftmax(t *testing.T) {
-	probs, err := computeSoftmax([]float64{-3, -2, -1, 0, 1, 2, 4})
+	probs, err := Softmax().Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
@@ -66,9 +71,9 @@ func TestTopP(t *testing.T) {
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
-	expectedLogits := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 2, 4}
-	if !floats.Same(logits, expectedLogits) {
-		t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedLogits)
+	want := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 2, 4}
+	if !floats.Same(logits, want) {
+		t.Fatalf("got: %v, want: %v", logits, want)
 	}
 	}
 	logits, err = TopP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
 	logits, err = TopP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
 	if err == nil {
 	if err == nil {
@@ -85,9 +90,9 @@ func TestMinP(t *testing.T) {
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
-	expectedLogits := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 3, 4}
-	if !floats.Same(logits, expectedLogits) {
-		t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedLogits)
+	want := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 3, 4}
+	if !floats.Same(logits, want) {
+		t.Fatalf("got: %v, want: %v", logits, want)
 	}
 	}
 	logits, err = MinP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
 	logits, err = MinP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
 	if err == nil {
 	if err == nil {
@@ -104,9 +109,9 @@ func TestWeighed(t *testing.T) {
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
-	expectedLogits := []float64{1}
-	if !floats.Equal(logits, expectedLogits) {
-		t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedLogits)
+	want := []float64{1}
+	if !floats.Equal(logits, want) {
+		t.Fatalf("got: %v, want: %v", logits, want)
 	}
 	}
 	logits, err = Weighed().Sample([]float64{math.NaN(), math.NaN(), math.NaN()})
 	logits, err = Weighed().Sample([]float64{math.NaN(), math.NaN(), math.NaN()})
 	if err == nil {
 	if err == nil {
@@ -116,36 +121,36 @@ func TestWeighed(t *testing.T) {
 
 
 func TestSample(t *testing.T) {
 func TestSample(t *testing.T) {
 	input := []float64{1, 2, 3, 4}
 	input := []float64{1, 2, 3, 4}
-	expectedOutput := []float64{1, 2, 3, 4}
+	want := []float64{1, 2, 3, 4}
 
 
 	var callOrder []int
 	var callOrder []int
 	mock1 := &mockSampler{
 	mock1 := &mockSampler{
 		id:         1,
 		id:         1,
 		callOrder:  &callOrder,
 		callOrder:  &callOrder,
-		returnVals: expectedOutput,
+		returnVals: want,
 	}
 	}
 	mock2 := &mockSampler{
 	mock2 := &mockSampler{
 		id:         2,
 		id:         2,
 		callOrder:  &callOrder,
 		callOrder:  &callOrder,
-		returnVals: expectedOutput,
+		returnVals: want,
 	}
 	}
 	mock3 := &mockSampler{
 	mock3 := &mockSampler{
 		id:         3,
 		id:         3,
 		callOrder:  &callOrder,
 		callOrder:  &callOrder,
-		returnVals: expectedOutput,
+		returnVals: want,
 	}
 	}
 
 
-	result, err := Sample(input, mock1, mock2, mock3)
+	got, err := Sample(input, mock1, mock2, mock3)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 
 
 	if !slices.Equal(callOrder, []int{1, 2, 3}) {
 	if !slices.Equal(callOrder, []int{1, 2, 3}) {
-		t.Errorf("Expected call order [1,2,3], got %v", callOrder)
+		t.Errorf("got %v, want %v", callOrder, []int{1, 2, 3})
 	}
 	}
 
 
-	if !floats.Equal(result, expectedOutput) {
-		t.Errorf("Expected output %v, got %v", expectedOutput, result)
+	if !floats.Equal(got, want) {
+		t.Errorf("got %v, want %v", got, want)
 	}
 	}
 
 
 	errMock := &mockSampler{
 	errMock := &mockSampler{