transforms.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. package sample
  2. import (
  3. "math"
  4. "slices"
  5. )
  6. // temperature applies scaling and softmax to the logits
  7. func temperature(ts []token, temp float32) []token {
  8. // Find max logit for numerical stability
  9. maxLogit := float32(math.Inf(-1))
  10. for _, t := range ts {
  11. if t.value > maxLogit {
  12. maxLogit = t.value
  13. }
  14. }
  15. // Apply temperature and compute exp(x - max)
  16. temp = max(temp, 1e-7)
  17. var sum float32
  18. for i, v := range ts {
  19. ts[i].value = float32(math.Exp(float64((v.value - maxLogit) / temp)))
  20. sum += ts[i].value
  21. }
  22. // Normalize
  23. for i := range ts {
  24. ts[i].value /= sum
  25. }
  26. return ts
  27. }
  28. // siftDown maintains a min-heap property by recursively moving larger elements down the heap.
  29. //
  30. // The heap is represented as an array where for any node at index i:
  31. // - Left child is at index 2i + 1
  32. // - Right child is at index 2i + 2
  33. // - Parent is at index (i-1)/2
  34. //
  35. // The function compares a node with its children and:
  36. // 1. Finds the smallest value between the node and its children
  37. // 2. If the node is not the smallest, swaps it with its smallest child
  38. // 3. Continues this process down the affected path until the min-heap property is restored
  39. func siftDown(data []token, start, end int) {
  40. root := start
  41. for {
  42. child := 2*root + 1
  43. if child >= end {
  44. break
  45. }
  46. // Find smaller child (we want min heap)
  47. if child+1 < end && data[child+1].value < data[child].value {
  48. child++
  49. }
  50. // Exit if root is already smaller than children
  51. if data[root].value <= data[child].value {
  52. break
  53. }
  54. // Swap with smaller child and continue
  55. data[root], data[child] = data[child], data[root]
  56. root = child
  57. }
  58. }
  59. // topK limits the number of tokens considered to the k highest logits
  60. func topK(ts []token, k int) []token {
  61. if k >= len(ts) {
  62. return ts
  63. }
  64. // Heapify + siftDown - O(nlog(k))
  65. // Build min-heap of first k elements
  66. heap := ts[:k]
  67. for i := k/2 - 1; i >= 0; i-- {
  68. siftDown(heap, i, k)
  69. }
  70. // Process remaining elements - if larger than heap root, replace root
  71. for i := k; i < len(ts); i++ {
  72. if ts[i].value > heap[0].value {
  73. heap[0] = ts[i]
  74. siftDown(heap, 0, k)
  75. }
  76. }
  77. slices.Reverse(heap)
  78. ts = heap
  79. return ts
  80. }
  81. // topP limits tokens to those with cumulative probability p
  82. func topP(ts []token, p float32) []token {
  83. if p == 1.0 {
  84. return ts
  85. }
  86. // Find cutoff index where cumulative sum exceeds p
  87. var sum float32
  88. for i, t := range ts {
  89. sum += t.value
  90. if sum > float32(p) {
  91. ts = ts[:i+1]
  92. return ts
  93. }
  94. }
  95. return ts
  96. }
  97. // minP limits tokens to those with cumulative probability p
  98. func minP(ts []token, p float32) []token {
  99. if p == 1.0 {
  100. return ts
  101. }
  102. maxProb := float32(math.Inf(-1))
  103. for _, token := range ts {
  104. if token.value > maxProb {
  105. maxProb = token.value
  106. }
  107. }
  108. threshold := maxProb * float32(p)
  109. // Filter tokens in-place
  110. validTokens := ts[:0]
  111. for i, token := range ts {
  112. if token.value >= threshold {
  113. validTokens = append(validTokens, ts[i])
  114. }
  115. }
  116. ts = validTokens
  117. return ts
  118. }
  119. // TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584
  120. // sortLogits sorts implementation to sort tokens by logits using counting sort
  121. // counting sort is faster than built-in sort for this use case
  122. func sortLogits(tokens []token) {
  123. if len(tokens) <= 1 {
  124. return
  125. }
  126. // Find max/min in a single pass
  127. minLogit, maxLogit := tokens[0].value, tokens[0].value
  128. for _, t := range tokens[1:] {
  129. if t.value < minLogit {
  130. minLogit = t.value
  131. } else if t.value > maxLogit {
  132. maxLogit = t.value
  133. }
  134. }
  135. // Calculate scaling to map to uint32 range
  136. logitRange := maxLogit - minLogit
  137. if logitRange < 1e-6 {
  138. return // All values effectively equal
  139. }
  140. // Count frequencies directly from tokens
  141. const maxInt = (1 << 24) - 1 // Use 24 bits for good granularity
  142. var counts [256]int // For first byte
  143. // First pass: count frequencies
  144. for _, t := range tokens {
  145. // Map to [0, maxInt] range
  146. score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt)
  147. counts[score>>16]++
  148. }
  149. // Calculate offsets
  150. var offset int
  151. for i := range counts {
  152. count := counts[i]
  153. counts[i] = offset
  154. offset += count
  155. }
  156. // Second pass: place elements in correct position
  157. output := make([]token, len(tokens))
  158. // Track current positions
  159. countsCopy := counts
  160. for i, t := range tokens {
  161. score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt)
  162. pos := countsCopy[score>>16]
  163. countsCopy[score>>16]++
  164. output[len(tokens)-1-pos] = tokens[i]
  165. }
  166. copy(tokens, output)
  167. }