Преглед на файлове

sample: improve ollama engine sampler performance (#9374)

This change bring in various interface cleanups along with greatly improving the performance of the sampler.

Tested with llama3.2 on local machine.
Improves performance from ~ 70 tokens/s -> 135 tokens/s with topK(40) enabled.
Without topK performance is ~ 110 tokens/s
Parth Sareen преди 1 месец
родител
ревизия
0682dae027
променени са 7 файла, в които са добавени 548 реда и са изтрити 307 реда
  1. 1 1
      go.mod
  2. 9 1
      runner/ollamarunner/runner.go
  3. 91 65
      sample/samplers.go
  4. 104 0
      sample/samplers_benchmark_test.go
  5. 35 120
      sample/samplers_test.go
  6. 157 74
      sample/transforms.go
  7. 151 46
      sample/transforms_test.go

+ 1 - 1
go.mod

@@ -25,7 +25,6 @@ require (
 	github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
 	golang.org/x/image v0.22.0
 	golang.org/x/tools v0.30.0
-	gonum.org/v1/gonum v0.15.0
 )
 
 require (
@@ -45,6 +44,7 @@ require (
 	github.com/xtgo/set v1.0.0 // indirect
 	go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
 	golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
+	gonum.org/v1/gonum v0.15.0 // indirect
 	gorgonia.org/vecf32 v0.9.0 // indirect
 	gorgonia.org/vecf64 v0.9.0 // indirect
 )

+ 9 - 1
runner/ollamarunner/runner.go

@@ -589,11 +589,19 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
+	sampler := sample.NewSampler(
+		req.Temperature,
+		req.TopK,
+		req.TopP,
+		req.MinP,
+		req.Seed,
+	)
+
 	seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
 		numPredict: req.NumPredict,
 		stop:       req.Stop,
 		numKeep:    int32(req.NumKeep),
-		sampler:    sample.Greedy(), // TODO: add support for different samplers when performance is optimized
+		sampler:    sampler,
 		embedding:  false,
 	})
 	if err != nil {

+ 91 - 65
sample/samplers.go

@@ -2,76 +2,103 @@ package sample
 
 import (
 	"errors"
-	"math"
-
-	"golang.org/x/exp/rand"
-	"gonum.org/v1/gonum/stat/sampleuv"
+	"math/rand/v2"
+	"slices"
 )
 
+// Sampler is not thread-safe. Each goroutine should have its own instance
 type Sampler interface {
 	Sample([]float32) (int32, error)
 }
 
+// logit represents information about a single token during sampling
+type logit struct {
+	id    int32   // The token's unique identifier
+	value float32 // The raw logit or probability from the model
+}
+
 type weighted struct {
-	src        rand.Source
-	transforms []Transform
+	rng         *rand.Rand
+	tokens      []logit
+	topK        int
+	topP        float32
+	minP        float32
+	temperature float32
 }
 
-// TODO(parthsareen): remove uv sample dependency https://github.com/ollama/ollama/issues/9279
-func Weighted(seed *uint64, transforms ...Transform) Sampler {
-	var src rand.Source
-	if seed != nil {
-		src = rand.NewSource(*seed)
+func (s *weighted) Sample(logits []float32) (int32, error) {
+	if len(s.tokens) < len(logits) {
+		s.tokens = make([]logit, len(logits))
 	}
-	return weighted{src: src, transforms: transforms}
-}
 
-func (s weighted) Sample(logits []float32) (int32, error) {
-	logits64 := make([]float64, len(logits))
+	tokens := s.tokens[:len(logits)]
+
 	for i, v := range logits {
-		logits64[i] = float64(v)
+		tokens[i].id = int32(i)
+		tokens[i].value = v
 	}
 
-	for _, t := range s.transforms {
-		logits64 = t.Apply(logits64)
+	// Tokens are sorted by logits in TopK or SortTokens
+	if s.topK > 0 {
+		tokens = topK(tokens, s.topK)
+	} else {
+		sortLogits(tokens)
 	}
 
-	logitsCopy := make([]float64, 0, len(logits))
-	indices := make([]int, 0, len(logits))
-	for i, logit := range logits64 {
-		if !math.IsInf(logit, -1) {
-			logitsCopy = append(logitsCopy, logit)
-			indices = append(indices, i)
-		}
+	tokens = temperature(tokens, s.temperature)
+	tokens = softmax(tokens)
+
+	tokens = topP(tokens, s.topP)
+	tokens = minP(tokens, s.minP)
+
+	if len(tokens) == 0 {
+		return -1, errors.New("no valid logits found for weighted sampling")
 	}
 
-	if len(logitsCopy) == 0 {
-		return -1, errors.New("no valid logits found for weighed sampling")
+	var r float32
+	if s.rng != nil {
+		r = s.rng.Float32()
+	} else {
+		r = rand.Float32()
 	}
 
-	probs := softmax(logitsCopy)
-	w := sampleuv.NewWeighted(probs, s.src)
-	if idx, ok := w.Take(); ok {
-		return int32(indices[idx]), nil
+	// Calculate cumulative sum of probabilities
+	var sum float32
+	for i := range tokens {
+		sum += tokens[i].value
+		tokens[i].value = sum
 	}
-	return -1, errors.New("weighted sampler failed, no valid token found")
-}
+	r *= tokens[len(tokens)-1].value
 
-type greedy struct{}
+	idx, _ := slices.BinarySearchFunc(tokens, r, func(token logit, target float32) int {
+		// Compare cumulative probabilities
+		if token.value < target {
+			return -1
+		}
+		// First token that exceeds target
+		return 1
+	})
 
-func Greedy() Sampler {
-	return greedy{}
+	if idx >= len(tokens) {
+		idx = len(tokens) - 1
+	}
+
+	return tokens[idx].id, nil
 }
 
-// Sample returns the index of the maximum value in logits.
+type greedy struct{}
+
+// Greedy sample returns the index of the maximum value in logits.
 func (s greedy) Sample(logits []float32) (int32, error) {
 	if len(logits) == 0 {
 		return -1, errors.New("no logits provided for greedy sampling")
 	}
 
 	maxIdx := 0
-	for i := range logits {
-		if logits[i] > logits[maxIdx] {
+	maxVal := logits[0]
+	for i := 1; i < len(logits); i++ {
+		if logits[i] > maxVal {
+			maxVal = logits[i]
 			maxIdx = i
 		}
 	}
@@ -80,41 +107,40 @@ func (s greedy) Sample(logits []float32) (int32, error) {
 }
 
 // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
-func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) (Sampler, error) {
+func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) Sampler {
 	if temperature == 0 {
-		return Greedy(), nil
+		return &greedy{}
 	}
 
-	if temperature < 0 || temperature > 2 {
-		return nil, errors.New("temperature must be between 0 and 2")
+	var rng *rand.Rand
+	if seed != -1 {
+		// PCG requires two parameters: sequence and stream
+		// Use original seed for sequence
+		sequence := uint64(seed)
+		// Use golden ratio hash to generate statistically independent seeds
+		rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
 	}
+	temperature = max(temperature, 1)
 
-	transforms := []Transform{Temperature(temperature)}
-
-	if topK != 0 {
-		if topK <= 0 {
-			return nil, errors.New("topK must be greater than 0")
-		}
-		transforms = append(transforms, TopK(topK))
+	if topP < 0.0 {
+		topP = 0.0
 	}
-
-	if topP != 0 {
-		if topP < 0 || topP >= 1 {
-			return nil, errors.New("topP must be between 0 and 1")
-		}
-		transforms = append(transforms, TopP(topP))
+	if topP >= 1.0 {
+		topP = 1.0
 	}
 
-	if minP != 0 {
-		if minP < 0 || minP >= 1 {
-			return nil, errors.New("minP must be between 0 and 1")
-		}
-		transforms = append(transforms, MinP(minP))
+	if minP < 0.0 {
+		minP = 0.0
+	}
+	if minP >= 1.0 {
+		minP = 1.0
 	}
 
-	if seed >= 0 {
-		seed64 := uint64(seed)
-		return Weighted(&seed64, transforms...), nil
+	return &weighted{
+		rng:         rng,
+		topK:        topK,
+		topP:        topP,
+		minP:        minP,
+		temperature: temperature,
 	}
-	return Weighted(nil, transforms...), nil
 }

+ 104 - 0
sample/samplers_benchmark_test.go

@@ -0,0 +1,104 @@
+package sample
+
+import (
+	"fmt"
+	"math/rand"
+	"testing"
+)
+
+func BenchmarkWeightedSampler(b *testing.B) {
+	sizes := []int{10, 100, 1000, 10000}
+
+	for _, size := range sizes {
+		b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) {
+			logits := make([]float32, size)
+			for i := range logits {
+				logits[i] = float32(rand.Float64()*10 - 5)
+			}
+
+			sampler := NewSampler(0.8, 0, 0, 0, 42)
+			b.ResetTimer()
+			for b.Loop() {
+				_, err := sampler.Sample(logits)
+				if err != nil {
+					b.Fatalf("Sampling failed: %v", err)
+				}
+			}
+		})
+	}
+
+	configs := []struct {
+		name        string
+		temperature float32
+		topK        int
+		topP        float32
+		minP        float32
+		seed        int
+	}{
+		{"Greedy", 0, -1, 0, 0, -1},
+		{"Temperature", 0.8, -1, 0, 0, -1},
+		{"TopK", 0.8, 50, 0, 0, -1},
+		{"TopP", 0.8, -1, 0.9, 0, -1},
+		{"MinP", 0.8, -1, 0, 0.05, -1},
+		{"WithSeed", 0.8, 50, 0, 0, 42},
+	}
+
+	// Fixed size for common vocab size
+	size := 128000
+	logits := make([]float32, size)
+	for i := range logits {
+		logits[i] = float32(rand.Float64()*10 - 5)
+	}
+
+	for _, tc := range configs {
+		b.Run("Config"+tc.name, func(b *testing.B) {
+			sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed)
+			sampler.Sample(logits)
+
+			b.ResetTimer()
+
+			for b.Loop() {
+				_, err := sampler.Sample(logits)
+				if err != nil {
+					b.Fatalf("Sampling failed: %v", err)
+				}
+			}
+		})
+	}
+
+	// Test with combined transforms separately - topK influences performance greatly
+	b.Run("TransformCombined", func(b *testing.B) {
+		sampler := NewSampler(0.8, 50, 0.9, 0.05, 42)
+		b.ResetTimer()
+
+		for b.Loop() {
+			_, err := sampler.Sample(logits)
+			if err != nil {
+				b.Fatalf("Sampling failed: %v", err)
+			}
+		}
+	})
+}
+
+func BenchmarkGreedySampler(b *testing.B) {
+	sizes := []int{10, 100, 1000, 10000, 100000}
+
+	for _, size := range sizes {
+		b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) {
+			logits := make([]float32, size)
+			for i := range logits {
+				logits[i] = float32(rand.Float64()*10 - 5)
+			}
+
+			sampler := NewSampler(0, -1, 0, 0, -1)
+			b.ResetTimer()
+
+			for b.Loop() {
+				_, err := sampler.Sample(logits)
+				if err != nil {
+					b.Fatalf("Sampling failed: %v", err)
+				}
+			}
+		})
+	}
+}

+ 35 - 120
sample/samplers_test.go

@@ -1,15 +1,14 @@
 package sample
 
 import (
-	"math"
 	"math/rand/v2"
 	"testing"
-
-	"github.com/google/go-cmp/cmp"
 )
 
 func TestWeighted(t *testing.T) {
-	got, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))})
+	logits := []float32{-10, 3, -10, -10}
+	sampler := NewSampler(0, 0, 0, 0, 0)
+	got, err := sampler.Sample(logits)
 	if err != nil {
 		t.Error(err)
 		return
@@ -19,64 +18,19 @@ func TestWeighted(t *testing.T) {
 		t.Errorf("index mismatch: want %d, got %d", want, got)
 	}
 
-	got, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))})
-	if err == nil {
-		t.Error("expected error for no valid tokens, got index", got)
-	}
-
-	seed := uint64(42)
-	got, err = Weighted(&seed).Sample([]float32{1, 2, 3, 4})
+	logits = []float32{-100, -10, 0, 10}
+	sampler = NewSampler(0, 0, 0, 0, 0)
+	got, err = sampler.Sample(logits)
 	if err != nil {
 		t.Error(err)
 		return
 	}
-	// With seed 42, we expect a consistent sample
-	want = int32(3) // This will be deterministic due to the seed
+	want = int32(3) // Should pick highest probability with this r value
 	if want != got {
 		t.Errorf("index mismatch: want %d, got %d", want, got)
 	}
 }
 
-type testTransform struct {
-	id        int
-	callOrder *[]int
-}
-
-func (ts *testTransform) Apply(logits []float64) []float64 {
-	if ts.callOrder != nil {
-		*ts.callOrder = append(*ts.callOrder, ts.id)
-	}
-	return logits
-}
-
-func TestSample(t *testing.T) {
-	input := []float32{1, 2, 3, 4}
-
-	var callOrder []int
-	mock1 := &testTransform{
-		id:        1,
-		callOrder: &callOrder,
-	}
-	mock2 := &testTransform{
-		id:        2,
-		callOrder: &callOrder,
-	}
-	mock3 := &testTransform{
-		id:        3,
-		callOrder: &callOrder,
-	}
-
-	_, err := Weighted(nil, mock1, mock2, mock3).Sample(input)
-	if err != nil {
-		t.Error(err)
-		return
-	}
-	wantOrder := []int{1, 2, 3}
-	if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
-		t.Errorf("call order mismatch (-want +got):\n%s", diff)
-	}
-}
-
 func TestNewSampler(t *testing.T) {
 	tests := []struct {
 		name        string
@@ -85,75 +39,41 @@ func TestNewSampler(t *testing.T) {
 		topP        float32
 		minP        float32
 		seed        int
-		wantErr     bool
+		wantGreedy  bool // Instead of wantErr, check if we get greedy sampler
 	}{
-		{
-			name: "no transforms",
-			// temperature is 0, so greedy should be used
-			wantErr: false,
-		},
 		{
 			name:        "temperature",
 			temperature: 0.5,
-			wantErr:     false,
+			wantGreedy:  false,
 		},
 		{
-			name:        "invalid temperature negative",
-			temperature: -1,
-			wantErr:     true,
-		},
-		{
-			name:        "invalid temperature too high",
-			temperature: 2.1,
-			wantErr:     true,
+			name:        "zero temperature - greedy",
+			temperature: 0,
+			wantGreedy:  true,
 		},
 		{
 			name:        "top k",
+			temperature: 0.1,
 			topK:        10,
-			temperature: 0.8,
-			wantErr:     false,
-		},
-		{
-			name:        "invalid top k negative",
-			topK:        -1,
-			temperature: 0.8,
-			wantErr:     true,
+			wantGreedy:  false,
 		},
 		{
 			name:        "top p",
+			temperature: 0.1,
 			topP:        0.9,
-			temperature: 0.8,
-			wantErr:     false,
-		},
-		{
-			name:        "invalid top p negative",
-			topP:        -0.1,
-			temperature: 0.8,
-			wantErr:     true,
-		},
-		{
-			name:        "invalid top p one",
-			topP:        1.0,
-			temperature: 0.8,
-			wantErr:     true,
+			wantGreedy:  false,
 		},
 		{
 			name:        "min p",
+			temperature: 0.1,
 			minP:        0.2,
-			temperature: 0.8,
-			wantErr:     false,
-		},
-		{
-			name:        "invalid min p negative",
-			minP:        -0.1,
-			temperature: 0.8,
-			wantErr:     true,
+			wantGreedy:  false,
 		},
 		{
-			name:        "invalid min p one",
-			minP:        1.0,
-			temperature: 0.8,
-			wantErr:     true,
+			name:        "seed - weighted",
+			temperature: 0.1,
+			seed:        42,
+			wantGreedy:  false,
 		},
 		{
 			name:        "default values",
@@ -162,16 +82,16 @@ func TestNewSampler(t *testing.T) {
 			topP:        0.9,
 			minP:        0.0,
 			seed:        0,
-			wantErr:     false,
+			wantGreedy:  false,
 		},
 		{
-			name:        "all zeroes",
+			name:        "all zeroes - greedy",
 			temperature: 0.0,
 			topK:        0,
 			topP:        0.0,
 			minP:        0.0,
 			seed:        0,
-			wantErr:     false, // all zeroes means no transforms
+			wantGreedy:  true,
 		},
 		{
 			name:        "all transforms",
@@ -180,33 +100,28 @@ func TestNewSampler(t *testing.T) {
 			topP:        0.95,
 			minP:        0.1,
 			seed:        42,
-			wantErr:     false,
+			wantGreedy:  false,
 		},
 	}
-
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
-			_, err := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
-			if (err != nil) != tt.wantErr {
-				t.Errorf("NewSampler() error = %v, wantErr %v", err, tt.wantErr)
+			sampler := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
+			_, isGreedy := sampler.(*greedy)
+			if isGreedy != tt.wantGreedy {
+				t.Errorf("NewSampler() got greedy = %v, want %v", isGreedy, tt.wantGreedy)
 			}
 		})
 	}
 }
 
 func BenchmarkSample(b *testing.B) {
-	transforms := []Transform{
-		Temperature(0.5),
-		TopK(10),
-		TopP(0.9),
-		MinP(0.2),
-	}
-
+	weighted := NewSampler(0.5, 10, 0.9, 0.2, -1)
 	samplers := map[string]Sampler{
-		"Greedy":   Greedy(),
-		"Weighted": Weighted(nil, transforms...),
+		"Greedy":   NewSampler(0, 0, 0, 0, 0), // Use NewSampler with temp=0 for greedy
+		"Weighted": weighted,
 	}
 
+	// Generate random logits for benchmarking
 	logits := make([]float32, 1<<16)
 	for i := range logits {
 		logits[i] = rand.Float32()
@@ -215,7 +130,7 @@ func BenchmarkSample(b *testing.B) {
 	for name, s := range samplers {
 		b.Run(name, func(b *testing.B) {
 			b.ResetTimer()
-			for range b.N {
+			for b.Loop() {
 				if _, err := s.Sample(logits); err != nil {
 					b.Error(err)
 				}

+ 157 - 74
sample/transforms.go

@@ -1,120 +1,203 @@
 package sample
 
 import (
-	"cmp"
 	"math"
 	"slices"
-
-	pq "github.com/emirpasic/gods/v2/queues/priorityqueue"
 )
 
-type Transform interface {
-	Apply([]float64) []float64
-}
-
-// TODO(parthsareen): potentially cache softmax values
-func softmax(logits []float64) []float64 {
-	var sum float64
-	probs := make([]float64, len(logits))
-	for i, v := range logits {
-		probs[i] = math.Exp(v)
-		sum += probs[i]
+func softmax(ts []logit) []logit {
+	var sum float32
+	for i, v := range ts {
+		ts[i].value = float32(math.Exp(float64(v.value)))
+		sum += ts[i].value
 	}
 
-	for i := range probs {
-		probs[i] /= sum
+	for i := range ts {
+		ts[i].value /= sum
 	}
 
-	return probs
+	return ts
 }
 
-type Temperature float64
+func temperature(ti []logit, t float32) []logit {
+	if t == 1 {
+		return ti
+	}
 
-func (t Temperature) Apply(logits []float64) []float64 {
-	temp := math.Max(float64(t), 1e-7)
+	temp := max(t, 1e-7)
+	maxLogit := float32(math.Inf(-1))
+	for _, token := range ti {
+		if token.value > maxLogit {
+			maxLogit = token.value
+		}
+	}
 
 	// subtracting max logit to avoid under/overflow
-	maxLogit := slices.Max(logits)
-	for i := range logits {
-		logits[i] = (logits[i] - maxLogit) / temp
+	for i := range ti {
+		ti[i].value = (ti[i].value - maxLogit) / temp
 	}
 
-	return logits
+	return ti
 }
 
-type logitMap struct {
-	index int
-	logit float64
+// siftDown maintains a min-heap property by recursively moving larger elements down the heap.
+//
+// The heap is represented as an array where for any node at index i:
+// - Left child is at index 2i + 1
+// - Right child is at index 2i + 2
+// - Parent is at index (i-1)/2
+//
+// The function compares a node with its children and:
+// 1. Finds the smallest value between the node and its children
+// 2. If the node is not the smallest, swaps it with its smallest child
+// 3. Continues this process down the affected path until the min-heap property is restored
+func siftDown(data []logit, start, end int) {
+	root := start
+	for {
+		child := 2*root + 1
+		if child >= end {
+			break
+		}
+		// Find smaller child (we want min heap)
+		if child+1 < end && data[child+1].value < data[child].value {
+			child++
+		}
+		// Exit if root is already smaller than children
+		if data[root].value <= data[child].value {
+			break
+		}
+		// Swap with smaller child and continue
+		data[root], data[child] = data[child], data[root]
+		root = child
+	}
 }
 
-type TopK int
-
-// TODO(parthsareen): avoid having to check all logits after this transform
-func (k TopK) Apply(logits []float64) []float64 {
-	if int(k) >= len(logits) {
-		return logits
+// topK limits the number of tokens considered to the k highest logits
+func topK(ts []logit, k int) []logit {
+	if k >= len(ts) {
+		return ts
+	}
+	// Heapify + siftDown - O(nlog(k))
+	// Build min-heap of first k elements
+	heap := ts[:k]
+	for i := k/2 - 1; i >= 0; i-- {
+		siftDown(heap, i, k)
 	}
-	q := pq.NewWith(func(a, b logitMap) int {
-		return -cmp.Compare(a.logit, b.logit)
-	})
 
-	for i, logit := range logits {
-		q.Enqueue(logitMap{index: i, logit: logit})
+	// Process remaining elements - if larger than heap root, replace root
+	for i := k; i < len(ts); i++ {
+		if ts[i].value > heap[0].value {
+			heap[0] = ts[i]
+			siftDown(heap, 0, k)
+		}
 	}
 
-	validLogits := make(map[int]float64)
-	for range k {
-		logitMap, _ := q.Dequeue()
-		validLogits[logitMap.index] = logitMap.logit
+	slices.Reverse(heap)
+
+	ts = heap
+	return ts
+}
+
+// topP limits tokens to those with cumulative probability p
+func topP(ts []logit, p float32) []logit {
+	if p == 1.0 {
+		return ts
 	}
 
-	for i := range logits {
-		if _, ok := validLogits[i]; !ok {
-			logits[i] = math.Inf(-1)
+	// Find cutoff index where cumulative sum exceeds p
+	var sum float32
+	for i, t := range ts {
+		sum += t.value
+		if sum > float32(p) {
+			ts = ts[:i+1]
+			return ts
 		}
 	}
 
-	return logits
+	return ts
 }
 
-type TopP float64
+// minP limits tokens to those with cumulative probability p
+func minP(ts []logit, p float32) []logit {
+	if p == 1.0 {
+		return ts
+	}
 
-func (p TopP) Apply(logits []float64) []float64 {
-	probs := softmax(logits)
-	indices := make([]int, len(probs))
-	for i := range indices {
-		indices[i] = i
+	maxProb := float32(math.Inf(-1))
+	for _, token := range ts {
+		if token.value > maxProb {
+			maxProb = token.value
+		}
 	}
 
-	// sort in descending order
-	slices.SortFunc(indices, func(i, j int) int {
-		return cmp.Compare(probs[j], probs[i])
-	})
+	threshold := maxProb * float32(p)
 
-	var sum float64
-	for i, idx := range indices {
-		sum += probs[idx]
-		if sum > float64(p) {
-			for _, idx := range indices[i+1:] {
-				logits[idx] = math.Inf(-1)
-			}
-			break
+	// Filter tokens in-place
+	validTokens := ts[:0]
+	for i, token := range ts {
+		if token.value >= threshold {
+			validTokens = append(validTokens, ts[i])
 		}
 	}
-	return logits
-}
 
-type MinP float64
+	ts = validTokens
+	return ts
+}
 
-func (p MinP) Apply(logits []float64) []float64 {
-	probs := softmax(logits)
-	threshold := slices.Max(probs) * float64(p)
+// TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584
+// Conting sort implementation to sort tokens by logits
+func sortLogits(tokens []logit) {
+	if len(tokens) <= 1 {
+		return
+	}
 
-	for i, prob := range probs {
-		if prob < threshold {
-			logits[i] = math.Inf(-1)
+	// Find max/min in a single pass
+	minLogit, maxLogit := tokens[0].value, tokens[0].value
+	for _, t := range tokens[1:] {
+		if t.value < minLogit {
+			minLogit = t.value
+		} else if t.value > maxLogit {
+			maxLogit = t.value
 		}
 	}
 
-	return logits
+	// Calculate scaling to map to uint32 range
+	logitRange := maxLogit - minLogit
+	if logitRange < 1e-6 {
+		return // All values effectively equal
+	}
+
+	// Count frequencies directly from tokens
+	const maxInt = (1 << 24) - 1 // Use 24 bits for good granularity
+	var counts [256]int          // For first byte
+
+	// First pass: count frequencies
+	for _, t := range tokens {
+		// Map to [0, maxInt] range
+		score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt)
+		counts[score>>16]++
+	}
+
+	// Calculate offsets
+	var offset int
+	for i := range counts {
+		count := counts[i]
+		counts[i] = offset
+		offset += count
+	}
+
+	// Second pass: place elements in correct position
+	output := make([]logit, len(tokens))
+	// Track current positions
+	countsCopy := counts
+
+	for i, t := range tokens {
+		score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt)
+
+		pos := countsCopy[score>>16]
+		countsCopy[score>>16]++
+		output[len(tokens)-1-pos] = tokens[i]
+	}
+
+	copy(tokens, output)
 }

+ 151 - 46
sample/transforms_test.go

@@ -4,77 +4,182 @@ import (
 	"math"
 	"math/rand/v2"
 	"testing"
-
-	"github.com/google/go-cmp/cmp"
 )
 
-func TestTemperature(t *testing.T) {
-	got := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0})
-	want := []float64{-4, -10, 0, -14, -6, -12, -8}
-	if diff := cmp.Diff(want, got); diff != "" {
-		t.Errorf("logits mismatch (-want +got):\n%s", diff)
+// Helper to convert float64 slice to logit slice
+func toLogits(values []float64) []logit {
+	tokens := make([]logit, len(values))
+	for i, v := range values {
+		tokens[i] = logit{
+			id:    int32(i),
+			value: float32(v),
+		}
+	}
+	return tokens
+}
+
+// Helper to compare logit slices
+func compareLogits(t *testing.T, name string, want []float64, got []logit) {
+	t.Helper()
+	if len(want) != len(got) {
+		t.Errorf("%s: length mismatch: want %d, got %d", name, len(want), len(got))
+		return
+	}
+	for i := range want {
+		if math.Abs(float64(got[i].value)-want[i]) > 1e-6 {
+			t.Errorf("%s: index %d: want %f, got %f", name, i, want[i], got[i].value)
+		}
 	}
 }
 
+func TestTemperature(t *testing.T) {
+	input := []float64{2, -1, 4, -3, 1, -2, 0}
+	want := []float64{-4, -10, 0, -14, -6, -12, -8} // (logit - max logit) / temp
+
+	got := temperature(toLogits(input), 0.5)
+	compareLogits(t, "Temperature", want, got)
+}
+
 func TestSoftmax(t *testing.T) {
-	got := softmax([]float64{-3, -2, -1, 0, 1, 2, 4})
+	input := []float64{-3, -2, -1, 0, 1, 2, 4}
+	got := softmax(toLogits(input))
 
-	want := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085}
-	if diff := cmp.Diff(want, got); diff != "" {
-		t.Errorf("probs mismatch (-want +got):\n%s", diff)
+	// Check probabilities sum to 1
+	var sum float32
+	for _, token := range got {
+		sum += token.value
+	}
+	if math.Abs(float64(sum)-1.0) > 1e-6 {
+		t.Errorf("probabilities don't sum to 1: got %f", sum)
 	}
-}
 
-func TestTopK(t *testing.T) {
-	got := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
-	want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4}
-	if diff := cmp.Diff(want, got); diff != "" {
-		t.Errorf("logits mismatch (-want +got):\n%s", diff)
+	// Check relative ordering is preserved
+	for i := 1; i < len(got); i++ {
+		if got[i].value < got[i-1].value {
+			t.Errorf("probability ordering not preserved at index %d", i)
+		}
 	}
+}
 
-	got = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
+func TestTopK(t *testing.T) {
+	input := []float64{-3, -2, -1, 0, 1, 2, 4}
 
-	want = []float64{-3, -2, -1, 0, 1, 2, 4}
-	if diff := cmp.Diff(want, got); diff != "" {
-		t.Errorf("logits mismatch (-want +got):\n%s", diff)
+	// Test k=3
+	got := topK(toLogits(input), 3)
+	if len(got) != 3 {
+		t.Errorf("topK(3): wrong length: want 3, got %d", len(got))
 	}
+	// Should keep highest 3 values: 4, 2, 1
+	want := []float64{4, 2, 1}
+	compareLogits(t, "topK(3)", want, got)
+
+	// Test k > len
+	got = topK(toLogits(input), 10)
+	compareLogits(t, "topK(10)", input, got)
 }
 
 func TestTopP(t *testing.T) {
-	got := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
-	want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4}
-	if diff := cmp.Diff(want, got); diff != "" {
-		t.Errorf("logits mismatch (-want +got):\n%s", diff)
+	input := []float64{-3, -2, -1, 0, 1, 2, 4}
+	tokens := toLogits(input)
+
+	// First apply temperature and softmax to get probabilities
+	tokens = temperature(tokens, 1)
+	tokens = softmax(tokens)
+	sortLogits(tokens)
+
+	// Then apply topP
+	got := topP(tokens, 0.95)
+
+	// Should keep tokens until cumsum > 0.95
+	if len(got) > 3 {
+		t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
+		t.Logf("got: %v", got)
 	}
 }
 
 func TestMinP(t *testing.T) {
-	got := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3})
-	want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3}
-	if diff := cmp.Diff(want, got); diff != "" {
-		t.Errorf("logits mismatch (-want +got):\n%s", diff)
+	input := []float64{-3, -2, -1, 0, 1, 2, 4, 3}
+	tokens := toLogits(input)
+
+	// First apply temperature and softmax
+	tokens = temperature(tokens, 1)
+	tokens = softmax(tokens)
+
+	// Then apply minP
+	got := minP(tokens, 0.2)
+
+	// Should keep tokens with prob >= 0.2 * max_prob
+	if len(got) > 3 {
+		t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
 	}
 }
 
-func BenchmarkTransform(b *testing.B) {
-	transforms := map[string]Transform{
-		"Temperature": Temperature(0.5),
-		"TopK":        TopK(10),
-		"TopP":        TopP(0.9),
-		"MinP":        MinP(0.2),
-	}
+func TestSortLogits(t *testing.T) {
+	input := []float64{3, 1, 4, 2, -1, 0, -2}
+	tokens := toLogits(input)
+
+	sortLogits(tokens)
 
-	logits := make([]float64, 1<<16)
-	for i := range logits {
-		logits[i] = rand.Float64()
+	for i := 1; i < len(tokens); i++ {
+		if tokens[i].value > tokens[i-1].value {
+			t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
+				i, tokens[i].value, tokens[i-1].value)
+		}
 	}
 
-	for name, transform := range transforms {
-		b.Run(name, func(b *testing.B) {
-			b.ResetTimer()
-			for range b.N {
-				transform.Apply(logits)
-			}
-		})
+	want := []float64{4, 3, 2, 1, 0, -1, -2}
+	compareLogits(t, "sortLogits", want, tokens)
+}
+
+func BenchmarkTransforms(b *testing.B) {
+	// Generate random logits
+	tokens := make([]logit, 1<<16)
+	for i := range tokens {
+		tokens[i] = logit{
+			id:    int32(i),
+			value: rand.Float32(),
+		}
 	}
+
+	tokensCopy := make([]logit, len(tokens))
+
+	b.Run("Temperature", func(b *testing.B) {
+		b.ResetTimer()
+		for b.Loop() {
+			copy(tokensCopy, tokens)
+			temperature(tokensCopy, 0.5)
+		}
+	})
+
+	b.Run("TopK", func(b *testing.B) {
+		b.ResetTimer()
+		for b.Loop() {
+			copy(tokensCopy, tokens)
+			topK(tokensCopy, 10)
+		}
+	})
+
+	b.Run("TopP", func(b *testing.B) {
+		b.ResetTimer()
+		for b.Loop() {
+			copy(tokensCopy, tokens)
+			topP(tokensCopy, 0.9)
+		}
+	})
+
+	b.Run("MinP", func(b *testing.B) {
+		b.ResetTimer()
+		for b.Loop() {
+			copy(tokensCopy, tokens)
+			minP(tokensCopy, 0.2)
+		}
+	})
+
+	b.Run("SortTokens", func(b *testing.B) {
+		b.ResetTimer()
+		for b.Loop() {
+			copy(tokensCopy, tokens)
+			sortLogits(tokensCopy)
+		}
+	})
 }