|
@@ -1,11 +1,10 @@
|
|
package sample
|
|
package sample
|
|
|
|
|
|
import (
|
|
import (
|
|
- "errors"
|
|
|
|
"math"
|
|
"math"
|
|
- "math/rand/v2"
|
|
|
|
- "slices"
|
|
|
|
|
|
+ "math/rand"
|
|
"sync"
|
|
"sync"
|
|
|
|
+ "time"
|
|
|
|
|
|
"github.com/ollama/ollama/llama"
|
|
"github.com/ollama/ollama/llama"
|
|
)
|
|
)
|
|
@@ -90,53 +89,53 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
|
sortLogits(tokens)
|
|
sortLogits(tokens)
|
|
}
|
|
}
|
|
|
|
|
|
- // token logit values are updated to probabilities
|
|
|
|
- tokens = temperature(tokens, s.temperature)
|
|
|
|
-
|
|
|
|
tokens = topP(tokens, s.topP)
|
|
tokens = topP(tokens, s.topP)
|
|
tokens = minP(tokens, s.minP)
|
|
tokens = minP(tokens, s.minP)
|
|
|
|
|
|
- // TODO: this should fall back to greedy sampling
|
|
|
|
- // or topP, topK values etc should be such that
|
|
|
|
- // there are always tokens to sample from
|
|
|
|
- if len(tokens) == 0 {
|
|
|
|
- return token{}, errors.New("no tokens to sample from")
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- var r float32
|
|
|
|
- if s.rng != nil {
|
|
|
|
- r = s.rng.Float32()
|
|
|
|
- } else {
|
|
|
|
- r = rand.Float32()
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- // Calculate cumulative sum of probabilities
|
|
|
|
- var sum float32
|
|
|
|
- for i := range tokens {
|
|
|
|
- sum += tokens[i].value
|
|
|
|
- tokens[i].value = sum
|
|
|
|
- }
|
|
|
|
- r *= tokens[len(tokens)-1].value
|
|
|
|
-
|
|
|
|
- idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
|
|
|
|
- if token.value < target {
|
|
|
|
- return -1
|
|
|
|
- }
|
|
|
|
- return 1
|
|
|
|
- })
|
|
|
|
-
|
|
|
|
- return tokens[idx], nil
|
|
|
|
|
|
+ // token logit values are updated to probabilities
|
|
|
|
+ temperature(tokens, s.temperature)
|
|
|
|
+ softmax(tokens)
|
|
|
|
+ return tokens[dist(tokens, s.rng.Int63())], nil
|
|
|
|
+
|
|
|
|
+ // // TODO: this should fall back to greedy sampling
|
|
|
|
+ // // or topP, topK values etc should be such that
|
|
|
|
+ // // there are always tokens to sample from
|
|
|
|
+ // if len(tokens) == 0 {
|
|
|
|
+ // return token{}, errors.New("no tokens to sample from")
|
|
|
|
+ // }
|
|
|
|
+
|
|
|
|
+ // var r float32
|
|
|
|
+ // if s.rng != nil {
|
|
|
|
+ // r = s.rng.Float32()
|
|
|
|
+ // } else {
|
|
|
|
+ // r = rand.Float32()
|
|
|
|
+ // }
|
|
|
|
+
|
|
|
|
+ // // Calculate cumulative sum of probabilities
|
|
|
|
+ // var sum float32
|
|
|
|
+ // for i := range tokens {
|
|
|
|
+ // sum += tokens[i].value
|
|
|
|
+ // tokens[i].value = sum
|
|
|
|
+ // }
|
|
|
|
+ // r *= tokens[len(tokens)-1].value
|
|
|
|
+
|
|
|
|
+ // idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
|
|
|
|
+ // if token.value < target {
|
|
|
|
+ // return -1
|
|
|
|
+ // }
|
|
|
|
+ // return 1
|
|
|
|
+ // })
|
|
|
|
+
|
|
|
|
+ // return tokens[idx], nil
|
|
}
|
|
}
|
|
|
|
|
|
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
|
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
|
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
|
|
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
|
|
var rng *rand.Rand
|
|
var rng *rand.Rand
|
|
if seed != -1 {
|
|
if seed != -1 {
|
|
- // PCG requires two parameters: sequence and stream
|
|
|
|
- // Use original seed for sequence
|
|
|
|
- sequence := uint64(seed)
|
|
|
|
- // Use golden ratio hash to generate statistically independent seeds
|
|
|
|
- rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
|
|
|
|
|
|
+ rng = rand.New(rand.NewSource(int64(seed)))
|
|
|
|
+ } else {
|
|
|
|
+ rng = rand.New(rand.NewSource(time.Now().UnixNano()))
|
|
}
|
|
}
|
|
if temperature < 0.0 {
|
|
if temperature < 0.0 {
|
|
temperature = 0.0
|
|
temperature = 0.0
|