|
@@ -10,7 +10,7 @@ import (
|
|
|
type tokenHeap []token
|
|
|
|
|
|
func (h tokenHeap) Len() int { return len(h) }
|
|
|
-func (h tokenHeap) Less(i, j int) bool { return h[i].value < h[j].value } // Use < for min-heap to track largest elements
|
|
|
+func (h tokenHeap) Less(i, j int) bool { return h[i].value < h[j].value }
|
|
|
func (h tokenHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
|
|
|
|
|
func (h *tokenHeap) Push(x any) {
|
|
@@ -72,7 +72,7 @@ func topK(ts []token, k int) []token {
|
|
|
}
|
|
|
|
|
|
// Convert heap to sorted slice in descending order
|
|
|
- result := make([]token, k)
|
|
|
+ result := make([]token, len(h))
|
|
|
for i := k - 1; i >= 0; i-- {
|
|
|
result[i] = heap.Pop(&h).(token)
|
|
|
}
|
|
@@ -126,77 +126,16 @@ func minP(ts []token, p float32) []token {
|
|
|
return ts
|
|
|
}
|
|
|
|
|
|
-// partialSortLogits uses quickselect to efficiently find and sort the top n tokens
|
|
|
-func partialSortLogits(ts []token, n int) []token {
|
|
|
- if n >= len(ts) {
|
|
|
- n = len(ts)
|
|
|
- }
|
|
|
-
|
|
|
- left, right := 0, len(ts)-1
|
|
|
- target := n - 1
|
|
|
-
|
|
|
- // Quickselect algorithm to partition array around pivot
|
|
|
- for left < right {
|
|
|
- // Choose middle element as pivot and move it to the end
|
|
|
- pivot := left + (right-left)/2
|
|
|
- ts[pivot], ts[right] = ts[right], ts[pivot]
|
|
|
-
|
|
|
- // storeIndex tracks where to put next element greater than pivot
|
|
|
- storeIndex := left
|
|
|
- pivotValue := ts[right].value
|
|
|
-
|
|
|
- // Partition array into elements >= pivot and < pivot
|
|
|
- // Elements >= pivot go to the left side
|
|
|
- for i := left; i < right; i++ {
|
|
|
- if ts[i].value >= pivotValue {
|
|
|
- ts[storeIndex], ts[i] = ts[i], ts[storeIndex]
|
|
|
- storeIndex++
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- // Move pivot to its final position
|
|
|
- ts[right], ts[storeIndex] = ts[storeIndex], ts[right]
|
|
|
-
|
|
|
- // If pivot is at target position, we're done
|
|
|
- // Otherwise recursively partition the half containing target
|
|
|
- if storeIndex == target {
|
|
|
- break
|
|
|
- } else if storeIndex < target {
|
|
|
- left = storeIndex + 1 // Target is in right half
|
|
|
- } else {
|
|
|
- right = storeIndex - 1 // Target is in left half
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- // Sort just the top n elements in descending order
|
|
|
- slices.SortFunc(ts[:n], func(a, b token) int {
|
|
|
- if a.value > b.value {
|
|
|
- return -1
|
|
|
- }
|
|
|
- if a.value < b.value {
|
|
|
+// sortLogits sorts the tokens in descending order of logits
|
|
|
+func sortLogits(ts []token) {
|
|
|
+ slices.SortFunc(ts, func(a, b token) int {
|
|
|
+ switch {
|
|
|
+ case a.value < b.value:
|
|
|
return 1
|
|
|
+ case a.value > b.value:
|
|
|
+ return -1
|
|
|
+ default:
|
|
|
+ return 0
|
|
|
}
|
|
|
- return 0
|
|
|
})
|
|
|
-
|
|
|
- return ts[:n]
|
|
|
-}
|
|
|
-
|
|
|
-// sortLogits uses partialSortLogits to efficiently sort tokens
|
|
|
-// It sorts approximately sqrt(len(tokens)) elements which balances
|
|
|
-// between having enough tokens for sampling while avoiding full sort
|
|
|
-func sortLogits(ts []token) {
|
|
|
- // Use sqrt of token length as a heuristic for partial sort size
|
|
|
- // This provides a good balance between performance and having enough tokens
|
|
|
- n := int(math.Sqrt(float64(len(ts)))) + 1
|
|
|
-
|
|
|
- // Ensure we have at least 100 tokens and at most 1000
|
|
|
- switch {
|
|
|
- case n < 100:
|
|
|
- n = 100
|
|
|
- case n > 1000:
|
|
|
- n = 1000
|
|
|
- }
|
|
|
-
|
|
|
- partialSortLogits(ts, n)
|
|
|
}
|