123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 |
- package sample
- import (
- "errors"
- "math/rand/v2"
- "slices"
- )
- // Sampler is not thread-safe. Each goroutine should have its own instance
- type Sampler interface {
- Sample([]float32) (int32, error)
- }
- // logit represents information about a single token during sampling
- type logit struct {
- id int32 // The token's unique identifier
- value float32 // The raw logit or probability from the model
- }
- type weighted struct {
- rng *rand.Rand
- tokens []logit
- topK int
- topP float32
- minP float32
- temperature float32
- }
- func (s *weighted) Sample(logits []float32) (int32, error) {
- if len(s.tokens) < len(logits) {
- s.tokens = make([]logit, len(logits))
- }
- tokens := s.tokens[:len(logits)]
- for i, v := range logits {
- tokens[i].id = int32(i)
- tokens[i].value = v
- }
- // Tokens are sorted by logits in TopK or SortTokens
- if s.topK > 0 {
- tokens = topK(tokens, s.topK)
- } else {
- sortLogits(tokens)
- }
- tokens = temperature(tokens, s.temperature)
- tokens = softmax(tokens)
- tokens = topP(tokens, s.topP)
- tokens = minP(tokens, s.minP)
- if len(tokens) == 0 {
- return -1, errors.New("no valid logits found for weighted sampling")
- }
- var r float32
- if s.rng != nil {
- r = s.rng.Float32()
- } else {
- r = rand.Float32()
- }
- // Calculate cumulative sum of probabilities
- var sum float32
- for i := range tokens {
- sum += tokens[i].value
- tokens[i].value = sum
- }
- r *= tokens[len(tokens)-1].value
- idx, _ := slices.BinarySearchFunc(tokens, r, func(token logit, target float32) int {
- // Compare cumulative probabilities
- if token.value < target {
- return -1
- }
- // First token that exceeds target
- return 1
- })
- if idx >= len(tokens) {
- idx = len(tokens) - 1
- }
- return tokens[idx].id, nil
- }
- type greedy struct{}
- // Greedy sample returns the index of the maximum value in logits.
- func (s greedy) Sample(logits []float32) (int32, error) {
- if len(logits) == 0 {
- return -1, errors.New("no logits provided for greedy sampling")
- }
- maxIdx := 0
- maxVal := logits[0]
- for i := 1; i < len(logits); i++ {
- if logits[i] > maxVal {
- maxVal = logits[i]
- maxIdx = i
- }
- }
- return int32(maxIdx), nil
- }
- // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
- func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) Sampler {
- if temperature == 0 {
- return &greedy{}
- }
- var rng *rand.Rand
- if seed != -1 {
- // PCG requires two parameters: sequence and stream
- // Use original seed for sequence
- sequence := uint64(seed)
- // Use golden ratio hash to generate statistically independent seeds
- rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
- }
- temperature = max(temperature, 1)
- if topP < 0.0 {
- topP = 0.0
- }
- if topP >= 1.0 {
- topP = 1.0
- }
- if minP < 0.0 {
- minP = 0.0
- }
- if minP >= 1.0 {
- minP = 1.0
- }
- return &weighted{
- rng: rng,
- topK: topK,
- topP: topP,
- minP: minP,
- temperature: temperature,
- }
- }
|