samplers.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. package sample
  2. import (
  3. "math"
  4. "math/rand"
  5. "sync"
  6. "time"
  7. "github.com/ollama/ollama/llama"
  8. )
  9. // token represents information about a single token during sampling
  10. type token struct {
  11. id int32 // The token's unique identifier
  12. value float32 // The raw logit or probability from the model
  13. }
  14. type Sampler struct {
  15. rng *rand.Rand
  16. topK int
  17. topP float32
  18. minP float32
  19. temperature float32
  20. grammar *Grammar
  21. }
  22. func (s *Sampler) Sample(logits []float32) (int32, error) {
  23. tokens := make([]token, len(logits))
  24. for i := range logits {
  25. tokens[i].id = int32(i)
  26. tokens[i].value = logits[i]
  27. }
  28. t, err := s.sample(tokens)
  29. if err != nil {
  30. return -1, err
  31. }
  32. if s.grammar != nil {
  33. // optimization: first check if the max logit is accepted by the grammar
  34. // if the max logit is rejected, apply the grammar to all logits (slower)
  35. top := []token{t}
  36. s.grammar.Apply(top)
  37. if !math.IsInf(float64(top[0].value), -1) {
  38. s.grammar.Accept(top[0].id)
  39. return top[0].id, nil
  40. }
  41. // since .sample has side effects of modifying the tokens
  42. // we need to reset them before applying the grammar and
  43. // sampling again
  44. for i := range logits {
  45. tokens[i].id = int32(i)
  46. tokens[i].value = logits[i]
  47. }
  48. s.grammar.Apply(tokens)
  49. t, err = s.sample(tokens)
  50. if err != nil {
  51. return -1, err
  52. }
  53. s.grammar.Accept(t.id)
  54. }
  55. return t.id, nil
  56. }
  57. // greedy returns the highest probability token from the tokens
  58. func greedy(tokens []token) token {
  59. max := tokens[0]
  60. for i := 1; i < len(tokens); i++ {
  61. if tokens[i].value > max.value {
  62. max = tokens[i]
  63. }
  64. }
  65. return max
  66. }
  67. // sample returns the highest probability token from the tokens
  68. // given sampler parameters. It also has side effects of modifying the tokens
  69. func (s *Sampler) sample(tokens []token) (token, error) {
  70. if s.temperature == 0 {
  71. return greedy(tokens), nil
  72. }
  73. if s.topK > 0 {
  74. tokens = topK(tokens, s.topK)
  75. } else {
  76. sortLogits(tokens)
  77. }
  78. tokens = topP(tokens, s.topP)
  79. tokens = minP(tokens, s.minP)
  80. // token logit values are updated to probabilities
  81. temperature(tokens, s.temperature)
  82. softmax(tokens)
  83. return tokens[dist(tokens, s.rng.Int63())], nil
  84. // // TODO: this should fall back to greedy sampling
  85. // // or topP, topK values etc should be such that
  86. // // there are always tokens to sample from
  87. // if len(tokens) == 0 {
  88. // return token{}, errors.New("no tokens to sample from")
  89. // }
  90. // var r float32
  91. // if s.rng != nil {
  92. // r = s.rng.Float32()
  93. // } else {
  94. // r = rand.Float32()
  95. // }
  96. // // Calculate cumulative sum of probabilities
  97. // var sum float32
  98. // for i := range tokens {
  99. // sum += tokens[i].value
  100. // tokens[i].value = sum
  101. // }
  102. // r *= tokens[len(tokens)-1].value
  103. // idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
  104. // if token.value < target {
  105. // return -1
  106. // }
  107. // return 1
  108. // })
  109. // return tokens[idx], nil
  110. }
  111. // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
  112. func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
  113. var rng *rand.Rand
  114. if seed != -1 {
  115. rng = rand.New(rand.NewSource(int64(seed)))
  116. } else {
  117. rng = rand.New(rand.NewSource(time.Now().UnixNano()))
  118. }
  119. if temperature < 0.0 {
  120. temperature = 0.0
  121. }
  122. if topP < 0.0 {
  123. topP = 0.0
  124. }
  125. if topP >= 1.0 {
  126. topP = 1.0
  127. }
  128. if minP < 0.0 {
  129. minP = 0.0
  130. }
  131. if minP >= 1.0 {
  132. minP = 1.0
  133. }
  134. return Sampler{
  135. rng: rng,
  136. topK: topK,
  137. topP: topP,
  138. minP: minP,
  139. temperature: temperature,
  140. grammar: grammar,
  141. }
  142. }
  143. type Grammar struct {
  144. vocab *Vocab
  145. grammar string
  146. sampler *llama.Sampler
  147. }
  148. func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) {
  149. v, err := vocab.Load()
  150. if err != nil {
  151. return nil, err
  152. }
  153. return &Grammar{
  154. vocab: vocab,
  155. grammar: grammar,
  156. sampler: llama.NewGrammarSampler(v, grammar),
  157. }, nil
  158. }
  159. func (g *Grammar) Apply(tokens []token) {
  160. tds := make([]llama.TokenData, len(tokens))
  161. for i, token := range tokens {
  162. tds[i].Id = token.id
  163. tds[i].Logit = token.value
  164. }
  165. g.sampler.Apply(tds)
  166. for i := range tokens {
  167. tokens[i].value = tds[i].Logit
  168. }
  169. }
  170. func (g *Grammar) Accept(token int32) {
  171. g.sampler.Accept(token)
  172. }
  173. type Vocab struct {
  174. once sync.Once
  175. vocab *llama.Vocab
  176. err error
  177. path string
  178. }
  179. func NewVocab(path string) *Vocab {
  180. return &Vocab{path: path}
  181. }
  182. // Load returns the lazily-loaded vocabulary
  183. func (v *Vocab) Load() (*llama.Vocab, error) {
  184. v.once.Do(func() {
  185. vocab, err := llama.LoadVocabFromFile(v.path)
  186. if err != nil {
  187. v.err = err
  188. return
  189. }
  190. v.vocab = vocab
  191. })
  192. return v.vocab, v.err
  193. }