|
@@ -2,43 +2,88 @@ package sample
|
|
|
|
|
|
import (
|
|
|
"errors"
|
|
|
+ "math"
|
|
|
"math/rand/v2"
|
|
|
"slices"
|
|
|
-)
|
|
|
+ "sync"
|
|
|
|
|
|
-// Sampler is not thread-safe. Each goroutine should have its own instance
|
|
|
-type Sampler interface {
|
|
|
- Sample([]float32) (int32, error)
|
|
|
-}
|
|
|
+ "github.com/ollama/ollama/llama"
|
|
|
+)
|
|
|
|
|
|
-// logit represents information about a single token during sampling
|
|
|
-type logit struct {
|
|
|
+// token represents information about a single token during sampling
|
|
|
+type token struct {
|
|
|
id int32 // The token's unique identifier
|
|
|
value float32 // The raw logit or probability from the model
|
|
|
}
|
|
|
|
|
|
-type weighted struct {
|
|
|
+type Sampler struct {
|
|
|
rng *rand.Rand
|
|
|
- tokens []logit
|
|
|
topK int
|
|
|
topP float32
|
|
|
minP float32
|
|
|
temperature float32
|
|
|
+ grammar *Grammar
|
|
|
}
|
|
|
|
|
|
-func (s *weighted) Sample(logits []float32) (int32, error) {
|
|
|
- if len(s.tokens) < len(logits) {
|
|
|
- s.tokens = make([]logit, len(logits))
|
|
|
+func (s *Sampler) Sample(logits []float32) (int32, error) {
|
|
|
+ tokens := make([]token, len(logits))
|
|
|
+ for i := range logits {
|
|
|
+ tokens[i].id = int32(i)
|
|
|
+ tokens[i].value = logits[i]
|
|
|
}
|
|
|
|
|
|
- tokens := s.tokens[:len(logits)]
|
|
|
+ t, err := s.sample(tokens)
|
|
|
+ if err != nil {
|
|
|
+ return -1, err
|
|
|
+ }
|
|
|
|
|
|
- for i, v := range logits {
|
|
|
- tokens[i].id = int32(i)
|
|
|
- tokens[i].value = v
|
|
|
+ if s.grammar != nil {
|
|
|
+ // optimization: first check if the max logit is accepted by the grammar
|
|
|
+ // if the max logit is rejected, apply the grammar to all logits (slower)
|
|
|
+ top := []token{t}
|
|
|
+ s.grammar.Apply(top)
|
|
|
+ if !math.IsInf(float64(top[0].value), -1) {
|
|
|
+ s.grammar.Accept(top[0].id)
|
|
|
+ return top[0].id, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ // since .sample has side effects of modifying the tokens
|
|
|
+ // we need to reset them before applying the grammar and
|
|
|
+ // sampling again
|
|
|
+ for i := range logits {
|
|
|
+ tokens[i].id = int32(i)
|
|
|
+ tokens[i].value = logits[i]
|
|
|
+ }
|
|
|
+ s.grammar.Apply(tokens)
|
|
|
+ t, err = s.sample(tokens)
|
|
|
+ if err != nil {
|
|
|
+ return -1, err
|
|
|
+ }
|
|
|
+ s.grammar.Accept(t.id)
|
|
|
+ }
|
|
|
+
|
|
|
+ return t.id, nil
|
|
|
+}
|
|
|
+
|
|
|
+// greedy returns the highest probability token from the tokens
|
|
|
+func greedy(tokens []token) token {
|
|
|
+ max := tokens[0]
|
|
|
+ for i := 1; i < len(tokens); i++ {
|
|
|
+ if tokens[i].value > max.value {
|
|
|
+ max = tokens[i]
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return max
|
|
|
+}
|
|
|
+
|
|
|
+// sample returns the highest probability token from the tokens
|
|
|
+// given sampler parameters. It also has side effects of modifying the tokens
|
|
|
+func (s *Sampler) sample(tokens []token) (token, error) {
|
|
|
+ if s.temperature == 0 {
|
|
|
+ return greedy(tokens), nil
|
|
|
}
|
|
|
|
|
|
- // Tokens are sorted by logits in TopK or SortTokens
|
|
|
if s.topK > 0 {
|
|
|
tokens = topK(tokens, s.topK)
|
|
|
} else {
|
|
@@ -47,12 +92,14 @@ func (s *weighted) Sample(logits []float32) (int32, error) {
|
|
|
|
|
|
tokens = temperature(tokens, s.temperature)
|
|
|
tokens = softmax(tokens)
|
|
|
-
|
|
|
tokens = topP(tokens, s.topP)
|
|
|
tokens = minP(tokens, s.minP)
|
|
|
|
|
|
+ // TODO: this should fall back to greedy sampling
|
|
|
+ // or topP, topK values etc should be such that
|
|
|
+ // there are always tokens to sample from
|
|
|
if len(tokens) == 0 {
|
|
|
- return -1, errors.New("no valid logits found for weighted sampling")
|
|
|
+ return token{}, errors.New("no tokens to sample from")
|
|
|
}
|
|
|
|
|
|
var r float32
|
|
@@ -70,48 +117,18 @@ func (s *weighted) Sample(logits []float32) (int32, error) {
|
|
|
}
|
|
|
r *= tokens[len(tokens)-1].value
|
|
|
|
|
|
- idx, _ := slices.BinarySearchFunc(tokens, r, func(token logit, target float32) int {
|
|
|
- // Compare cumulative probabilities
|
|
|
+ idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
|
|
|
if token.value < target {
|
|
|
return -1
|
|
|
}
|
|
|
- // First token that exceeds target
|
|
|
return 1
|
|
|
})
|
|
|
|
|
|
- if idx >= len(tokens) {
|
|
|
- idx = len(tokens) - 1
|
|
|
- }
|
|
|
-
|
|
|
- return tokens[idx].id, nil
|
|
|
-}
|
|
|
-
|
|
|
-type greedy struct{}
|
|
|
-
|
|
|
-// Greedy sample returns the index of the maximum value in logits.
|
|
|
-func (s greedy) Sample(logits []float32) (int32, error) {
|
|
|
- if len(logits) == 0 {
|
|
|
- return -1, errors.New("no logits provided for greedy sampling")
|
|
|
- }
|
|
|
-
|
|
|
- maxIdx := 0
|
|
|
- maxVal := logits[0]
|
|
|
- for i := 1; i < len(logits); i++ {
|
|
|
- if logits[i] > maxVal {
|
|
|
- maxVal = logits[i]
|
|
|
- maxIdx = i
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- return int32(maxIdx), nil
|
|
|
+ return tokens[idx], nil
|
|
|
}
|
|
|
|
|
|
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
|
|
-func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) Sampler {
|
|
|
- if temperature == 0 {
|
|
|
- return &greedy{}
|
|
|
- }
|
|
|
-
|
|
|
+func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
|
|
|
var rng *rand.Rand
|
|
|
if seed != -1 {
|
|
|
// PCG requires two parameters: sequence and stream
|
|
@@ -120,7 +137,9 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
|
|
|
// Use golden ratio hash to generate statistically independent seeds
|
|
|
rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
|
|
|
}
|
|
|
- temperature = max(temperature, 1)
|
|
|
+ if temperature < 0.0 {
|
|
|
+ temperature = 0.0
|
|
|
+ }
|
|
|
|
|
|
if topP < 0.0 {
|
|
|
topP = 0.0
|
|
@@ -136,11 +155,73 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
|
|
|
minP = 1.0
|
|
|
}
|
|
|
|
|
|
- return &weighted{
|
|
|
+ return Sampler{
|
|
|
rng: rng,
|
|
|
topK: topK,
|
|
|
topP: topP,
|
|
|
minP: minP,
|
|
|
temperature: temperature,
|
|
|
+ grammar: grammar,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+type Grammar struct {
|
|
|
+ vocab *Vocab
|
|
|
+ grammar string
|
|
|
+ sampler *llama.Sampler
|
|
|
+}
|
|
|
+
|
|
|
+func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) {
|
|
|
+ v, err := vocab.Load()
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ return &Grammar{
|
|
|
+ vocab: vocab,
|
|
|
+ grammar: grammar,
|
|
|
+ sampler: llama.NewGrammarSampler(v, grammar),
|
|
|
+ }, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (g *Grammar) Apply(tokens []token) {
|
|
|
+ tds := make([]llama.TokenData, len(tokens))
|
|
|
+ for i, token := range tokens {
|
|
|
+ tds[i].Id = token.id
|
|
|
+ tds[i].Logit = token.value
|
|
|
}
|
|
|
+
|
|
|
+ g.sampler.Apply(tds)
|
|
|
+
|
|
|
+ for i := range tokens {
|
|
|
+ tokens[i].value = tds[i].Logit
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (g *Grammar) Accept(token int32) {
|
|
|
+ g.sampler.Accept(token)
|
|
|
+}
|
|
|
+
|
|
|
+type Vocab struct {
|
|
|
+ once sync.Once
|
|
|
+ vocab *llama.Vocab
|
|
|
+ err error
|
|
|
+ path string
|
|
|
+}
|
|
|
+
|
|
|
+func NewVocab(path string) *Vocab {
|
|
|
+ return &Vocab{path: path}
|
|
|
+}
|
|
|
+
|
|
|
+// Load returns the lazily-loaded vocabulary
|
|
|
+func (v *Vocab) Load() (*llama.Vocab, error) {
|
|
|
+ v.once.Do(func() {
|
|
|
+ vocab, err := llama.LoadVocabFromFile(v.path)
|
|
|
+ if err != nil {
|
|
|
+ v.err = err
|
|
|
+ return
|
|
|
+ }
|
|
|
+ v.vocab = vocab
|
|
|
+ })
|
|
|
+ return v.vocab, v.err
|
|
|
}
|