transforms.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. package sample
  2. import (
  3. "math"
  4. "slices"
  5. )
  6. func softmax(ts []logit) []logit {
  7. var sum float32
  8. for i, v := range ts {
  9. ts[i].value = float32(math.Exp(float64(v.value)))
  10. sum += ts[i].value
  11. }
  12. for i := range ts {
  13. ts[i].value /= sum
  14. }
  15. return ts
  16. }
  17. func temperature(ti []logit, t float32) []logit {
  18. if t == 1 {
  19. return ti
  20. }
  21. temp := max(t, 1e-7)
  22. maxLogit := float32(math.Inf(-1))
  23. for _, token := range ti {
  24. if token.value > maxLogit {
  25. maxLogit = token.value
  26. }
  27. }
  28. // subtracting max logit to avoid under/overflow
  29. for i := range ti {
  30. ti[i].value = (ti[i].value - maxLogit) / temp
  31. }
  32. return ti
  33. }
  34. // siftDown maintains a min-heap property by recursively moving larger elements down the heap.
  35. //
  36. // The heap is represented as an array where for any node at index i:
  37. // - Left child is at index 2i + 1
  38. // - Right child is at index 2i + 2
  39. // - Parent is at index (i-1)/2
  40. //
  41. // The function compares a node with its children and:
  42. // 1. Finds the smallest value between the node and its children
  43. // 2. If the node is not the smallest, swaps it with its smallest child
  44. // 3. Continues this process down the affected path until the min-heap property is restored
  45. func siftDown(data []logit, start, end int) {
  46. root := start
  47. for {
  48. child := 2*root + 1
  49. if child >= end {
  50. break
  51. }
  52. // Find smaller child (we want min heap)
  53. if child+1 < end && data[child+1].value < data[child].value {
  54. child++
  55. }
  56. // Exit if root is already smaller than children
  57. if data[root].value <= data[child].value {
  58. break
  59. }
  60. // Swap with smaller child and continue
  61. data[root], data[child] = data[child], data[root]
  62. root = child
  63. }
  64. }
  65. // topK limits the number of tokens considered to the k highest logits
  66. func topK(ts []logit, k int) []logit {
  67. if k >= len(ts) {
  68. return ts
  69. }
  70. // Heapify + siftDown - O(nlog(k))
  71. // Build min-heap of first k elements
  72. heap := ts[:k]
  73. for i := k/2 - 1; i >= 0; i-- {
  74. siftDown(heap, i, k)
  75. }
  76. // Process remaining elements - if larger than heap root, replace root
  77. for i := k; i < len(ts); i++ {
  78. if ts[i].value > heap[0].value {
  79. heap[0] = ts[i]
  80. siftDown(heap, 0, k)
  81. }
  82. }
  83. slices.Reverse(heap)
  84. ts = heap
  85. return ts
  86. }
  87. // topP limits tokens to those with cumulative probability p
  88. func topP(ts []logit, p float32) []logit {
  89. if p == 1.0 {
  90. return ts
  91. }
  92. // Find cutoff index where cumulative sum exceeds p
  93. var sum float32
  94. for i, t := range ts {
  95. sum += t.value
  96. if sum > float32(p) {
  97. ts = ts[:i+1]
  98. return ts
  99. }
  100. }
  101. return ts
  102. }
  103. // minP limits tokens to those with cumulative probability p
  104. func minP(ts []logit, p float32) []logit {
  105. if p == 1.0 {
  106. return ts
  107. }
  108. maxProb := float32(math.Inf(-1))
  109. for _, token := range ts {
  110. if token.value > maxProb {
  111. maxProb = token.value
  112. }
  113. }
  114. threshold := maxProb * float32(p)
  115. // Filter tokens in-place
  116. validTokens := ts[:0]
  117. for i, token := range ts {
  118. if token.value >= threshold {
  119. validTokens = append(validTokens, ts[i])
  120. }
  121. }
  122. ts = validTokens
  123. return ts
  124. }
  125. // TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584
  126. // Conting sort implementation to sort tokens by logits
  127. func sortLogits(tokens []logit) {
  128. if len(tokens) <= 1 {
  129. return
  130. }
  131. // Find max/min in a single pass
  132. minLogit, maxLogit := tokens[0].value, tokens[0].value
  133. for _, t := range tokens[1:] {
  134. if t.value < minLogit {
  135. minLogit = t.value
  136. } else if t.value > maxLogit {
  137. maxLogit = t.value
  138. }
  139. }
  140. // Calculate scaling to map to uint32 range
  141. logitRange := maxLogit - minLogit
  142. if logitRange < 1e-6 {
  143. return // All values effectively equal
  144. }
  145. // Count frequencies directly from tokens
  146. const maxInt = (1 << 24) - 1 // Use 24 bits for good granularity
  147. var counts [256]int // For first byte
  148. // First pass: count frequencies
  149. for _, t := range tokens {
  150. // Map to [0, maxInt] range
  151. score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt)
  152. counts[score>>16]++
  153. }
  154. // Calculate offsets
  155. var offset int
  156. for i := range counts {
  157. count := counts[i]
  158. counts[i] = offset
  159. offset += count
  160. }
  161. // Second pass: place elements in correct position
  162. output := make([]logit, len(tokens))
  163. // Track current positions
  164. countsCopy := counts
  165. for i, t := range tokens {
  166. score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt)
  167. pos := countsCopy[score>>16]
  168. countsCopy[score>>16]++
  169. output[len(tokens)-1-pos] = tokens[i]
  170. }
  171. copy(tokens, output)
  172. }