Переглянути джерело

sample: use partial sort for sorting

ParthSareen 2 місяців тому
батько
коміт
310b235626
1 змінених файлів з 62 додано та 46 видалено
  1. 62 46
      sample/transforms.go

+ 62 - 46
sample/transforms.go

@@ -126,61 +126,77 @@ func minP(ts []token, p float32) []token {
 	return ts
 }
 
-// TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584
-// sortLogits sorts implementation to sort tokens by logits using counting sort
-// counting sort is faster than built-in sort for this use case
-func sortLogits(tokens []token) {
-	if len(tokens) <= 1 {
-		return
-	}
-
-	// Find max/min in a single pass
-	minLogit, maxLogit := tokens[0].value, tokens[0].value
-	for _, t := range tokens[1:] {
-		if t.value < minLogit {
-			minLogit = t.value
-		} else if t.value > maxLogit {
-			maxLogit = t.value
+// partialSortLogits uses quickselect to efficiently find and sort the top n tokens
+func partialSortLogits(ts []token, n int) []token {
+	if n >= len(ts) {
+		n = len(ts)
+	}
+
+	left, right := 0, len(ts)-1
+	target := n - 1
+
+	// Quickselect algorithm to partition array around pivot
+	for left < right {
+		// Choose middle element as pivot and move it to the end
+		pivot := left + (right-left)/2
+		ts[pivot], ts[right] = ts[right], ts[pivot]
+
+		// storeIndex tracks where to put next element greater than pivot
+		storeIndex := left
+		pivotValue := ts[right].value
+
+		// Partition array into elements >= pivot and < pivot
+		// Elements >= pivot go to the left side
+		for i := left; i < right; i++ {
+			if ts[i].value >= pivotValue {
+				ts[storeIndex], ts[i] = ts[i], ts[storeIndex]
+				storeIndex++
 		}
 	}
 
-	// Calculate scaling to map to uint32 range
-	logitRange := maxLogit - minLogit
-	if logitRange < 1e-6 {
-		return // All values effectively equal
-	}
-
-	// Count frequencies directly from tokens
-	const maxInt = (1 << 24) - 1 // Use 24 bits for good granularity
-	var counts [256]int          // For first byte
+		// Move pivot to its final position
+		ts[right], ts[storeIndex] = ts[storeIndex], ts[right]
 
-	// First pass: count frequencies
-	for _, t := range tokens {
-		// Map to [0, maxInt] range
-		score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt)
-		counts[score>>16]++
+		// If pivot is at target position, we're done
+		// Otherwise recursively partition the half containing target
+		if storeIndex == target {
+			break
+		} else if storeIndex < target {
+			left = storeIndex + 1 // Target is in right half
+		} else {
+			right = storeIndex - 1 // Target is in left half
+		}
 	}
 
-	// Calculate offsets
-	var offset int
-	for i := range counts {
-		count := counts[i]
-		counts[i] = offset
-		offset += count
-	}
+	// Sort just the top n elements in descending order
+	slices.SortFunc(ts[:n], func(a, b token) int {
+		if a.value > b.value {
+			return -1
+		}
+		if a.value < b.value {
+			return 1
+		}
+		return 0
+	})
 
-	// Second pass: place elements in correct position
-	output := make([]token, len(tokens))
-	// Track current positions
-	countsCopy := counts
+	return ts[:n]
+	}
 
-	for i, t := range tokens {
-		score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt)
+// sortLogits uses partialSortLogits to efficiently sort tokens
+// It sorts approximately sqrt(len(tokens)) elements which balances
+// between having enough tokens for sampling while avoiding full sort
+func sortLogits(ts []token) {
+	// Use sqrt of token length as a heuristic for partial sort size
+	// This provides a good balance between performance and having enough tokens
+	n := int(math.Sqrt(float64(len(ts)))) + 1
 
-		pos := countsCopy[score>>16]
-		countsCopy[score>>16]++
-		output[len(tokens)-1-pos] = tokens[i]
+	// Ensure we have at least 100 tokens and at most 1000
+	switch {
+	case n < 100:
+		n = 100
+	case n > 1000:
+		n = 1000
 	}
 
-	copy(tokens, output)
+	partialSortLogits(ts, n)
 }