浏览代码

sample: do all sorting in topK

ParthSareen 1 月之前
父节点
当前提交
4aeb67ef4c
共有 3 个文件被更改,包括 35 次插入25 次删除
  1. 2 5
      sample/samplers.go
  2. 11 16
      sample/transforms.go
  3. 22 4
      sample/transforms_test.go

+ 2 - 5
sample/samplers.go

@@ -84,11 +84,8 @@ func (s *Sampler) sample(tokens []token) (token, error) {
 		return greedy(tokens), nil
 	}
 
-	if s.topK > 0 {
-		tokens = topK(tokens, s.topK)
-	} else {
-		sortLogits(tokens)
-	}
+	// topK also sorts the tokens in descending order of logits
+	tokens = topK(tokens, s.topK)
 
 	// token logit values are updated to probabilities
 	tokens = temperature(tokens, s.temperature)

+ 11 - 16
sample/transforms.go

@@ -53,8 +53,17 @@ func temperature(ts []token, temp float32) []token {
 
 // topK limits the number of tokens considered to the k highest logits
 func topK(ts []token, k int) []token {
-	if k >= len(ts) {
-		sortLogits(ts)
+	if k >= len(ts) || k <= 0 {
+		slices.SortFunc(ts, func(a, b token) int {
+			switch {
+			case a.value < b.value:
+				return 1
+			case a.value > b.value:
+				return -1
+			default:
+				return 0
+			}
+		})
 		return ts
 	}
 
@@ -125,17 +134,3 @@ func minP(ts []token, p float32) []token {
 	ts = validTokens
 	return ts
 }
-
-// sortLogits sorts the tokens in descending order of logits
-func sortLogits(ts []token) {
-	slices.SortFunc(ts, func(a, b token) int {
-		switch {
-		case a.value < b.value:
-			return 1
-		case a.value > b.value:
-			return -1
-		default:
-			return 0
-		}
-	})
-}

+ 22 - 4
sample/transforms_test.go

@@ -59,7 +59,7 @@ func TestTemperatureAndSoftmax(t *testing.T) {
 func TestTopK(t *testing.T) {
 	input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
 
-	// Test k=3
+	// Test k=5
 	got := topK(toTokens(input), 5)
 	if len(got) != 5 {
 		t.Errorf("topK(5): wrong length: want 5, got %d", len(got))
@@ -72,6 +72,24 @@ func TestTopK(t *testing.T) {
 	if len(got) != len(input) {
 		t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got))
 	}
+
+	// Test k=-1
+	input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
+	want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
+	got = topK(toTokens(input), -1)
+	if len(got) != len(input) {
+		t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
+	}
+	compareLogits(t, "topK(-1)", want, got)
+
+	// Test k=0
+	input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
+	want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
+	got = topK(toTokens(input), 0)
+	if len(got) != len(input) {
+		t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
+	}
+	compareLogits(t, "topK(-1)", want, got)
 }
 
 func TestTopP(t *testing.T) {
@@ -80,7 +98,7 @@ func TestTopP(t *testing.T) {
 
 	// First apply temperature and softmax to get probabilities
 	tokens = temperature(tokens, 1)
-	sortLogits(tokens)
+	tokens = topK(tokens, 20)
 
 	// Then apply topP
 	got := topP(tokens, 0.95)
@@ -112,7 +130,7 @@ func TestSortLogits(t *testing.T) {
 	input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
 	tokens := toTokens(input)
 
-	sortLogits(tokens)
+	tokens = topK(tokens, 20)
 
 	for i := 1; i < len(tokens); i++ {
 		if tokens[i].value > tokens[i-1].value {
@@ -173,7 +191,7 @@ func BenchmarkTransforms(b *testing.B) {
 		b.ResetTimer()
 		for b.Loop() {
 			copy(tokensCopy, tokens)
-			sortLogits(tokensCopy)
+			topK(tokensCopy, 200000)
 		}
 	})
 }