transforms.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. package sample
  2. import (
  3. "container/heap"
  4. "math"
  5. "math/rand"
  6. "slices"
  7. )
  8. // tokenHeap implements heap.Interface and holds tokens as a min-heap to track k largest elements
  9. type tokenHeap []token
  10. func (h tokenHeap) Len() int { return len(h) }
  11. func (h tokenHeap) Less(i, j int) bool { return h[i].value < h[j].value }
  12. func (h tokenHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
  13. func (h *tokenHeap) Push(x any) {
  14. *h = append(*h, x.(token))
  15. }
  16. func (h *tokenHeap) Pop() any {
  17. old := *h
  18. n := len(old)
  19. x := old[n-1]
  20. *h = old[0 : n-1]
  21. return x
  22. }
  23. // topK limits the number of tokens considered to the k highest logits
  24. func topK(ts []token, k int) []token {
  25. if k >= len(ts) || k <= 0 {
  26. slices.SortFunc(ts, func(a, b token) int {
  27. switch {
  28. case a.value < b.value:
  29. return 1
  30. case a.value > b.value:
  31. return -1
  32. default:
  33. return 0
  34. }
  35. })
  36. return ts
  37. }
  38. // Initialize min-heap with first k elements
  39. h := make(tokenHeap, k)
  40. copy(h, ts[:k])
  41. heap.Init(&h)
  42. // Process remaining elements
  43. for i := k; i < len(ts); i++ {
  44. if ts[i].value > h[0].value {
  45. heap.Pop(&h)
  46. heap.Push(&h, ts[i])
  47. }
  48. }
  49. // Convert heap to sorted slice in descending order
  50. result := make([]token, len(h))
  51. for i := k - 1; i >= 0; i-- {
  52. result[i] = heap.Pop(&h).(token)
  53. }
  54. return result
  55. }
  56. // topP limits tokens to those with cumulative probability p
  57. func topP(ts []token, p float32) []token {
  58. if p == 1.0 {
  59. return ts
  60. }
  61. // Find cutoff index where cumulative sum exceeds p
  62. var sum float32
  63. for i, t := range ts {
  64. sum += t.value
  65. if sum > float32(p) {
  66. ts = ts[:i+1]
  67. return ts
  68. }
  69. }
  70. return ts
  71. }
  72. // minP limits tokens to those with cumulative probability p
  73. func minP(ts []token, p float32) []token {
  74. if p == 1.0 {
  75. return ts
  76. }
  77. maxProb := float32(math.Inf(-1))
  78. for _, token := range ts {
  79. if token.value > maxProb {
  80. maxProb = token.value
  81. }
  82. }
  83. threshold := maxProb * float32(p)
  84. // Filter tokens in-place
  85. validTokens := ts[:0]
  86. for i, token := range ts {
  87. if token.value >= threshold {
  88. validTokens = append(validTokens, ts[i])
  89. }
  90. }
  91. ts = validTokens
  92. return ts
  93. }
  94. func temperature(ts []token, temp float32) {
  95. for i := range ts {
  96. ts[i].value /= temp
  97. }
  98. }
  99. func softmax(ts []token) {
  100. if len(ts) == 0 {
  101. return
  102. }
  103. // Find max logit for numerical stability
  104. maxLogit := ts[0].value
  105. for _, t := range ts {
  106. if t.value > maxLogit {
  107. maxLogit = t.value
  108. }
  109. }
  110. // Compute exp(logit - maxLogit) and sum them
  111. var sumExp float32
  112. for i, t := range ts {
  113. expVal := float32(math.Exp(float64(t.value - maxLogit)))
  114. ts[i].value = expVal
  115. sumExp += expVal
  116. }
  117. // Normalize probabilities
  118. for i := range ts {
  119. ts[i].value /= sumExp
  120. }
  121. }
  122. // applyDist selects a token based on probabilities and seed
  123. func dist(ts []token, seed int64) int {
  124. rng := rand.New(rand.NewSource(seed))
  125. cdf := make([]float32, len(ts))
  126. var cumSum float32
  127. for i, t := range ts {
  128. cumSum += t.value
  129. cdf[i] = cumSum
  130. }
  131. r := rng.Float32() * cumSum
  132. // Select token based on CDF
  133. for i, probSum := range cdf {
  134. if r < probSum {
  135. return i
  136. }
  137. }
  138. return len(ts) - 1
  139. }