samplers.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. package sample
  2. import (
  3. "errors"
  4. "math"
  5. "math/rand/v2"
  6. "slices"
  7. "sync"
  8. "github.com/ollama/ollama/llama"
  9. )
  10. // token represents information about a single token during sampling
  11. type token struct {
  12. id int32 // The token's unique identifier
  13. value float32 // The raw logit or probability from the model
  14. }
  15. type Sampler struct {
  16. rng *rand.Rand
  17. topK int
  18. topP float32
  19. minP float32
  20. temperature float32
  21. grammar *Grammar
  22. JSONSampler *JSONSampler
  23. PythonSampler *PythonSampler
  24. }
  25. func (s *Sampler) Sample(logits []float32) (int32, error) {
  26. if len(logits) == 0 {
  27. return -1, errors.New("sample: no logits provided to sample")
  28. }
  29. var err error
  30. if s.JSONSampler != nil {
  31. logits, err = s.JSONSampler.Apply(logits)
  32. if err != nil {
  33. return -1, err
  34. }
  35. }
  36. if s.PythonSampler != nil {
  37. logits, err = s.PythonSampler.ApplyMask(logits)
  38. if err != nil {
  39. return -1, err
  40. }
  41. }
  42. tokens := make([]token, len(logits))
  43. for i := range logits {
  44. tokens[i].id = int32(i)
  45. tokens[i].value = logits[i]
  46. }
  47. t, err := s.sample(tokens)
  48. if err != nil {
  49. return -1, err
  50. }
  51. if s.grammar != nil {
  52. // optimization: first check if the max logit is accepted by the grammar
  53. // if the max logit is rejected, apply the grammar to all logits (slower)
  54. top := []token{t}
  55. s.grammar.Apply(top)
  56. if !math.IsInf(float64(top[0].value), -1) {
  57. s.grammar.Accept(top[0].id)
  58. return top[0].id, nil
  59. }
  60. // since .sample has side effects of modifying the tokens
  61. // we need to reset them before applying the grammar and
  62. // sampling again
  63. for i := range logits {
  64. tokens[i].id = int32(i)
  65. tokens[i].value = logits[i]
  66. }
  67. s.grammar.Apply(tokens)
  68. t, err = s.sample(tokens)
  69. if err != nil {
  70. return -1, err
  71. }
  72. s.grammar.Accept(t.id)
  73. }
  74. return t.id, nil
  75. }
  76. // greedy returns the highest probability token from the tokens
  77. func greedy(tokens []token) token {
  78. max := tokens[0]
  79. for i := 1; i < len(tokens); i++ {
  80. if tokens[i].value > max.value {
  81. max = tokens[i]
  82. }
  83. }
  84. return max
  85. }
  86. // sample returns the highest probability token from the tokens
  87. // given sampler parameters. It also has side effects of modifying the tokens
  88. func (s *Sampler) sample(tokens []token) (token, error) {
  89. if s.temperature == 0 {
  90. return greedy(tokens), nil
  91. }
  92. // topK also sorts the tokens in descending order of logits
  93. tokens = topK(tokens, s.topK)
  94. // scale and normalize the tokens in place
  95. temperature(tokens, s.temperature)
  96. softmax(tokens)
  97. tokens = topP(tokens, s.topP)
  98. tokens = minP(tokens, s.minP)
  99. var r float32
  100. if s.rng != nil {
  101. r = s.rng.Float32()
  102. } else {
  103. r = rand.Float32()
  104. }
  105. // Calculate cumulative sum of probabilities
  106. var sum float32
  107. for i := range tokens {
  108. sum += tokens[i].value
  109. tokens[i].value = sum
  110. }
  111. r *= tokens[len(tokens)-1].value
  112. idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
  113. if token.value < target {
  114. return -1
  115. }
  116. return 1
  117. })
  118. if math.IsNaN(float64(sum)) {
  119. return token{}, errors.New("sample: logits sum to NaN, check model output")
  120. }
  121. return tokens[idx], nil
  122. }
  123. // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
  124. func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar, jsonSampler *JSONSampler, pythonSampler *PythonSampler) Sampler {
  125. var rng *rand.Rand
  126. if seed != -1 {
  127. // PCG requires two parameters: sequence and stream
  128. // Use original seed for sequence
  129. sequence := uint64(seed)
  130. // Use golden ratio hash to generate statistically independent seeds
  131. rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
  132. }
  133. if temperature < 0.0 {
  134. temperature = 0.0
  135. }
  136. if topP < 0.0 {
  137. topP = 0.0
  138. }
  139. if topP >= 1.0 {
  140. topP = 1.0
  141. }
  142. if minP < 0.0 {
  143. minP = 0.0
  144. }
  145. if minP >= 1.0 {
  146. minP = 1.0
  147. }
  148. return Sampler{
  149. rng: rng,
  150. topK: topK,
  151. topP: topP,
  152. minP: minP,
  153. temperature: temperature,
  154. grammar: grammar,
  155. JSONSampler: jsonSampler,
  156. PythonSampler: pythonSampler,
  157. }
  158. }
  159. type Grammar struct {
  160. vocab *Vocab
  161. grammar string
  162. sampler *llama.Sampler
  163. }
  164. func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) {
  165. v, err := vocab.Load()
  166. if err != nil {
  167. return nil, err
  168. }
  169. return &Grammar{
  170. vocab: vocab,
  171. grammar: grammar,
  172. sampler: llama.NewGrammarSampler(v, grammar),
  173. }, nil
  174. }
  175. func (g *Grammar) Apply(tokens []token) {
  176. tds := make([]llama.TokenData, len(tokens))
  177. for i, token := range tokens {
  178. tds[i].Id = token.id
  179. tds[i].Logit = token.value
  180. }
  181. g.sampler.Apply(tds)
  182. for i := range tokens {
  183. tokens[i].value = tds[i].Logit
  184. }
  185. }
  186. func (g *Grammar) Accept(token int32) {
  187. g.sampler.Accept(token)
  188. }
  189. type Vocab struct {
  190. once sync.Once
  191. vocab *llama.Vocab
  192. err error
  193. path string
  194. }
  195. func NewVocab(path string) *Vocab {
  196. return &Vocab{path: path}
  197. }
  198. // Load returns the lazily-loaded vocabulary
  199. func (v *Vocab) Load() (*llama.Vocab, error) {
  200. v.once.Do(func() {
  201. vocab, err := llama.LoadVocabFromFile(v.path)
  202. if err != nil {
  203. v.err = err
  204. return
  205. }
  206. v.vocab = vocab
  207. })
  208. return v.vocab, v.err
  209. }