|
@@ -19,17 +19,7 @@ func TestTemperature(t *testing.T) {
|
|
if err != nil {
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
- want := []float64{-6, -4, -2, 0, 2, 4, 8}
|
|
|
|
- if !floats.Equal(logits, want) {
|
|
|
|
- t.Fatalf("got: %v, want: %v", logits, want)
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- // Only expect the max value returned
|
|
|
|
- logits, err = Temperature(0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
|
|
|
|
- if err != nil {
|
|
|
|
- t.Fatal(err)
|
|
|
|
- }
|
|
|
|
- want = []float64{4}
|
|
|
|
|
|
+ want := []float64{-14, -12, -10, -8, -6, -4, 0}
|
|
if !floats.Equal(logits, want) {
|
|
if !floats.Equal(logits, want) {
|
|
t.Fatalf("got: %v, want: %v", logits, want)
|
|
t.Fatalf("got: %v, want: %v", logits, want)
|
|
}
|
|
}
|
|
@@ -37,6 +27,9 @@ func TestTemperature(t *testing.T) {
|
|
if _, err := Temperature(-1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
|
|
if _, err := Temperature(-1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
|
|
t.Fatalf("expected error for temperature=-1, got %v", logits)
|
|
t.Fatalf("expected error for temperature=-1, got %v", logits)
|
|
}
|
|
}
|
|
|
|
+ if _, err := Temperature(2.1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
|
|
|
|
+ t.Fatalf("expected error for temperature=2.1, got %v", logits)
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
func TestSoftmax(t *testing.T) {
|
|
func TestSoftmax(t *testing.T) {
|
|
@@ -124,17 +117,17 @@ func TestSample(t *testing.T) {
|
|
want := []float64{1, 2, 3, 4}
|
|
want := []float64{1, 2, 3, 4}
|
|
|
|
|
|
var callOrder []int
|
|
var callOrder []int
|
|
- mock1 := &mockSampler{
|
|
|
|
|
|
+ mock1 := &testSampler{
|
|
id: 1,
|
|
id: 1,
|
|
callOrder: &callOrder,
|
|
callOrder: &callOrder,
|
|
returnVals: want,
|
|
returnVals: want,
|
|
}
|
|
}
|
|
- mock2 := &mockSampler{
|
|
|
|
|
|
+ mock2 := &testSampler{
|
|
id: 2,
|
|
id: 2,
|
|
callOrder: &callOrder,
|
|
callOrder: &callOrder,
|
|
returnVals: want,
|
|
returnVals: want,
|
|
}
|
|
}
|
|
- mock3 := &mockSampler{
|
|
|
|
|
|
+ mock3 := &testSampler{
|
|
id: 3,
|
|
id: 3,
|
|
callOrder: &callOrder,
|
|
callOrder: &callOrder,
|
|
returnVals: want,
|
|
returnVals: want,
|
|
@@ -153,7 +146,7 @@ func TestSample(t *testing.T) {
|
|
t.Errorf("got %v, want %v", got, want)
|
|
t.Errorf("got %v, want %v", got, want)
|
|
}
|
|
}
|
|
|
|
|
|
- errMock := &mockSampler{
|
|
|
|
|
|
+ errMock := &testSampler{
|
|
returnErr: fmt.Errorf("mock error"),
|
|
returnErr: fmt.Errorf("mock error"),
|
|
}
|
|
}
|
|
_, err = Sample(input, mock1, errMock, mock2)
|
|
_, err = Sample(input, mock1, errMock, mock2)
|
|
@@ -162,19 +155,30 @@ func TestSample(t *testing.T) {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
-type mockSampler struct {
|
|
|
|
|
|
+type testSampler struct {
|
|
id int
|
|
id int
|
|
callOrder *[]int
|
|
callOrder *[]int
|
|
returnVals []float64
|
|
returnVals []float64
|
|
returnErr error
|
|
returnErr error
|
|
}
|
|
}
|
|
|
|
|
|
-func (m *mockSampler) Sample(logits []float64) ([]float64, error) {
|
|
|
|
- if m.callOrder != nil {
|
|
|
|
- *m.callOrder = append(*m.callOrder, m.id)
|
|
|
|
|
|
+func (ts *testSampler) Sample(logits []float64) ([]float64, error) {
|
|
|
|
+ if ts.callOrder != nil {
|
|
|
|
+ *ts.callOrder = append(*ts.callOrder, ts.id)
|
|
}
|
|
}
|
|
- if m.returnErr != nil {
|
|
|
|
- return nil, m.returnErr
|
|
|
|
|
|
+ if ts.returnErr != nil {
|
|
|
|
+ return nil, ts.returnErr
|
|
|
|
+ }
|
|
|
|
+ return ts.returnVals, nil
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func TestSampleTemperatureZero(t *testing.T) {
|
|
|
|
+ logits, err := Sample([]float64{1, 2, 3, 4}, Temperature(0))
|
|
|
|
+ if err != nil {
|
|
|
|
+ t.Fatal(err)
|
|
|
|
+ }
|
|
|
|
+ want := []float64{3}
|
|
|
|
+ if !floats.Equal(logits, want) {
|
|
|
|
+ t.Fatalf("got: %v, want: %v", logits, want)
|
|
}
|
|
}
|
|
- return m.returnVals, nil
|
|
|
|
}
|
|
}
|