ParthSareen 1 bulan lalu
induk
melakukan
a5d638dfe7
4 mengubah file dengan 194 tambahan dan 67 penghapusan
  1. 40 41
      sample/samplers.go
  2. 0 0
      sample/testdata/logits.bin
  3. 57 26
      sample/transforms.go
  4. 97 0
      sample/transforms_test.go

+ 40 - 41
sample/samplers.go

@@ -1,11 +1,10 @@
 package sample
 
 import (
-	"errors"
 	"math"
-	"math/rand/v2"
-	"slices"
+	"math/rand"
 	"sync"
+	"time"
 
 	"github.com/ollama/ollama/llama"
 )
@@ -87,53 +86,53 @@ func (s *Sampler) sample(tokens []token) (token, error) {
 	// topK also sorts the tokens in descending order of logits
 	tokens = topK(tokens, s.topK)
 
-	// token logit values are updated to probabilities
-	tokens = temperature(tokens, s.temperature)
-
 	tokens = topP(tokens, s.topP)
 	tokens = minP(tokens, s.minP)
 
-	// TODO: this should fall back to greedy sampling
-	// or topP, topK values etc should be such that
-	// there are always tokens to sample from
-	if len(tokens) == 0 {
-		return token{}, errors.New("no tokens to sample from")
-	}
-
-	var r float32
-	if s.rng != nil {
-		r = s.rng.Float32()
-	} else {
-		r = rand.Float32()
-	}
-
-	// Calculate cumulative sum of probabilities
-	var sum float32
-	for i := range tokens {
-		sum += tokens[i].value
-		tokens[i].value = sum
-	}
-	r *= tokens[len(tokens)-1].value
-
-	idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
-		if token.value < target {
-			return -1
-		}
-		return 1
-	})
-
-	return tokens[idx], nil
+	// token logit values are updated to probabilities
+	temperature(tokens, s.temperature)
+	softmax(tokens)
+	return tokens[dist(tokens, s.rng.Int63())], nil
+
+	// // TODO: this should fall back to greedy sampling
+	// // or topP, topK values etc should be such that
+	// // there are always tokens to sample from
+	// if len(tokens) == 0 {
+	// 	return token{}, errors.New("no tokens to sample from")
+	// }
+
+	// var r float32
+	// if s.rng != nil {
+	// 	r = s.rng.Float32()
+	// } else {
+	// 	r = rand.Float32()
+	// }
+
+	// // Calculate cumulative sum of probabilities
+	// var sum float32
+	// for i := range tokens {
+	// 	sum += tokens[i].value
+	// 	tokens[i].value = sum
+	// }
+	// r *= tokens[len(tokens)-1].value
+
+	// idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
+	// 	if token.value < target {
+	// 		return -1
+	// 	}
+	// 	return 1
+	// })
+
+	// return tokens[idx], nil
 }
 
 // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
 func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
 	var rng *rand.Rand
 	if seed != -1 {
-		// PCG requires two parameters: sequence and stream
-		// Use original seed for sequence
-		sequence := uint64(seed)
-		// Use golden ratio hash to generate statistically independent seeds
-		rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
+		rng = rand.New(rand.NewSource(int64(seed)))
+	} else {
+		rng = rand.New(rand.NewSource(time.Now().UnixNano()))
 	}
 	if temperature < 0.0 {
 		temperature = 0.0

File diff ditekan karena terlalu besar
+ 0 - 0
sample/testdata/logits.bin


+ 57 - 26
sample/transforms.go

@@ -3,6 +3,7 @@ package sample
 import (
 	"container/heap"
 	"math"
+	"math/rand"
 	"slices"
 )
 
@@ -25,32 +26,6 @@ func (h *tokenHeap) Pop() any {
 	return x
 }
 
-// temperature applies scaling and softmax to the logits
-func temperature(ts []token, temp float32) []token {
-	// Find max logit for numerical stability
-	maxLogit := float32(math.Inf(-1))
-	for _, t := range ts {
-		if t.value > maxLogit {
-			maxLogit = t.value
-		}
-	}
-
-	// Apply temperature and compute exp(x - max)
-	temp = max(temp, 1e-7)
-	var sum float32
-	for i, v := range ts {
-		ts[i].value = float32(math.Exp(float64((v.value - maxLogit) / temp)))
-		sum += ts[i].value
-	}
-
-	// Normalize
-	for i := range ts {
-		ts[i].value /= sum
-	}
-
-	return ts
-}
-
 // topK limits the number of tokens considered to the k highest logits
 func topK(ts []token, k int) []token {
 	if k >= len(ts) || k <= 0 {
@@ -134,3 +109,59 @@ func minP(ts []token, p float32) []token {
 	ts = validTokens
 	return ts
 }
+
+func temperature(ts []token, temp float32) {
+	for i := range ts {
+		ts[i].value /= temp
+	}
+}
+
+func softmax(ts []token) {
+	if len(ts) == 0 {
+		return
+	}
+
+	// Find max logit for numerical stability
+	maxLogit := ts[0].value
+	for _, t := range ts {
+		if t.value > maxLogit {
+			maxLogit = t.value
+		}
+	}
+
+	// Compute exp(logit - maxLogit) and sum them
+	var sumExp float32
+	for i, t := range ts {
+		expVal := float32(math.Exp(float64(t.value - maxLogit)))
+		ts[i].value = expVal
+		sumExp += expVal
+	}
+
+	// Normalize probabilities
+	for i := range ts {
+		ts[i].value /= sumExp
+	}
+}
+
+// applyDist selects a token based on probabilities and seed
+func dist(ts []token, seed int64) int {
+	rng := rand.New(rand.NewSource(seed))
+
+	cdf := make([]float32, len(ts))
+	var cumSum float32
+	for i, t := range ts {
+		cumSum += t.value
+		cdf[i] = cumSum
+	}
+
+	r := rng.Float32() * cumSum
+
+	// Select token based on CDF
+	for i, probSum := range cdf {
+		if r < probSum {
+			return i
+		}
+	}
+
+	return len(ts) - 1
+}

+ 97 - 0
sample/transforms_test.go

@@ -1,8 +1,13 @@
 package sample
 
 import (
+	"encoding/binary"
+	"errors"
 	"math"
 	"math/rand/v2"
+	"os"
+	"path/filepath"
+	"runtime"
 	"testing"
 )
 
@@ -143,6 +148,98 @@ func TestSortLogits(t *testing.T) {
 	compareLogits(t, "sortLogits", want, tokens)
 }
 
+// TestSortLogitsWithRealData tests sorting behavior using real model logit distributions
+func TestSortLogitsWithRealData(t *testing.T) {
+	// This will be populated from testdata/logits.bin
+	// Format: 32-bit float array in binary format
+	logits, err := loadTestLogits(t)
+	if err != nil {
+		t.Skipf("Skipping real logit test: %v", err)
+		return
+	}
+
+	tokens := toTokens(logits)
+	sortLogits(tokens)
+
+	// Calculate n for verification
+	n := int(math.Sqrt(float64(len(tokens)))) + 1
+	if n > 1000 {
+		n = 1000
+	} else if n < 100 {
+		n = 100
+	}
+
+	t.Logf("Testing with %d tokens, partial sorting top %d", len(tokens), n)
+
+	// Only verify the top n elements are sorted (which is what we guarantee)
+	// This is much faster than checking the entire array
+	topN := tokens[:n]
+	for i := 1; i < len(topN); i++ {
+		if topN[i].value > topN[i-1].value {
+			t.Fatalf("top %d tokens not properly sorted at index %d: %.15f > %.15f",
+				n, i, topN[i].value, topN[i-1].value)
+		}
+	}
+
+	// Verify we didn't lose any high value tokens by checking that
+	// all tokens after position n are <= the nth token
+	// Do this in chunks to avoid timeouts on large arrays
+	nthValue := tokens[n-1].value
+	const chunkSize = 1000
+
+	for start := n; start < len(tokens); start += chunkSize {
+		end := min(start+chunkSize, len(tokens))
+		for i := start; i < end; i++ {
+			if tokens[i].value > nthValue {
+				t.Fatalf("found higher value token after position %d: tokens[%d].value = %.15f > %.15f",
+					n, i, tokens[i].value, nthValue)
+			}
+		}
+	}
+}
+
+// loadTestLogits loads logit test data from testdata/logits.bin
+func loadTestLogits(t *testing.T) ([]float32, error) {
+	t.Helper()
+
+	_, currFile, _, ok := runtime.Caller(0)
+	if !ok {
+		return nil, errors.New("could not determine test file path")
+	}
+	testDataPath := filepath.Join(filepath.Dir(currFile), "testdata", "logits.bin")
+
+	file, err := os.Open(testDataPath)
+	if err != nil {
+		return nil, err
+	}
+	defer file.Close()
+
+	stat, err := file.Stat()
+	if err != nil {
+		return nil, err
+	}
+
+	numFloats := stat.Size() / 4 // each float32 is 4 bytes
+	if numFloats*4 != stat.Size() {
+		return nil, errors.New("logits.bin has invalid size: not a multiple of 4 bytes")
+	}
+
+	logits := make([]float32, numFloats)
+	for i := range logits {
+		var val uint32
+		if err := binary.Read(file, binary.LittleEndian, &val); err != nil {
+			return nil, err
+		}
+		logits[i] = math.Float32frombits(val)
+	}
+
+	if len(logits) == 0 {
+		return nil, errors.New("logits.bin is empty")
+	}
+
+	return logits, nil
+}
+
 func BenchmarkTransforms(b *testing.B) {
 	// Generate random logits
 	tokens := make([]token, 1<<16)

Beberapa file tidak ditampilkan karena terlalu banyak file yang berubah dalam diff ini