transforms.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. package sample
  2. import (
  3. "container/heap"
  4. "math"
  5. "slices"
  6. )
  7. // tokenHeap implements heap.Interface and holds tokens as a min-heap to track k largest elements
  8. type tokenHeap []token
  9. func (h tokenHeap) Len() int { return len(h) }
  10. func (h tokenHeap) Less(i, j int) bool { return h[i].value < h[j].value } // Use < for min-heap to track largest elements
  11. func (h tokenHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
  12. func (h *tokenHeap) Push(x any) {
  13. *h = append(*h, x.(token))
  14. }
  15. func (h *tokenHeap) Pop() any {
  16. old := *h
  17. n := len(old)
  18. x := old[n-1]
  19. *h = old[0 : n-1]
  20. return x
  21. }
  22. // temperature applies scaling and softmax to the logits
  23. func temperature(ts []token, temp float32) []token {
  24. // Find max logit for numerical stability
  25. maxLogit := float32(math.Inf(-1))
  26. for _, t := range ts {
  27. if t.value > maxLogit {
  28. maxLogit = t.value
  29. }
  30. }
  31. // Apply temperature and compute exp(x - max)
  32. temp = max(temp, 1e-7)
  33. var sum float32
  34. for i, v := range ts {
  35. ts[i].value = float32(math.Exp(float64((v.value - maxLogit) / temp)))
  36. sum += ts[i].value
  37. }
  38. // Normalize
  39. for i := range ts {
  40. ts[i].value /= sum
  41. }
  42. return ts
  43. }
  44. // topK limits the number of tokens considered to the k highest logits
  45. func topK(ts []token, k int) []token {
  46. if k >= len(ts) {
  47. sortLogits(ts)
  48. return ts
  49. }
  50. // Initialize min-heap with first k elements
  51. h := make(tokenHeap, k)
  52. copy(h, ts[:k])
  53. heap.Init(&h)
  54. // Process remaining elements
  55. for i := k; i < len(ts); i++ {
  56. if ts[i].value > h[0].value {
  57. heap.Pop(&h)
  58. heap.Push(&h, ts[i])
  59. }
  60. }
  61. // Convert heap to sorted slice in descending order
  62. result := make([]token, k)
  63. for i := k - 1; i >= 0; i-- {
  64. result[i] = heap.Pop(&h).(token)
  65. }
  66. return result
  67. }
  68. // topP limits tokens to those with cumulative probability p
  69. func topP(ts []token, p float32) []token {
  70. if p == 1.0 {
  71. return ts
  72. }
  73. // Find cutoff index where cumulative sum exceeds p
  74. var sum float32
  75. for i, t := range ts {
  76. sum += t.value
  77. if sum > float32(p) {
  78. ts = ts[:i+1]
  79. return ts
  80. }
  81. }
  82. return ts
  83. }
  84. // minP limits tokens to those with cumulative probability p
  85. func minP(ts []token, p float32) []token {
  86. if p == 1.0 {
  87. return ts
  88. }
  89. maxProb := float32(math.Inf(-1))
  90. for _, token := range ts {
  91. if token.value > maxProb {
  92. maxProb = token.value
  93. }
  94. }
  95. threshold := maxProb * float32(p)
  96. // Filter tokens in-place
  97. validTokens := ts[:0]
  98. for i, token := range ts {
  99. if token.value >= threshold {
  100. validTokens = append(validTokens, ts[i])
  101. }
  102. }
  103. ts = validTokens
  104. return ts
  105. }
  106. // partialSortLogits uses quickselect to efficiently find and sort the top n tokens
  107. func partialSortLogits(ts []token, n int) []token {
  108. if n >= len(ts) {
  109. n = len(ts)
  110. }
  111. left, right := 0, len(ts)-1
  112. target := n - 1
  113. // Quickselect algorithm to partition array around pivot
  114. for left < right {
  115. // Choose middle element as pivot and move it to the end
  116. pivot := left + (right-left)/2
  117. ts[pivot], ts[right] = ts[right], ts[pivot]
  118. // storeIndex tracks where to put next element greater than pivot
  119. storeIndex := left
  120. pivotValue := ts[right].value
  121. // Partition array into elements >= pivot and < pivot
  122. // Elements >= pivot go to the left side
  123. for i := left; i < right; i++ {
  124. if ts[i].value >= pivotValue {
  125. ts[storeIndex], ts[i] = ts[i], ts[storeIndex]
  126. storeIndex++
  127. }
  128. }
  129. // Move pivot to its final position
  130. ts[right], ts[storeIndex] = ts[storeIndex], ts[right]
  131. // If pivot is at target position, we're done
  132. // Otherwise recursively partition the half containing target
  133. if storeIndex == target {
  134. break
  135. } else if storeIndex < target {
  136. left = storeIndex + 1 // Target is in right half
  137. } else {
  138. right = storeIndex - 1 // Target is in left half
  139. }
  140. }
  141. // Sort just the top n elements in descending order
  142. slices.SortFunc(ts[:n], func(a, b token) int {
  143. if a.value > b.value {
  144. return -1
  145. }
  146. if a.value < b.value {
  147. return 1
  148. }
  149. return 0
  150. })
  151. return ts[:n]
  152. }
  153. // sortLogits uses partialSortLogits to efficiently sort tokens
  154. // It sorts approximately sqrt(len(tokens)) elements which balances
  155. // between having enough tokens for sampling while avoiding full sort
  156. func sortLogits(ts []token) {
  157. // Use sqrt of token length as a heuristic for partial sort size
  158. // This provides a good balance between performance and having enough tokens
  159. n := int(math.Sqrt(float64(len(ts)))) + 1
  160. // Ensure we have at least 100 tokens and at most 1000
  161. switch {
  162. case n < 100:
  163. n = 100
  164. case n > 1000:
  165. n = 1000
  166. }
  167. partialSortLogits(ts, n)
  168. }