samplers.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. package sample
  2. import (
  3. "errors"
  4. "math"
  5. "math/rand/v2"
  6. "slices"
  7. "sync"
  8. "github.com/ollama/ollama/llama"
  9. )
  10. // token represents information about a single token during sampling
  11. type token struct {
  12. id int32 // The token's unique identifier
  13. value float32 // The raw logit or probability from the model
  14. }
  15. type Sampler struct {
  16. rng *rand.Rand
  17. topK int
  18. topP float32
  19. minP float32
  20. temperature float32
  21. grammar *Grammar
  22. }
  23. func (s *Sampler) Sample(logits []float32) (int32, error) {
  24. if len(logits) == 0 {
  25. return -1, errors.New("sample: no logits provided to sample")
  26. }
  27. tokens := make([]token, len(logits))
  28. for i := range logits {
  29. tokens[i].id = int32(i)
  30. tokens[i].value = logits[i]
  31. }
  32. t, err := s.sample(tokens)
  33. if err != nil {
  34. return -1, err
  35. }
  36. if s.grammar != nil {
  37. // optimization: first check if the max logit is accepted by the grammar
  38. // if the max logit is rejected, apply the grammar to all logits (slower)
  39. top := []token{t}
  40. s.grammar.Apply(top)
  41. if !math.IsInf(float64(top[0].value), -1) {
  42. s.grammar.Accept(top[0].id)
  43. return top[0].id, nil
  44. }
  45. // since .sample has side effects of modifying the tokens
  46. // we need to reset them before applying the grammar and
  47. // sampling again
  48. for i := range logits {
  49. tokens[i].id = int32(i)
  50. tokens[i].value = logits[i]
  51. }
  52. s.grammar.Apply(tokens)
  53. t, err = s.sample(tokens)
  54. if err != nil {
  55. return -1, err
  56. }
  57. s.grammar.Accept(t.id)
  58. }
  59. return t.id, nil
  60. }
  61. // greedy returns the highest probability token from the tokens
  62. func greedy(tokens []token) token {
  63. max := tokens[0]
  64. for i := 1; i < len(tokens); i++ {
  65. if tokens[i].value > max.value {
  66. max = tokens[i]
  67. }
  68. }
  69. return max
  70. }
  71. // sample returns the highest probability token from the tokens
  72. // given sampler parameters. It also has side effects of modifying the tokens
  73. func (s *Sampler) sample(tokens []token) (token, error) {
  74. if s.temperature == 0 {
  75. return greedy(tokens), nil
  76. }
  77. // topK also sorts the tokens in descending order of logits
  78. tokens = topK(tokens, s.topK)
  79. // scale and normalize the tokens in place
  80. temperature(tokens, s.temperature)
  81. softmax(tokens)
  82. tokens = topP(tokens, s.topP)
  83. tokens = minP(tokens, s.minP)
  84. var r float32
  85. if s.rng != nil {
  86. r = s.rng.Float32()
  87. } else {
  88. r = rand.Float32()
  89. }
  90. // Calculate cumulative sum of probabilities
  91. var sum float32
  92. for i := range tokens {
  93. sum += tokens[i].value
  94. tokens[i].value = sum
  95. }
  96. r *= tokens[len(tokens)-1].value
  97. idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
  98. if token.value < target {
  99. return -1
  100. }
  101. return 1
  102. })
  103. if math.IsNaN(float64(sum)) {
  104. return token{}, errors.New("sample: logits sum to NaN, check model output")
  105. }
  106. return tokens[idx], nil
  107. }
  108. // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
  109. func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
  110. var rng *rand.Rand
  111. if seed != -1 {
  112. // PCG requires two parameters: sequence and stream
  113. // Use original seed for sequence
  114. sequence := uint64(seed)
  115. // Use golden ratio hash to generate statistically independent seeds
  116. rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
  117. }
  118. if temperature < 0.0 {
  119. temperature = 0.0
  120. }
  121. if topP < 0.0 {
  122. topP = 0.0
  123. }
  124. if topP >= 1.0 {
  125. topP = 1.0
  126. }
  127. if minP < 0.0 {
  128. minP = 0.0
  129. }
  130. if minP >= 1.0 {
  131. minP = 1.0
  132. }
  133. return Sampler{
  134. rng: rng,
  135. topK: topK,
  136. topP: topP,
  137. minP: minP,
  138. temperature: temperature,
  139. grammar: grammar,
  140. }
  141. }
  142. type Grammar struct {
  143. vocab *Vocab
  144. grammar string
  145. sampler *llama.Sampler
  146. }
  147. func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) {
  148. v, err := vocab.Load()
  149. if err != nil {
  150. return nil, err
  151. }
  152. return &Grammar{
  153. vocab: vocab,
  154. grammar: grammar,
  155. sampler: llama.NewGrammarSampler(v, grammar),
  156. }, nil
  157. }
  158. func (g *Grammar) Apply(tokens []token) {
  159. tds := make([]llama.TokenData, len(tokens))
  160. for i, token := range tokens {
  161. tds[i].Id = token.id
  162. tds[i].Logit = token.value
  163. }
  164. g.sampler.Apply(tds)
  165. for i := range tokens {
  166. tokens[i].value = tds[i].Logit
  167. }
  168. }
  169. func (g *Grammar) Accept(token int32) {
  170. g.sampler.Accept(token)
  171. }
  172. type Vocab struct {
  173. once sync.Once
  174. vocab *llama.Vocab
  175. err error
  176. path string
  177. }
  178. func NewVocab(path string) *Vocab {
  179. return &Vocab{path: path}
  180. }
  181. // Load returns the lazily-loaded vocabulary
  182. func (v *Vocab) Load() (*llama.Vocab, error) {
  183. v.once.Do(func() {
  184. vocab, err := llama.LoadVocabFromFile(v.path)
  185. if err != nil {
  186. v.err = err
  187. return
  188. }
  189. v.vocab = vocab
  190. })
  191. return v.vocab, v.err
  192. }