Browse Source

improve temperature sampler

ParthSareen 3 tháng trước cách đây
mục cha
commit
7cd9fbbbb1
2 tập tin đã thay đổi với 56 bổ sung32 xóa
  1. 30 10
      sample/sample.go
  2. 26 22
      sample/sample_test.go

+ 30 - 10
sample/sample.go

@@ -16,16 +16,19 @@ type Sampler interface {
 
 type Temperature float64
 
-func (s Temperature) Sample(logits []float64) ([]float64, error) {
-	if s < 0 || s > 1 {
-		return nil, errors.New("temperature must be between 0 and 1")
+func (t Temperature) Sample(logits []float64) ([]float64, error) {
+	if t < 0 || t > 2 {
+		return nil, errors.New("temperature must be between 0 and 2")
 	}
 
-	// greedy sampling
-	if s == 0 {
-		return []float64{floats.Max(logits)}, nil
+	// subtracting max logit to avoid under/overflow
+	maxLogit := floats.Max(logits)
+
+	temp := math.Max(float64(t), 1e-7)
+	for i := range logits {
+		logits[i] = (logits[i] - maxLogit) / temp
 	}
-	floats.Scale(1.0/float64(s), logits)
+
 	return logits, nil
 }
 
@@ -47,10 +50,8 @@ func computeSoftmax(logits []float64) ([]float64, error) {
 	}
 
 	floatSum := floats.Sum(copiedLogits)
-	if floatSum == 0 {
-		return nil, errors.New("no valid tokens found")
-	}
 	floats.Scale(1.0/floatSum, copiedLogits)
+
 	return copiedLogits, nil
 }
 
@@ -175,9 +176,28 @@ func (s weighed) Sample(logits []float64) ([]float64, error) {
 	return nil, errors.New("weighed sampler failed")
 }
 
+// TODO: remove after next PR merge
+type greedy struct{}
+
+func Greedy() Sampler {
+	return greedy{}
+}
+
+func (greedy) Sample(logits []float64) ([]float64, error) {
+	return []float64{float64(floats.MaxIdx(logits))}, nil
+}
+
 func Sample(logits []float64, samplers ...Sampler) ([]float64, error) {
 	var err error
 	for _, sampler := range samplers {
+		if sampler == Temperature(0) {
+			// early return with greedy if temperature is 0
+			logits, err = Greedy().Sample(logits)
+			if err != nil {
+				return nil, err
+			}
+			return logits, nil
+		}
 		logits, err = sampler.Sample(logits)
 		if err != nil {
 			return nil, err

+ 26 - 22
sample/sample_test.go

@@ -19,17 +19,7 @@ func TestTemperature(t *testing.T) {
 	if err != nil {
 		t.Fatal(err)
 	}
-	want := []float64{-6, -4, -2, 0, 2, 4, 8}
-	if !floats.Equal(logits, want) {
-		t.Fatalf("got: %v, want: %v", logits, want)
-	}
-
-	// Only expect the max value returned
-	logits, err = Temperature(0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
-	if err != nil {
-		t.Fatal(err)
-	}
-	want = []float64{4}
+	want := []float64{-14, -12, -10, -8, -6, -4, 0}
 	if !floats.Equal(logits, want) {
 		t.Fatalf("got: %v, want: %v", logits, want)
 	}
@@ -37,6 +27,9 @@ func TestTemperature(t *testing.T) {
 	if _, err := Temperature(-1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
 		t.Fatalf("expected error for temperature=-1, got %v", logits)
 	}
+	if _, err := Temperature(2.1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
+		t.Fatalf("expected error for temperature=2.1, got %v", logits)
+	}
 }
 
 func TestSoftmax(t *testing.T) {
@@ -124,17 +117,17 @@ func TestSample(t *testing.T) {
 	want := []float64{1, 2, 3, 4}
 
 	var callOrder []int
-	mock1 := &mockSampler{
+	mock1 := &testSampler{
 		id:         1,
 		callOrder:  &callOrder,
 		returnVals: want,
 	}
-	mock2 := &mockSampler{
+	mock2 := &testSampler{
 		id:         2,
 		callOrder:  &callOrder,
 		returnVals: want,
 	}
-	mock3 := &mockSampler{
+	mock3 := &testSampler{
 		id:         3,
 		callOrder:  &callOrder,
 		returnVals: want,
@@ -153,7 +146,7 @@ func TestSample(t *testing.T) {
 		t.Errorf("got %v, want %v", got, want)
 	}
 
-	errMock := &mockSampler{
+	errMock := &testSampler{
 		returnErr: fmt.Errorf("mock error"),
 	}
 	_, err = Sample(input, mock1, errMock, mock2)
@@ -162,19 +155,30 @@ func TestSample(t *testing.T) {
 	}
 }
 
-type mockSampler struct {
+type testSampler struct {
 	id         int
 	callOrder  *[]int
 	returnVals []float64
 	returnErr  error
 }
 
-func (m *mockSampler) Sample(logits []float64) ([]float64, error) {
-	if m.callOrder != nil {
-		*m.callOrder = append(*m.callOrder, m.id)
+func (ts *testSampler) Sample(logits []float64) ([]float64, error) {
+	if ts.callOrder != nil {
+		*ts.callOrder = append(*ts.callOrder, ts.id)
 	}
-	if m.returnErr != nil {
-		return nil, m.returnErr
+	if ts.returnErr != nil {
+		return nil, ts.returnErr
+	}
+	return ts.returnVals, nil
+}
+
+func TestSampleTemperatureZero(t *testing.T) {
+	logits, err := Sample([]float64{1, 2, 3, 4}, Temperature(0))
+	if err != nil {
+		t.Fatal(err)
+	}
+	want := []float64{3}
+	if !floats.Equal(logits, want) {
+		t.Fatalf("got: %v, want: %v", logits, want)
 	}
-	return m.returnVals, nil
 }