samplers.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. package sample
  2. import (
  3. "errors"
  4. "math"
  5. "math/rand/v2"
  6. "slices"
  7. "sync"
  8. "github.com/ollama/ollama/llama"
  9. )
  10. // token represents information about a single token during sampling
  11. type token struct {
  12. id int32 // The token's unique identifier
  13. value float32 // The raw logit or probability from the model
  14. }
  15. type Sampler struct {
  16. rng *rand.Rand
  17. topK int
  18. topP float32
  19. minP float32
  20. temperature float32
  21. grammar *Grammar
  22. }
  23. func (s *Sampler) Sample(logits []float32) (int32, error) {
  24. tokens := make([]token, len(logits))
  25. for i := range logits {
  26. tokens[i].id = int32(i)
  27. tokens[i].value = logits[i]
  28. }
  29. t, err := s.sample(tokens)
  30. if err != nil {
  31. return -1, err
  32. }
  33. if s.grammar != nil {
  34. // optimization: first check if the max logit is accepted by the grammar
  35. // if the max logit is rejected, apply the grammar to all logits (slower)
  36. top := []token{t}
  37. s.grammar.Apply(top)
  38. if !math.IsInf(float64(top[0].value), -1) {
  39. s.grammar.Accept(top[0].id)
  40. return top[0].id, nil
  41. }
  42. // since .sample has side effects of modifying the tokens
  43. // we need to reset them before applying the grammar and
  44. // sampling again
  45. for i := range logits {
  46. tokens[i].id = int32(i)
  47. tokens[i].value = logits[i]
  48. }
  49. s.grammar.Apply(tokens)
  50. t, err = s.sample(tokens)
  51. if err != nil {
  52. return -1, err
  53. }
  54. s.grammar.Accept(t.id)
  55. }
  56. return t.id, nil
  57. }
  58. // greedy returns the highest probability token from the tokens
  59. func greedy(tokens []token) token {
  60. max := tokens[0]
  61. for i := 1; i < len(tokens); i++ {
  62. if tokens[i].value > max.value {
  63. max = tokens[i]
  64. }
  65. }
  66. return max
  67. }
  68. // sample returns the highest probability token from the tokens
  69. // given sampler parameters. It also has side effects of modifying the tokens
  70. func (s *Sampler) sample(tokens []token) (token, error) {
  71. if s.temperature == 0 {
  72. return greedy(tokens), nil
  73. }
  74. // topK also sorts the tokens in descending order of logits
  75. tokens = topK(tokens, s.topK)
  76. // token logit values are updated to probabilities
  77. tokens = temperature(tokens, s.temperature)
  78. tokens = topP(tokens, s.topP)
  79. tokens = minP(tokens, s.minP)
  80. // TODO: this should fall back to greedy sampling
  81. // or topP, topK values etc should be such that
  82. // there are always tokens to sample from
  83. if len(tokens) == 0 {
  84. return token{}, errors.New("no tokens to sample from")
  85. }
  86. var r float32
  87. if s.rng != nil {
  88. r = s.rng.Float32()
  89. } else {
  90. r = rand.Float32()
  91. }
  92. // Calculate cumulative sum of probabilities
  93. var sum float32
  94. for i := range tokens {
  95. sum += tokens[i].value
  96. tokens[i].value = sum
  97. }
  98. r *= tokens[len(tokens)-1].value
  99. idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
  100. if token.value < target {
  101. return -1
  102. }
  103. return 1
  104. })
  105. return tokens[idx], nil
  106. }
  107. // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
  108. func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
  109. var rng *rand.Rand
  110. if seed != -1 {
  111. // PCG requires two parameters: sequence and stream
  112. // Use original seed for sequence
  113. sequence := uint64(seed)
  114. // Use golden ratio hash to generate statistically independent seeds
  115. rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
  116. }
  117. if temperature < 0.0 {
  118. temperature = 0.0
  119. }
  120. if topP < 0.0 {
  121. topP = 0.0
  122. }
  123. if topP >= 1.0 {
  124. topP = 1.0
  125. }
  126. if minP < 0.0 {
  127. minP = 0.0
  128. }
  129. if minP >= 1.0 {
  130. minP = 1.0
  131. }
  132. return Sampler{
  133. rng: rng,
  134. topK: topK,
  135. topP: topP,
  136. minP: minP,
  137. temperature: temperature,
  138. grammar: grammar,
  139. }
  140. }
  141. type Grammar struct {
  142. vocab *Vocab
  143. grammar string
  144. sampler *llama.Sampler
  145. }
  146. func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) {
  147. v, err := vocab.Load()
  148. if err != nil {
  149. return nil, err
  150. }
  151. return &Grammar{
  152. vocab: vocab,
  153. grammar: grammar,
  154. sampler: llama.NewGrammarSampler(v, grammar),
  155. }, nil
  156. }
  157. func (g *Grammar) Apply(tokens []token) {
  158. tds := make([]llama.TokenData, len(tokens))
  159. for i, token := range tokens {
  160. tds[i].Id = token.id
  161. tds[i].Logit = token.value
  162. }
  163. g.sampler.Apply(tds)
  164. for i := range tokens {
  165. tokens[i].value = tds[i].Logit
  166. }
  167. }
  168. func (g *Grammar) Accept(token int32) {
  169. g.sampler.Accept(token)
  170. }
  171. type Vocab struct {
  172. once sync.Once
  173. vocab *llama.Vocab
  174. err error
  175. path string
  176. }
  177. func NewVocab(path string) *Vocab {
  178. return &Vocab{path: path}
  179. }
  180. // Load returns the lazily-loaded vocabulary
  181. func (v *Vocab) Load() (*llama.Vocab, error) {
  182. v.once.Do(func() {
  183. vocab, err := llama.LoadVocabFromFile(v.path)
  184. if err != nil {
  185. v.err = err
  186. return
  187. }
  188. v.vocab = vocab
  189. })
  190. return v.vocab, v.err
  191. }