Prechádzať zdrojové kódy

sample: add error handling for empty logits (#9740)

Parth Sareen 1 mesiac pred
rodič
commit
42a14f7f63
3 zmenil súbory, kde vykonal 97 pridanie a 29 odobranie
  1. 7 7
      sample/samplers.go
  2. 24 0
      sample/samplers_test.go
  3. 66 22
      sample/transforms_test.go

+ 7 - 7
sample/samplers.go

@@ -26,6 +26,10 @@ type Sampler struct {
 }
 
 func (s *Sampler) Sample(logits []float32) (int32, error) {
+	if len(logits) == 0 {
+		return -1, errors.New("sample: no logits provided to sample")
+	}
+
 	tokens := make([]token, len(logits))
 	for i := range logits {
 		tokens[i].id = int32(i)
@@ -94,13 +98,6 @@ func (s *Sampler) sample(tokens []token) (token, error) {
 	tokens = topP(tokens, s.topP)
 	tokens = minP(tokens, s.minP)
 
-	// TODO: this should fall back to greedy sampling
-	// or topP, topK values etc should be such that
-	// there are always tokens to sample from
-	if len(tokens) == 0 {
-		return token{}, errors.New("no tokens to sample from")
-	}
-
 	var r float32
 	if s.rng != nil {
 		r = s.rng.Float32()
@@ -123,6 +120,9 @@ func (s *Sampler) sample(tokens []token) (token, error) {
 		return 1
 	})
 
+	if math.IsNaN(float64(sum)) {
+		return token{}, errors.New("sample: logits sum to NaN, check model output")
+	}
 	return tokens[idx], nil
 }
 

+ 24 - 0
sample/samplers_test.go

@@ -1,6 +1,7 @@
 package sample
 
 import (
+	"math"
 	"math/rand/v2"
 	"testing"
 )
@@ -29,6 +30,29 @@ func TestWeighted(t *testing.T) {
 	if want != got {
 		t.Errorf("index mismatch: want %d, got %d", want, got)
 	}
+
+	// Test very high p
+	logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
+	// Use extremely small topP to filter out all tokens
+	sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
+	got, err = sampler.Sample(logits)
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	// Should get the token with the highest logit
+	want = int32(0)
+	if want != got {
+		t.Errorf("index mismatch: want %d, got %d", want, got)
+	}
+
+	logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
+	sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
+	got, err = sampler.Sample(logits)
+	if err == nil {
+		t.Errorf("expected error, got %d", got)
+		return
+	}
 }
 
 func BenchmarkSample(b *testing.B) {

+ 66 - 22
sample/transforms_test.go

@@ -168,27 +168,53 @@ func TestTopP(t *testing.T) {
 	softmax(tokens)
 	tokens = topK(tokens, 20)
 
-	// Then apply topP
-	tokens = topP(tokens, 0.95)
+	// Test with very high p value
+	got := topP(tokens, 1.0)
 
-	// Should keep tokens until cumsum > 0.95
-	if len(tokens) > 3 {
+	// Should keep all tokens since p is 1
+	if len(got) != len(input) {
+		t.Errorf("topP(1.0): should keep all tokens, got %d, want %d", len(got), len(input))
+	}
+
+	// Test with normal p value
+	got = topP(tokens, 0.95)
+
+	if len(got) > 3 {
 		t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
-		t.Logf("got: %v", tokens)
+		t.Logf("got: %v", got)
 	}
 
 	// Test edge case - ensure at least one token remains
-	input = []float32{-1e6, -1e6, -1e6} // One dominant token
+	input = []float32{-1e6, -1e6, -1e7}
 	tokens = toTokens(input)
+	tokens = topK(tokens, 20)
 	softmax(tokens)
-	tokens = topP(tokens, 0.0) // Very small p
-	if len(tokens) < 1 {
+	got = topP(tokens, 0.0)
+	if len(got) < 1 {
 		t.Error("topP should keep at least one token")
 	}
+
+	// Test with zero p value
+	got = topP(tokens, 0.0)
+
+	// Should keep only the highest probability token
+	if len(got) != 1 {
+		t.Errorf("topP(0.0): should keep only one token, got %d", len(got))
+		t.Logf("got: %v", got)
+	}
+
+	tokens = toTokens(input)
+	tokens = topK(tokens, 20)
+	softmax(tokens)
+	got = topP(tokens, 1e-10)
+	if len(got) == 0 {
+		t.Errorf("topP(1e-10): should keep at least one token, got %d", len(got))
+		t.Logf("got: %v", got)
+	}
 }
 
 func TestMinP(t *testing.T) {
-	input := []float32{-3, -2, -1, 0, 1, 2, 4, 3}
+	input := []float32{-2, 0, -1, -3, 2, 1, 4, 3}
 	tokens := toTokens(input)
 
 	// First apply temperature and softmax
@@ -225,30 +251,48 @@ func TestMinP(t *testing.T) {
 		t.Logf("got: %v", tokens)
 	}
 
+	// Test with single token
+	tokens = toTokens(input[:1])
+	tokens = topK(tokens, 20)
+	softmax(tokens)
+	tokens = minP(tokens, 0.1)
+
+	// Should keep only the highest probability token
+	if len(tokens) != 1 {
+		t.Errorf("minP(0.1): should return single token, got %d", len(tokens))
+		t.Logf("got: %v", tokens)
+	}
+
 	input = []float32{1e-10, 1e-10, 1e-10}
 	tokens = toTokens(input)
 	softmax(tokens)
 	tokens = minP(tokens, 1.0)
 	if len(tokens) < 1 {
 		t.Error("minP should keep at least one token even with extreme probabilities")
-	}
-}
+		got := minP(tokens, 1.0)
 
-func TestSortLogits(t *testing.T) {
-	input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
-	tokens := toTokens(input)
+		if len(got) != 1 {
+			t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(got), len(tokens))
+		}
 
-	tokens = topK(tokens, 20)
+		// Test with normal p value
+		got = minP(tokens, 0.2)
 
-	for i := 1; i < len(tokens); i++ {
-		if tokens[i].value > tokens[i-1].value {
-			t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
-				i, tokens[i].value, tokens[i-1].value)
+		// Should keep tokens with prob >= 0.2 * max_prob
+		if len(got) > 3 {
+			t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
+			t.Logf("got: %v", got)
 		}
-	}
 
-	want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
-	compareLogits(t, "sortLogits", want, tokens)
+		// Test with zero p value
+		got = minP(tokens, 0.0)
+
+		// Should keep only the highest probability token
+		if len(got) != len(tokens) {
+			t.Errorf("minP(0.0): should keep only one token, got %d", len(got))
+			t.Logf("got: %v", got)
+		}
+	}
 }
 
 func BenchmarkTransforms(b *testing.B) {