samplers.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. package sample
  2. import (
  3. "errors"
  4. "math"
  5. "math/rand/v2"
  6. "slices"
  7. "sync"
  8. "github.com/ollama/ollama/llama"
  9. )
  10. // token represents information about a single token during sampling
  11. type token struct {
  12. id int32 // The token's unique identifier
  13. value float32 // The raw logit or probability from the model
  14. }
  15. type Sampler struct {
  16. rng *rand.Rand
  17. topK int
  18. topP float32
  19. minP float32
  20. temperature float32
  21. grammar *Grammar
  22. }
  23. func (s *Sampler) Sample(logits []float32) (int32, error) {
  24. tokens := make([]token, len(logits))
  25. for i := range logits {
  26. tokens[i].id = int32(i)
  27. tokens[i].value = logits[i]
  28. }
  29. t, err := s.sample(tokens)
  30. if err != nil {
  31. return -1, err
  32. }
  33. if s.grammar != nil {
  34. // optimization: first check if the max logit is accepted by the grammar
  35. // if the max logit is rejected, apply the grammar to all logits (slower)
  36. top := []token{t}
  37. s.grammar.Apply(top)
  38. if !math.IsInf(float64(top[0].value), -1) {
  39. s.grammar.Accept(top[0].id)
  40. return top[0].id, nil
  41. }
  42. // since .sample has side effects of modifying the tokens
  43. // we need to reset them before applying the grammar and
  44. // sampling again
  45. for i := range logits {
  46. tokens[i].id = int32(i)
  47. tokens[i].value = logits[i]
  48. }
  49. s.grammar.Apply(tokens)
  50. t, err = s.sample(tokens)
  51. if err != nil {
  52. return -1, err
  53. }
  54. s.grammar.Accept(t.id)
  55. }
  56. return t.id, nil
  57. }
  58. // greedy returns the highest probability token from the tokens
  59. func greedy(tokens []token) token {
  60. max := tokens[0]
  61. for i := 1; i < len(tokens); i++ {
  62. if tokens[i].value > max.value {
  63. max = tokens[i]
  64. }
  65. }
  66. return max
  67. }
  68. // sample returns the highest probability token from the tokens
  69. // given sampler parameters. It also has side effects of modifying the tokens
  70. func (s *Sampler) sample(tokens []token) (token, error) {
  71. if s.temperature == 0 {
  72. return greedy(tokens), nil
  73. }
  74. if s.topK > 0 {
  75. tokens = topK(tokens, s.topK)
  76. } else {
  77. sortLogits(tokens)
  78. }
  79. tokens = temperature(tokens, s.temperature)
  80. tokens = softmax(tokens)
  81. tokens = topP(tokens, s.topP)
  82. tokens = minP(tokens, s.minP)
  83. // TODO: this should fall back to greedy sampling
  84. // or topP, topK values etc should be such that
  85. // there are always tokens to sample from
  86. if len(tokens) == 0 {
  87. return token{}, errors.New("no tokens to sample from")
  88. }
  89. var r float32
  90. if s.rng != nil {
  91. r = s.rng.Float32()
  92. } else {
  93. r = rand.Float32()
  94. }
  95. // Calculate cumulative sum of probabilities
  96. var sum float32
  97. for i := range tokens {
  98. sum += tokens[i].value
  99. tokens[i].value = sum
  100. }
  101. r *= tokens[len(tokens)-1].value
  102. idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
  103. if token.value < target {
  104. return -1
  105. }
  106. return 1
  107. })
  108. return tokens[idx], nil
  109. }
  110. // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
  111. func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
  112. var rng *rand.Rand
  113. if seed != -1 {
  114. // PCG requires two parameters: sequence and stream
  115. // Use original seed for sequence
  116. sequence := uint64(seed)
  117. // Use golden ratio hash to generate statistically independent seeds
  118. rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
  119. }
  120. if temperature < 0.0 {
  121. temperature = 0.0
  122. }
  123. if topP < 0.0 {
  124. topP = 0.0
  125. }
  126. if topP >= 1.0 {
  127. topP = 1.0
  128. }
  129. if minP < 0.0 {
  130. minP = 0.0
  131. }
  132. if minP >= 1.0 {
  133. minP = 1.0
  134. }
  135. return Sampler{
  136. rng: rng,
  137. topK: topK,
  138. topP: topP,
  139. minP: minP,
  140. temperature: temperature,
  141. grammar: grammar,
  142. }
  143. }
  144. type Grammar struct {
  145. vocab *Vocab
  146. grammar string
  147. sampler *llama.Sampler
  148. }
  149. func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) {
  150. v, err := vocab.Load()
  151. if err != nil {
  152. return nil, err
  153. }
  154. return &Grammar{
  155. vocab: vocab,
  156. grammar: grammar,
  157. sampler: llama.NewGrammarSampler(v, grammar),
  158. }, nil
  159. }
  160. func (g *Grammar) Apply(tokens []token) {
  161. tds := make([]llama.TokenData, len(tokens))
  162. for i, token := range tokens {
  163. tds[i].Id = token.id
  164. tds[i].Logit = token.value
  165. }
  166. g.sampler.Apply(tds)
  167. for i := range tokens {
  168. tokens[i].value = tds[i].Logit
  169. }
  170. }
  171. func (g *Grammar) Accept(token int32) {
  172. g.sampler.Accept(token)
  173. }
  174. type Vocab struct {
  175. once sync.Once
  176. vocab *llama.Vocab
  177. err error
  178. path string
  179. }
  180. func NewVocab(path string) *Vocab {
  181. return &Vocab{path: path}
  182. }
  183. // Load returns the lazily-loaded vocabulary
  184. func (v *Vocab) Load() (*llama.Vocab, error) {
  185. v.once.Do(func() {
  186. vocab, err := llama.LoadVocabFromFile(v.path)
  187. if err != nil {
  188. v.err = err
  189. return
  190. }
  191. v.vocab = vocab
  192. })
  193. return v.vocab, v.err
  194. }