|
@@ -3,9 +3,14 @@ package sample
|
|
|
import (
|
|
|
"fmt"
|
|
|
"math"
|
|
|
+ "math/rand"
|
|
|
+ "os"
|
|
|
+ "runtime"
|
|
|
"slices"
|
|
|
"testing"
|
|
|
|
|
|
+ "runtime/trace"
|
|
|
+
|
|
|
"gonum.org/v1/gonum/floats"
|
|
|
)
|
|
|
|
|
@@ -14,9 +19,9 @@ func TestTemperature(t *testing.T) {
|
|
|
if err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
|
- expectedlogits := []float64{-6, -4, -2, 0, 2, 4, 8}
|
|
|
- if !floats.Equal(logits, expectedlogits) {
|
|
|
- t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits)
|
|
|
+ want := []float64{-6, -4, -2, 0, 2, 4, 8}
|
|
|
+ if !floats.Equal(logits, want) {
|
|
|
+ t.Fatalf("got: %v, want: %v", logits, want)
|
|
|
}
|
|
|
|
|
|
// Only expect the max value returned
|
|
@@ -24,9 +29,9 @@ func TestTemperature(t *testing.T) {
|
|
|
if err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
|
- expectedlogits = []float64{4}
|
|
|
- if !floats.Equal(logits, expectedlogits) {
|
|
|
- t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits)
|
|
|
+ want = []float64{4}
|
|
|
+ if !floats.Equal(logits, want) {
|
|
|
+ t.Fatalf("got: %v, want: %v", logits, want)
|
|
|
}
|
|
|
|
|
|
if _, err := Temperature(-1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
|
|
@@ -35,7 +40,7 @@ func TestTemperature(t *testing.T) {
|
|
|
}
|
|
|
|
|
|
func TestSoftmax(t *testing.T) {
|
|
|
- probs, err := computeSoftmax([]float64{-3, -2, -1, 0, 1, 2, 4})
|
|
|
+ probs, err := Softmax().Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
|
|
|
if err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
@@ -66,9 +71,9 @@ func TestTopP(t *testing.T) {
|
|
|
if err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
|
- expectedLogits := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 2, 4}
|
|
|
- if !floats.Same(logits, expectedLogits) {
|
|
|
- t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedLogits)
|
|
|
+ want := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 2, 4}
|
|
|
+ if !floats.Same(logits, want) {
|
|
|
+ t.Fatalf("got: %v, want: %v", logits, want)
|
|
|
}
|
|
|
logits, err = TopP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
|
|
|
if err == nil {
|
|
@@ -85,9 +90,9 @@ func TestMinP(t *testing.T) {
|
|
|
if err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
|
- expectedLogits := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 3, 4}
|
|
|
- if !floats.Same(logits, expectedLogits) {
|
|
|
- t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedLogits)
|
|
|
+ want := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 3, 4}
|
|
|
+ if !floats.Same(logits, want) {
|
|
|
+ t.Fatalf("got: %v, want: %v", logits, want)
|
|
|
}
|
|
|
logits, err = MinP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
|
|
|
if err == nil {
|
|
@@ -104,9 +109,9 @@ func TestWeighed(t *testing.T) {
|
|
|
if err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
|
- expectedLogits := []float64{1}
|
|
|
- if !floats.Equal(logits, expectedLogits) {
|
|
|
- t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedLogits)
|
|
|
+ want := []float64{1}
|
|
|
+ if !floats.Equal(logits, want) {
|
|
|
+ t.Fatalf("got: %v, want: %v", logits, want)
|
|
|
}
|
|
|
logits, err = Weighed().Sample([]float64{math.NaN(), math.NaN(), math.NaN()})
|
|
|
if err == nil {
|
|
@@ -116,36 +121,36 @@ func TestWeighed(t *testing.T) {
|
|
|
|
|
|
func TestSample(t *testing.T) {
|
|
|
input := []float64{1, 2, 3, 4}
|
|
|
- expectedOutput := []float64{1, 2, 3, 4}
|
|
|
+ want := []float64{1, 2, 3, 4}
|
|
|
|
|
|
var callOrder []int
|
|
|
mock1 := &mockSampler{
|
|
|
id: 1,
|
|
|
callOrder: &callOrder,
|
|
|
- returnVals: expectedOutput,
|
|
|
+ returnVals: want,
|
|
|
}
|
|
|
mock2 := &mockSampler{
|
|
|
id: 2,
|
|
|
callOrder: &callOrder,
|
|
|
- returnVals: expectedOutput,
|
|
|
+ returnVals: want,
|
|
|
}
|
|
|
mock3 := &mockSampler{
|
|
|
id: 3,
|
|
|
callOrder: &callOrder,
|
|
|
- returnVals: expectedOutput,
|
|
|
+ returnVals: want,
|
|
|
}
|
|
|
|
|
|
- result, err := Sample(input, mock1, mock2, mock3)
|
|
|
+ got, err := Sample(input, mock1, mock2, mock3)
|
|
|
if err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
|
|
|
|
if !slices.Equal(callOrder, []int{1, 2, 3}) {
|
|
|
- t.Errorf("Expected call order [1,2,3], got %v", callOrder)
|
|
|
+ t.Errorf("got %v, want %v", callOrder, []int{1, 2, 3})
|
|
|
}
|
|
|
|
|
|
- if !floats.Equal(result, expectedOutput) {
|
|
|
- t.Errorf("Expected output %v, got %v", expectedOutput, result)
|
|
|
+ if !floats.Equal(got, want) {
|
|
|
+ t.Errorf("got %v, want %v", got, want)
|
|
|
}
|
|
|
|
|
|
errMock := &mockSampler{
|