|
@@ -59,7 +59,7 @@ func TestTemperatureAndSoftmax(t *testing.T) {
|
|
|
func TestTopK(t *testing.T) {
|
|
|
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
|
|
|
|
|
- // Test k=3
|
|
|
+ // Test k=5
|
|
|
got := topK(toTokens(input), 5)
|
|
|
if len(got) != 5 {
|
|
|
t.Errorf("topK(5): wrong length: want 5, got %d", len(got))
|
|
@@ -72,6 +72,24 @@ func TestTopK(t *testing.T) {
|
|
|
if len(got) != len(input) {
|
|
|
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got))
|
|
|
}
|
|
|
+
|
|
|
+ // Test k=-1
|
|
|
+ input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
|
|
+ want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
|
|
+ got = topK(toTokens(input), -1)
|
|
|
+ if len(got) != len(input) {
|
|
|
+ t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
|
|
|
+ }
|
|
|
+ compareLogits(t, "topK(-1)", want, got)
|
|
|
+
|
|
|
+ // Test k=0
|
|
|
+ input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
|
|
+ want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
|
|
+ got = topK(toTokens(input), 0)
|
|
|
+ if len(got) != len(input) {
|
|
|
+ t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
|
|
|
+ }
|
|
|
+ compareLogits(t, "topK(-1)", want, got)
|
|
|
}
|
|
|
|
|
|
func TestTopP(t *testing.T) {
|
|
@@ -80,7 +98,7 @@ func TestTopP(t *testing.T) {
|
|
|
|
|
|
// First apply temperature and softmax to get probabilities
|
|
|
tokens = temperature(tokens, 1)
|
|
|
- sortLogits(tokens)
|
|
|
+ tokens = topK(tokens, 20)
|
|
|
|
|
|
// Then apply topP
|
|
|
got := topP(tokens, 0.95)
|
|
@@ -112,7 +130,7 @@ func TestSortLogits(t *testing.T) {
|
|
|
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
|
|
tokens := toTokens(input)
|
|
|
|
|
|
- sortLogits(tokens)
|
|
|
+ tokens = topK(tokens, 20)
|
|
|
|
|
|
for i := 1; i < len(tokens); i++ {
|
|
|
if tokens[i].value > tokens[i-1].value {
|
|
@@ -173,7 +191,7 @@ func BenchmarkTransforms(b *testing.B) {
|
|
|
b.ResetTimer()
|
|
|
for b.Loop() {
|
|
|
copy(tokensCopy, tokens)
|
|
|
- sortLogits(tokensCopy)
|
|
|
+ topK(tokensCopy, 200000)
|
|
|
}
|
|
|
})
|
|
|
}
|