transforms.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. package sample
  2. import (
  3. "container/heap"
  4. "math"
  5. "slices"
  6. )
  7. // tokenHeap implements heap.Interface and holds tokens as a min-heap to track k largest elements
  8. type tokenHeap []token
  9. func (h tokenHeap) Len() int { return len(h) }
  10. func (h tokenHeap) Less(i, j int) bool { return h[i].value < h[j].value }
  11. func (h tokenHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
  12. func (h *tokenHeap) Push(x any) {
  13. *h = append(*h, x.(token))
  14. }
  15. func (h *tokenHeap) Pop() any {
  16. old := *h
  17. n := len(old)
  18. x := old[n-1]
  19. *h = old[0 : n-1]
  20. return x
  21. }
  22. // temperature applies scaling to the logits
  23. func temperature(ts []token, temp float32) []token {
  24. // Ensure temperature clipping near 0 to avoid numerical instability
  25. temp = max(temp, 1e-7)
  26. for i := range ts {
  27. ts[i].value = ts[i].value / temp
  28. }
  29. return ts
  30. }
  31. // softmax applies normalization to the logits
  32. func softmax(ts []token) []token {
  33. // Find max logit for numerical stability
  34. maxLogit := float32(math.Inf(-1))
  35. for _, t := range ts {
  36. if t.value > maxLogit {
  37. maxLogit = t.value
  38. }
  39. }
  40. // Compute exp(x - max)
  41. var sum float32
  42. for i, v := range ts {
  43. ts[i].value = float32(math.Exp(float64(v.value - maxLogit)))
  44. sum += ts[i].value
  45. }
  46. // exp(x - max) / sum(exp(x - max))
  47. for i := range ts {
  48. ts[i].value /= sum
  49. }
  50. return ts
  51. }
  52. // topK limits the number of tokens considered to the k highest logits
  53. func topK(ts []token, k int) []token {
  54. if k >= len(ts) || k <= 0 {
  55. slices.SortFunc(ts, func(a, b token) int {
  56. switch {
  57. case a.value < b.value:
  58. return 1
  59. case a.value > b.value:
  60. return -1
  61. default:
  62. return 0
  63. }
  64. })
  65. return ts
  66. }
  67. // Initialize min-heap with first k elements
  68. h := make(tokenHeap, k)
  69. copy(h, ts[:k])
  70. heap.Init(&h)
  71. // Process remaining elements
  72. for i := k; i < len(ts); i++ {
  73. if ts[i].value > h[0].value {
  74. heap.Pop(&h)
  75. heap.Push(&h, ts[i])
  76. }
  77. }
  78. // Convert heap to sorted slice in descending order
  79. result := make([]token, len(h))
  80. for i := k - 1; i >= 0; i-- {
  81. result[i] = heap.Pop(&h).(token)
  82. }
  83. return result
  84. }
  85. // topP limits tokens to those with cumulative probability p
  86. func topP(ts []token, p float32) []token {
  87. if p == 1.0 {
  88. return ts
  89. }
  90. // Find cutoff index where cumulative sum exceeds p
  91. var sum float32
  92. for i, t := range ts {
  93. sum += t.value
  94. if sum > float32(p) {
  95. ts = ts[:i+1]
  96. return ts
  97. }
  98. }
  99. return ts
  100. }
  101. // minP limits tokens to those with cumulative probability p
  102. func minP(ts []token, p float32) []token {
  103. if p == 1.0 {
  104. return ts
  105. }
  106. maxProb := float32(math.Inf(-1))
  107. for _, token := range ts {
  108. if token.value > maxProb {
  109. maxProb = token.value
  110. }
  111. }
  112. threshold := maxProb * float32(p)
  113. // Filter tokens in-place
  114. validTokens := ts[:0]
  115. for i, token := range ts {
  116. if token.value >= threshold {
  117. validTokens = append(validTokens, ts[i])
  118. }
  119. }
  120. ts = validTokens
  121. return ts
  122. }