jmorganca 1 mês atrás
pai
commit
9622b928b4
2 arquivos alterados com 97 adições e 67 exclusões
  1. 40 41
      sample/samplers.go
  2. 57 26
      sample/transforms.go

+ 40 - 41
sample/samplers.go

@@ -1,11 +1,10 @@
 package sample
 
 import (
-	"errors"
 	"math"
-	"math/rand/v2"
-	"slices"
+	"math/rand"
 	"sync"
+	"time"
 
 	"github.com/ollama/ollama/llama"
 )
@@ -90,53 +89,53 @@ func (s *Sampler) sample(tokens []token) (token, error) {
 		sortLogits(tokens)
 	}
 
-	// token logit values are updated to probabilities
-	tokens = temperature(tokens, s.temperature)
-
 	tokens = topP(tokens, s.topP)
 	tokens = minP(tokens, s.minP)
 
-	// TODO: this should fall back to greedy sampling
-	// or topP, topK values etc should be such that
-	// there are always tokens to sample from
-	if len(tokens) == 0 {
-		return token{}, errors.New("no tokens to sample from")
-	}
-
-	var r float32
-	if s.rng != nil {
-		r = s.rng.Float32()
-	} else {
-		r = rand.Float32()
-	}
-
-	// Calculate cumulative sum of probabilities
-	var sum float32
-	for i := range tokens {
-		sum += tokens[i].value
-		tokens[i].value = sum
-	}
-	r *= tokens[len(tokens)-1].value
-
-	idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
-		if token.value < target {
-			return -1
-		}
-		return 1
-	})
-
-	return tokens[idx], nil
+	// token logit values are updated to probabilities
+	temperature(tokens, s.temperature)
+	softmax(tokens)
+	return tokens[dist(tokens, s.rng.Int63())], nil
+
+	// // TODO: this should fall back to greedy sampling
+	// // or topP, topK values etc should be such that
+	// // there are always tokens to sample from
+	// if len(tokens) == 0 {
+	// 	return token{}, errors.New("no tokens to sample from")
+	// }
+
+	// var r float32
+	// if s.rng != nil {
+	// 	r = s.rng.Float32()
+	// } else {
+	// 	r = rand.Float32()
+	// }
+
+	// // Calculate cumulative sum of probabilities
+	// var sum float32
+	// for i := range tokens {
+	// 	sum += tokens[i].value
+	// 	tokens[i].value = sum
+	// }
+	// r *= tokens[len(tokens)-1].value
+
+	// idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
+	// 	if token.value < target {
+	// 		return -1
+	// 	}
+	// 	return 1
+	// })
+
+	// return tokens[idx], nil
 }
 
 // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
 func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
 	var rng *rand.Rand
 	if seed != -1 {
-		// PCG requires two parameters: sequence and stream
-		// Use original seed for sequence
-		sequence := uint64(seed)
-		// Use golden ratio hash to generate statistically independent seeds
-		rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
+		rng = rand.New(rand.NewSource(int64(seed)))
+	} else {
+		rng = rand.New(rand.NewSource(time.Now().UnixNano()))
 	}
 	if temperature < 0.0 {
 		temperature = 0.0

+ 57 - 26
sample/transforms.go

@@ -3,6 +3,7 @@ package sample
 import (
 	"container/heap"
 	"math"
+	"math/rand"
 	"slices"
 )
 
@@ -25,32 +26,6 @@ func (h *tokenHeap) Pop() any {
 	return x
 }
 
-// temperature applies scaling and softmax to the logits
-func temperature(ts []token, temp float32) []token {
-	// Find max logit for numerical stability
-	maxLogit := float32(math.Inf(-1))
-	for _, t := range ts {
-		if t.value > maxLogit {
-			maxLogit = t.value
-		}
-	}
-
-	// Apply temperature and compute exp(x - max)
-	temp = max(temp, 1e-7)
-	var sum float32
-	for i, v := range ts {
-		ts[i].value = float32(math.Exp(float64((v.value - maxLogit) / temp)))
-		sum += ts[i].value
-	}
-
-	// Normalize
-	for i := range ts {
-		ts[i].value /= sum
-	}
-
-	return ts
-}
-
 // topK limits the number of tokens considered to the k highest logits
 func topK(ts []token, k int) []token {
 	if k >= len(ts) {
@@ -200,3 +175,59 @@ func sortLogits(ts []token) {
 
 	partialSortLogits(ts, n)
 }
+
+func temperature(ts []token, temp float32) {
+	for i := range ts {
+		ts[i].value /= temp
+	}
+}
+
+func softmax(ts []token) {
+	if len(ts) == 0 {
+		return
+	}
+
+	// Find max logit for numerical stability
+	maxLogit := ts[0].value
+	for _, t := range ts {
+		if t.value > maxLogit {
+			maxLogit = t.value
+		}
+	}
+
+	// Compute exp(logit - maxLogit) and sum them
+	var sumExp float32
+	for i, t := range ts {
+		expVal := float32(math.Exp(float64(t.value - maxLogit)))
+		ts[i].value = expVal
+		sumExp += expVal
+	}
+
+	// Normalize probabilities
+	for i := range ts {
+		ts[i].value /= sumExp
+	}
+}
+
+// applyDist selects a token based on probabilities and seed
+func dist(ts []token, seed int64) int {
+	rng := rand.New(rand.NewSource(seed))
+
+	cdf := make([]float32, len(ts))
+	var cumSum float32
+	for i, t := range ts {
+		cumSum += t.value
+		cdf[i] = cumSum
+	}
+
+	r := rng.Float32() * cumSum
+
+	// Select token based on CDF
+	for i, probSum := range cdf {
+		if r < probSum {
+			return i
+		}
+	}
+
+	return len(ts) - 1
+}