transforms.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. package sample
  2. import (
  3. "container/heap"
  4. "math"
  5. "math/rand"
  6. "slices"
  7. )
  8. // tokenHeap implements heap.Interface and holds tokens as a min-heap to track k largest elements
  9. type tokenHeap []token
  10. func (h tokenHeap) Len() int { return len(h) }
  11. func (h tokenHeap) Less(i, j int) bool { return h[i].value < h[j].value } // Use < for min-heap to track largest elements
  12. func (h tokenHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
  13. func (h *tokenHeap) Push(x any) {
  14. *h = append(*h, x.(token))
  15. }
  16. func (h *tokenHeap) Pop() any {
  17. old := *h
  18. n := len(old)
  19. x := old[n-1]
  20. *h = old[0 : n-1]
  21. return x
  22. }
  23. // topK limits the number of tokens considered to the k highest logits
  24. func topK(ts []token, k int) []token {
  25. if k >= len(ts) {
  26. sortLogits(ts)
  27. return ts
  28. }
  29. // Initialize min-heap with first k elements
  30. h := make(tokenHeap, k)
  31. copy(h, ts[:k])
  32. heap.Init(&h)
  33. // Process remaining elements
  34. for i := k; i < len(ts); i++ {
  35. if ts[i].value > h[0].value {
  36. heap.Pop(&h)
  37. heap.Push(&h, ts[i])
  38. }
  39. }
  40. // Convert heap to sorted slice in descending order
  41. result := make([]token, k)
  42. for i := k - 1; i >= 0; i-- {
  43. result[i] = heap.Pop(&h).(token)
  44. }
  45. return result
  46. }
  47. // topP limits tokens to those with cumulative probability p
  48. func topP(ts []token, p float32) []token {
  49. if p == 1.0 {
  50. return ts
  51. }
  52. // Find cutoff index where cumulative sum exceeds p
  53. var sum float32
  54. for i, t := range ts {
  55. sum += t.value
  56. if sum > float32(p) {
  57. ts = ts[:i+1]
  58. return ts
  59. }
  60. }
  61. return ts
  62. }
  63. // minP limits tokens to those with cumulative probability p
  64. func minP(ts []token, p float32) []token {
  65. if p == 1.0 {
  66. return ts
  67. }
  68. maxProb := float32(math.Inf(-1))
  69. for _, token := range ts {
  70. if token.value > maxProb {
  71. maxProb = token.value
  72. }
  73. }
  74. threshold := maxProb * float32(p)
  75. // Filter tokens in-place
  76. validTokens := ts[:0]
  77. for i, token := range ts {
  78. if token.value >= threshold {
  79. validTokens = append(validTokens, ts[i])
  80. }
  81. }
  82. ts = validTokens
  83. return ts
  84. }
  85. // partialSortLogits uses quickselect to efficiently find and sort the top n tokens
  86. func partialSortLogits(ts []token, n int) []token {
  87. if n >= len(ts) {
  88. n = len(ts)
  89. }
  90. left, right := 0, len(ts)-1
  91. target := n - 1
  92. // Quickselect algorithm to partition array around pivot
  93. for left < right {
  94. // Choose middle element as pivot and move it to the end
  95. pivot := left + (right-left)/2
  96. ts[pivot], ts[right] = ts[right], ts[pivot]
  97. // storeIndex tracks where to put next element greater than pivot
  98. storeIndex := left
  99. pivotValue := ts[right].value
  100. // Partition array into elements >= pivot and < pivot
  101. // Elements >= pivot go to the left side
  102. for i := left; i < right; i++ {
  103. if ts[i].value >= pivotValue {
  104. ts[storeIndex], ts[i] = ts[i], ts[storeIndex]
  105. storeIndex++
  106. }
  107. }
  108. // Move pivot to its final position
  109. ts[right], ts[storeIndex] = ts[storeIndex], ts[right]
  110. // If pivot is at target position, we're done
  111. // Otherwise recursively partition the half containing target
  112. if storeIndex == target {
  113. break
  114. } else if storeIndex < target {
  115. left = storeIndex + 1 // Target is in right half
  116. } else {
  117. right = storeIndex - 1 // Target is in left half
  118. }
  119. }
  120. // Sort just the top n elements in descending order
  121. slices.SortFunc(ts[:n], func(a, b token) int {
  122. if a.value > b.value {
  123. return -1
  124. }
  125. if a.value < b.value {
  126. return 1
  127. }
  128. return 0
  129. })
  130. return ts[:n]
  131. }
  132. // sortLogits uses partialSortLogits to efficiently sort tokens
  133. // It sorts approximately sqrt(len(tokens)) elements which balances
  134. // between having enough tokens for sampling while avoiding full sort
  135. func sortLogits(ts []token) {
  136. // Use sqrt of token length as a heuristic for partial sort size
  137. // This provides a good balance between performance and having enough tokens
  138. n := int(math.Sqrt(float64(len(ts)))) + 1
  139. // Ensure we have at least 100 tokens and at most 1000
  140. switch {
  141. case n < 100:
  142. n = 100
  143. case n > 1000:
  144. n = 1000
  145. }
  146. partialSortLogits(ts, n)
  147. }
  148. func temperature(ts []token, temp float32) {
  149. for i := range ts {
  150. ts[i].value /= temp
  151. }
  152. }
  153. func softmax(ts []token) {
  154. if len(ts) == 0 {
  155. return
  156. }
  157. // Find max logit for numerical stability
  158. maxLogit := ts[0].value
  159. for _, t := range ts {
  160. if t.value > maxLogit {
  161. maxLogit = t.value
  162. }
  163. }
  164. // Compute exp(logit - maxLogit) and sum them
  165. var sumExp float32
  166. for i, t := range ts {
  167. expVal := float32(math.Exp(float64(t.value - maxLogit)))
  168. ts[i].value = expVal
  169. sumExp += expVal
  170. }
  171. // Normalize probabilities
  172. for i := range ts {
  173. ts[i].value /= sumExp
  174. }
  175. }
  176. // applyDist selects a token based on probabilities and seed
  177. func dist(ts []token, seed int64) int {
  178. rng := rand.New(rand.NewSource(seed))
  179. cdf := make([]float32, len(ts))
  180. var cumSum float32
  181. for i, t := range ts {
  182. cumSum += t.value
  183. cdf[i] = cumSum
  184. }
  185. r := rng.Float32() * cumSum
  186. // Select token based on CDF
  187. for i, probSum := range cdf {
  188. if r < probSum {
  189. return i
  190. }
  191. }
  192. return len(ts) - 1
  193. }