|
@@ -168,27 +168,53 @@ func TestTopP(t *testing.T) {
|
|
|
softmax(tokens)
|
|
|
tokens = topK(tokens, 20)
|
|
|
|
|
|
- // Then apply topP
|
|
|
- tokens = topP(tokens, 0.95)
|
|
|
+ // Test with very high p value
|
|
|
+ got := topP(tokens, 1.0)
|
|
|
|
|
|
- // Should keep tokens until cumsum > 0.95
|
|
|
- if len(tokens) > 3 {
|
|
|
+ // Should keep all tokens since p is 1
|
|
|
+ if len(got) != len(input) {
|
|
|
+ t.Errorf("topP(1.0): should keep all tokens, got %d, want %d", len(got), len(input))
|
|
|
+ }
|
|
|
+
|
|
|
+ // Test with normal p value
|
|
|
+ got = topP(tokens, 0.95)
|
|
|
+
|
|
|
+ if len(got) > 3 {
|
|
|
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
|
|
|
- t.Logf("got: %v", tokens)
|
|
|
+ t.Logf("got: %v", got)
|
|
|
}
|
|
|
|
|
|
// Test edge case - ensure at least one token remains
|
|
|
- input = []float32{-1e6, -1e6, -1e6} // One dominant token
|
|
|
+ input = []float32{-1e6, -1e6, -1e7}
|
|
|
tokens = toTokens(input)
|
|
|
+ tokens = topK(tokens, 20)
|
|
|
softmax(tokens)
|
|
|
- tokens = topP(tokens, 0.0) // Very small p
|
|
|
- if len(tokens) < 1 {
|
|
|
+ got = topP(tokens, 0.0)
|
|
|
+ if len(got) < 1 {
|
|
|
t.Error("topP should keep at least one token")
|
|
|
}
|
|
|
+
|
|
|
+ // Test with zero p value
|
|
|
+ got = topP(tokens, 0.0)
|
|
|
+
|
|
|
+ // Should keep only the highest probability token
|
|
|
+ if len(got) != 1 {
|
|
|
+ t.Errorf("topP(0.0): should keep only one token, got %d", len(got))
|
|
|
+ t.Logf("got: %v", got)
|
|
|
+ }
|
|
|
+
|
|
|
+ tokens = toTokens(input)
|
|
|
+ tokens = topK(tokens, 20)
|
|
|
+ softmax(tokens)
|
|
|
+ got = topP(tokens, 1e-10)
|
|
|
+ if len(got) == 0 {
|
|
|
+ t.Errorf("topP(1e-10): should keep at least one token, got %d", len(got))
|
|
|
+ t.Logf("got: %v", got)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
func TestMinP(t *testing.T) {
|
|
|
- input := []float32{-3, -2, -1, 0, 1, 2, 4, 3}
|
|
|
+ input := []float32{-2, 0, -1, -3, 2, 1, 4, 3}
|
|
|
tokens := toTokens(input)
|
|
|
|
|
|
// First apply temperature and softmax
|
|
@@ -225,30 +251,48 @@ func TestMinP(t *testing.T) {
|
|
|
t.Logf("got: %v", tokens)
|
|
|
}
|
|
|
|
|
|
+ // Test with single token
|
|
|
+ tokens = toTokens(input[:1])
|
|
|
+ tokens = topK(tokens, 20)
|
|
|
+ softmax(tokens)
|
|
|
+ tokens = minP(tokens, 0.1)
|
|
|
+
|
|
|
+ // Should keep only the highest probability token
|
|
|
+ if len(tokens) != 1 {
|
|
|
+ t.Errorf("minP(0.1): should return single token, got %d", len(tokens))
|
|
|
+ t.Logf("got: %v", tokens)
|
|
|
+ }
|
|
|
+
|
|
|
input = []float32{1e-10, 1e-10, 1e-10}
|
|
|
tokens = toTokens(input)
|
|
|
softmax(tokens)
|
|
|
tokens = minP(tokens, 1.0)
|
|
|
if len(tokens) < 1 {
|
|
|
t.Error("minP should keep at least one token even with extreme probabilities")
|
|
|
- }
|
|
|
-}
|
|
|
+ got := minP(tokens, 1.0)
|
|
|
|
|
|
-func TestSortLogits(t *testing.T) {
|
|
|
- input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
|
|
- tokens := toTokens(input)
|
|
|
+ if len(got) != 1 {
|
|
|
+ t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(got), len(tokens))
|
|
|
+ }
|
|
|
|
|
|
- tokens = topK(tokens, 20)
|
|
|
+ // Test with normal p value
|
|
|
+ got = minP(tokens, 0.2)
|
|
|
|
|
|
- for i := 1; i < len(tokens); i++ {
|
|
|
- if tokens[i].value > tokens[i-1].value {
|
|
|
- t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
|
|
|
- i, tokens[i].value, tokens[i-1].value)
|
|
|
+ // Should keep tokens with prob >= 0.2 * max_prob
|
|
|
+ if len(got) > 3 {
|
|
|
+ t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
|
|
|
+ t.Logf("got: %v", got)
|
|
|
}
|
|
|
- }
|
|
|
|
|
|
- want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
|
|
- compareLogits(t, "sortLogits", want, tokens)
|
|
|
+ // Test with zero p value
|
|
|
+ got = minP(tokens, 0.0)
|
|
|
+
|
|
|
+ // Should keep only the highest probability token
|
|
|
+ if len(got) != len(tokens) {
|
|
|
+ t.Errorf("minP(0.0): should keep only one token, got %d", len(got))
|
|
|
+ t.Logf("got: %v", got)
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
func BenchmarkTransforms(b *testing.B) {
|