Pārlūkot izejas kodu

sample: temporarily use grammars for constrained generation in new engine (#9586)

Jeffrey Morgan 1 mēnesi atpakaļ
vecāks
revīzija
e093db92c4

+ 68 - 0
llama/llama.go

@@ -245,6 +245,20 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
 	return &m, nil
 }
 
+func LoadVocabFromFile(path string) (*Vocab, error) {
+	mp := C.CString(path)
+	defer C.free(unsafe.Pointer(mp))
+	v := Vocab{c: C.llama_load_vocab_from_file(mp)}
+	if v.c == nil {
+		return nil, fmt.Errorf("unable to load vocab: %s", path)
+	}
+	return &v, nil
+}
+
+func FreeVocab(vocab *Vocab) {
+	C.llama_free_vocab(vocab.c)
+}
+
 func FreeModel(model *Model) {
 	C.llama_model_free(model.c)
 }
@@ -293,6 +307,10 @@ func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float
 	return nil
 }
 
+type Vocab struct {
+	c *C.struct_llama_vocab
+}
+
 func (m *Model) Vocab() *C.struct_llama_vocab {
 	return C.llama_model_get_vocab(m.c)
 }
@@ -669,3 +687,53 @@ func SchemaToGrammar(schema []byte) []byte {
 	}
 	return buf[:n]
 }
+
+type Sampler struct {
+	c *C.struct_llama_sampler
+}
+
+func NewGrammarSampler(vocab *Vocab, grammar string) *Sampler {
+	cGrammar := C.CString(grammar)
+	cRoot := C.CString("root")
+	defer C.free(unsafe.Pointer(cGrammar))
+	defer C.free(unsafe.Pointer(cRoot))
+
+	sampler := &Sampler{c: C.llama_sampler_init_grammar(vocab.c, cGrammar, cRoot)}
+
+	return sampler
+}
+
+func (s *Sampler) Accept(token int32) {
+	C.llama_sampler_accept(s.c, C.llama_token(token))
+}
+
+type TokenData struct {
+	Id    int32
+	Logit float32
+}
+
+func (s *Sampler) Apply(tokens []TokenData) {
+	tds := make([]C.struct_llama_token_data, len(tokens))
+	for i, token := range tokens {
+		tds[i] = C.struct_llama_token_data{
+			id:    C.int32_t(token.Id),
+			logit: C.float(token.Logit),
+			p:     C.float(0.0),
+		}
+	}
+	tda := &C.llama_token_data_array{
+		data:     (*C.struct_llama_token_data)(unsafe.Pointer(&tds[0])),
+		size:     C.size_t(len(tokens)),
+		selected: C.int64_t(-1),
+		sorted:   C.bool(false),
+	}
+
+	var pinner runtime.Pinner
+	pinner.Pin(&tds[0])
+	defer pinner.Unpin()
+
+	C.llama_sampler_apply(s.c, tda)
+	for i := range tokens {
+		tokens[i].Logit = float32(tds[i].logit)
+	}
+}

+ 22 - 0
llama/sampling_ext.cpp

@@ -2,6 +2,9 @@
 #include "sampling.h"
 #include "sampling_ext.h"
 #include "json-schema-to-grammar.h"
+#include "llama.h"
+#include "llama-model.h"
+#include "llama-model-loader.h"
 
 struct common_sampler *common_sampler_cinit(const struct llama_model *model, struct common_sampler_cparams *params) {
     try {
@@ -64,3 +67,22 @@ int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len)
         return 0;
     }
 }
+
+struct llama_vocab * llama_load_vocab_from_file(const char * fname) {
+    llama_vocab * vocab = new llama_vocab();
+    try {
+        const auto kv = LLM_KV(LLM_ARCH_UNKNOWN);
+        std::vector<std::string> splits = {};
+        llama_model_loader ml(std::string(fname), splits, false, false, nullptr);
+        vocab->load(ml, kv);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
+        return nullptr;
+    }
+
+    return vocab;
+}
+
+void llama_free_vocab(struct llama_vocab * vocab) {
+    delete vocab;
+}

+ 3 - 0
llama/sampling_ext.h

@@ -35,6 +35,9 @@ extern "C"
 
     int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len);
 
+    struct llama_vocab * llama_load_vocab_from_file(const char * fname);
+    void llama_free_vocab(struct llama_vocab * vocab);
+
 #ifdef __cplusplus
 }
 #endif

+ 16 - 21
llm/server.go

@@ -729,29 +729,24 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 	}
 
 	if len(req.Format) > 0 {
-		format := string(req.Format)
-		if format != `null` && format != `""` {
-			if s.textProcessor != nil {
-				// New engine handles this on the backend
-				request["format"] = req.Format
-			} else {
-				// old engine
-				switch format {
-				case `"json"`:
-					request["grammar"] = grammarJSON
-				default:
-					if req.Format[0] != '{' {
-						return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
-					}
+		switch string(req.Format) {
+		case `null`, `""`:
+			// Field was set, but "missing" a value. We accept
+			// these as "not set".
+			break
+		case `"json"`:
+			request["grammar"] = grammarJSON
+		default:
+			if req.Format[0] != '{' {
+				return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
+			}
 
-					// User provided a JSON schema
-					g := llama.SchemaToGrammar(req.Format)
-					if g == nil {
-						return fmt.Errorf("invalid JSON schema in format")
-					}
-					request["grammar"] = string(g)
-				}
+			// User provided a JSON schema
+			g := llama.SchemaToGrammar(req.Format)
+			if g == nil {
+				return fmt.Errorf("invalid JSON schema in format")
 			}
+			request["grammar"] = string(g)
 		}
 	}
 

+ 19 - 4
runner/ollamarunner/runner.go

@@ -254,6 +254,12 @@ type Server struct {
 	// multimodalHash generates hashes for comparing equality
 	// of non-text data
 	multimodalHash maphash.Hash
+
+	// vocab is a llama.cpp vocab required for gammar-based
+	// constrained generation (json mode, structured outputs)
+	// TODO: this is temporary until Ollama sampling supports
+	// constrained generation
+	vocab *sample.Vocab
 }
 
 func (s *Server) allNil() bool {
@@ -574,18 +580,25 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
+	var grammar *sample.Grammar
+	var err error
+	if req.Grammar != "" {
+		grammar, err = sample.NewGrammar(s.vocab, req.Grammar)
+		if err != nil {
+			http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
+			return
+		}
+	}
+
 	sampler := sample.NewSampler(
 		req.Temperature,
 		req.TopK,
 		req.TopP,
 		req.MinP,
 		req.Seed,
+		grammar,
 	)
 
-	if req.Grammar != "" {
-		panic("grammars are not yet supported")
-	}
-
 	seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
 		numPredict: req.NumPredict,
 		stop:       req.Stop,
@@ -797,6 +810,8 @@ func (s *Server) loadModel(
 		panic(err)
 	}
 
+	s.vocab = sample.NewVocab(mpath)
+
 	// TODO(jessegross): LoRA loading
 	if lpath.String() != "" {
 		panic("loras are not yet implemented")

+ 135 - 54
sample/samplers.go

@@ -2,43 +2,88 @@ package sample
 
 import (
 	"errors"
+	"math"
 	"math/rand/v2"
 	"slices"
-)
+	"sync"
 
-// Sampler is not thread-safe. Each goroutine should have its own instance
-type Sampler interface {
-	Sample([]float32) (int32, error)
-}
+	"github.com/ollama/ollama/llama"
+)
 
-// logit represents information about a single token during sampling
-type logit struct {
+// token represents information about a single token during sampling
+type token struct {
 	id    int32   // The token's unique identifier
 	value float32 // The raw logit or probability from the model
 }
 
-type weighted struct {
+type Sampler struct {
 	rng         *rand.Rand
-	tokens      []logit
 	topK        int
 	topP        float32
 	minP        float32
 	temperature float32
+	grammar     *Grammar
 }
 
-func (s *weighted) Sample(logits []float32) (int32, error) {
-	if len(s.tokens) < len(logits) {
-		s.tokens = make([]logit, len(logits))
+func (s *Sampler) Sample(logits []float32) (int32, error) {
+	tokens := make([]token, len(logits))
+	for i := range logits {
+		tokens[i].id = int32(i)
+		tokens[i].value = logits[i]
 	}
 
-	tokens := s.tokens[:len(logits)]
+	t, err := s.sample(tokens)
+	if err != nil {
+		return -1, err
+	}
 
-	for i, v := range logits {
-		tokens[i].id = int32(i)
-		tokens[i].value = v
+	if s.grammar != nil {
+		// optimization: first check if the max logit is accepted by the grammar
+		// if the max logit is rejected, apply the grammar to all logits (slower)
+		top := []token{t}
+		s.grammar.Apply(top)
+		if !math.IsInf(float64(top[0].value), -1) {
+			s.grammar.Accept(top[0].id)
+			return top[0].id, nil
+		}
+
+		// since .sample has side effects of modifying the tokens
+		// we need to reset them before applying the grammar and
+		// sampling again
+		for i := range logits {
+			tokens[i].id = int32(i)
+			tokens[i].value = logits[i]
+		}
+		s.grammar.Apply(tokens)
+		t, err = s.sample(tokens)
+		if err != nil {
+			return -1, err
+		}
+		s.grammar.Accept(t.id)
+	}
+
+	return t.id, nil
+}
+
+// greedy returns the highest probability token from the tokens
+func greedy(tokens []token) token {
+	max := tokens[0]
+	for i := 1; i < len(tokens); i++ {
+		if tokens[i].value > max.value {
+			max = tokens[i]
+		}
+	}
+
+	return max
+}
+
+// sample returns the highest probability token from the tokens
+// given sampler parameters. It also has side effects of modifying the tokens
+func (s *Sampler) sample(tokens []token) (token, error) {
+	if s.temperature == 0 {
+		return greedy(tokens), nil
 	}
 
-	// Tokens are sorted by logits in TopK or SortTokens
 	if s.topK > 0 {
 		tokens = topK(tokens, s.topK)
 	} else {
@@ -47,12 +92,14 @@ func (s *weighted) Sample(logits []float32) (int32, error) {
 
 	tokens = temperature(tokens, s.temperature)
 	tokens = softmax(tokens)
-
 	tokens = topP(tokens, s.topP)
 	tokens = minP(tokens, s.minP)
 
+	// TODO: this should fall back to greedy sampling
+	// or topP, topK values etc should be such that
+	// there are always tokens to sample from
 	if len(tokens) == 0 {
-		return -1, errors.New("no valid logits found for weighted sampling")
+		return token{}, errors.New("no tokens to sample from")
 	}
 
 	var r float32
@@ -70,48 +117,18 @@ func (s *weighted) Sample(logits []float32) (int32, error) {
 	}
 	r *= tokens[len(tokens)-1].value
 
-	idx, _ := slices.BinarySearchFunc(tokens, r, func(token logit, target float32) int {
-		// Compare cumulative probabilities
+	idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
 		if token.value < target {
 			return -1
 		}
-		// First token that exceeds target
 		return 1
 	})
 
-	if idx >= len(tokens) {
-		idx = len(tokens) - 1
-	}
-
-	return tokens[idx].id, nil
-}
-
-type greedy struct{}
-
-// Greedy sample returns the index of the maximum value in logits.
-func (s greedy) Sample(logits []float32) (int32, error) {
-	if len(logits) == 0 {
-		return -1, errors.New("no logits provided for greedy sampling")
-	}
-
-	maxIdx := 0
-	maxVal := logits[0]
-	for i := 1; i < len(logits); i++ {
-		if logits[i] > maxVal {
-			maxVal = logits[i]
-			maxIdx = i
-		}
-	}
-
-	return int32(maxIdx), nil
+	return tokens[idx], nil
 }
 
 // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
-func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) Sampler {
-	if temperature == 0 {
-		return &greedy{}
-	}
-
+func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
 	var rng *rand.Rand
 	if seed != -1 {
 		// PCG requires two parameters: sequence and stream
@@ -120,7 +137,9 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
 		// Use golden ratio hash to generate statistically independent seeds
 		rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
 	}
-	temperature = max(temperature, 1)
+	if temperature < 0.0 {
+		temperature = 0.0
+	}
 
 	if topP < 0.0 {
 		topP = 0.0
@@ -136,11 +155,73 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
 		minP = 1.0
 	}
 
-	return &weighted{
+	return Sampler{
 		rng:         rng,
 		topK:        topK,
 		topP:        topP,
 		minP:        minP,
 		temperature: temperature,
+		grammar:     grammar,
+	}
+}
+
+type Grammar struct {
+	vocab   *Vocab
+	grammar string
+	sampler *llama.Sampler
+}
+
+func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) {
+	v, err := vocab.Load()
+	if err != nil {
+		return nil, err
+	}
+
+	return &Grammar{
+		vocab:   vocab,
+		grammar: grammar,
+		sampler: llama.NewGrammarSampler(v, grammar),
+	}, nil
+}
+
+func (g *Grammar) Apply(tokens []token) {
+	tds := make([]llama.TokenData, len(tokens))
+	for i, token := range tokens {
+		tds[i].Id = token.id
+		tds[i].Logit = token.value
 	}
+
+	g.sampler.Apply(tds)
+
+	for i := range tokens {
+		tokens[i].value = tds[i].Logit
+	}
+}
+
+func (g *Grammar) Accept(token int32) {
+	g.sampler.Accept(token)
+}
+
+type Vocab struct {
+	once  sync.Once
+	vocab *llama.Vocab
+	err   error
+	path  string
+}
+
+func NewVocab(path string) *Vocab {
+	return &Vocab{path: path}
+}
+
+// Load returns the lazily-loaded vocabulary
+func (v *Vocab) Load() (*llama.Vocab, error) {
+	v.once.Do(func() {
+		vocab, err := llama.LoadVocabFromFile(v.path)
+		if err != nil {
+			v.err = err
+			return
+		}
+		v.vocab = vocab
+	})
+	return v.vocab, v.err
 }

+ 8 - 20
sample/samplers_benchmark_test.go

@@ -16,13 +16,10 @@ func BenchmarkWeightedSampler(b *testing.B) {
 				logits[i] = float32(rand.Float64()*10 - 5)
 			}
 
-			sampler := NewSampler(0.8, 0, 0, 0, 42)
+			sampler := NewSampler(0.8, 0, 0, 0, 42, nil)
 			b.ResetTimer()
 			for b.Loop() {
-				_, err := sampler.Sample(logits)
-				if err != nil {
-					b.Fatalf("Sampling failed: %v", err)
-				}
+				sampler.Sample(logits)
 			}
 		})
 	}
@@ -52,30 +49,24 @@ func BenchmarkWeightedSampler(b *testing.B) {
 
 	for _, tc := range configs {
 		b.Run("Config"+tc.name, func(b *testing.B) {
-			sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed)
+			sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil)
 			sampler.Sample(logits)
 
 			b.ResetTimer()
 
 			for b.Loop() {
-				_, err := sampler.Sample(logits)
-				if err != nil {
-					b.Fatalf("Sampling failed: %v", err)
-				}
+				sampler.Sample(logits)
 			}
 		})
 	}
 
 	// Test with combined transforms separately - topK influences performance greatly
 	b.Run("TransformCombined", func(b *testing.B) {
-		sampler := NewSampler(0.8, 50, 0.9, 0.05, 42)
+		sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil)
 		b.ResetTimer()
 
 		for b.Loop() {
-			_, err := sampler.Sample(logits)
-			if err != nil {
-				b.Fatalf("Sampling failed: %v", err)
-			}
+			sampler.Sample(logits)
 		}
 	})
 }
@@ -90,14 +81,11 @@ func BenchmarkGreedySampler(b *testing.B) {
 				logits[i] = float32(rand.Float64()*10 - 5)
 			}
 
-			sampler := NewSampler(0, -1, 0, 0, -1)
+			sampler := NewSampler(0, -1, 0, 0, -1, nil)
 			b.ResetTimer()
 
 			for b.Loop() {
-				_, err := sampler.Sample(logits)
-				if err != nil {
-					b.Fatalf("Sampling failed: %v", err)
-				}
+				sampler.Sample(logits)
 			}
 		})
 	}

+ 5 - 89
sample/samplers_test.go

@@ -7,7 +7,7 @@ import (
 
 func TestWeighted(t *testing.T) {
 	logits := []float32{-10, 3, -10, -10}
-	sampler := NewSampler(0, 0, 0, 0, 0)
+	sampler := NewSampler(0, 0, 0, 0, 0, nil)
 	got, err := sampler.Sample(logits)
 	if err != nil {
 		t.Error(err)
@@ -19,7 +19,7 @@ func TestWeighted(t *testing.T) {
 	}
 
 	logits = []float32{-100, -10, 0, 10}
-	sampler = NewSampler(0, 0, 0, 0, 0)
+	sampler = NewSampler(0, 0, 0, 0, 0, nil)
 	got, err = sampler.Sample(logits)
 	if err != nil {
 		t.Error(err)
@@ -31,94 +31,10 @@ func TestWeighted(t *testing.T) {
 	}
 }
 
-func TestNewSampler(t *testing.T) {
-	tests := []struct {
-		name        string
-		temperature float32
-		topK        int
-		topP        float32
-		minP        float32
-		seed        int
-		wantGreedy  bool // Instead of wantErr, check if we get greedy sampler
-	}{
-		{
-			name:        "temperature",
-			temperature: 0.5,
-			wantGreedy:  false,
-		},
-		{
-			name:        "zero temperature - greedy",
-			temperature: 0,
-			wantGreedy:  true,
-		},
-		{
-			name:        "top k",
-			temperature: 0.1,
-			topK:        10,
-			wantGreedy:  false,
-		},
-		{
-			name:        "top p",
-			temperature: 0.1,
-			topP:        0.9,
-			wantGreedy:  false,
-		},
-		{
-			name:        "min p",
-			temperature: 0.1,
-			minP:        0.2,
-			wantGreedy:  false,
-		},
-		{
-			name:        "seed - weighted",
-			temperature: 0.1,
-			seed:        42,
-			wantGreedy:  false,
-		},
-		{
-			name:        "default values",
-			temperature: 0.8,
-			topK:        40,
-			topP:        0.9,
-			minP:        0.0,
-			seed:        0,
-			wantGreedy:  false,
-		},
-		{
-			name:        "all zeroes - greedy",
-			temperature: 0.0,
-			topK:        0,
-			topP:        0.0,
-			minP:        0.0,
-			seed:        0,
-			wantGreedy:  true,
-		},
-		{
-			name:        "all transforms",
-			temperature: 0.8,
-			topK:        50,
-			topP:        0.95,
-			minP:        0.1,
-			seed:        42,
-			wantGreedy:  false,
-		},
-	}
-	for _, tt := range tests {
-		t.Run(tt.name, func(t *testing.T) {
-			sampler := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
-			_, isGreedy := sampler.(*greedy)
-			if isGreedy != tt.wantGreedy {
-				t.Errorf("NewSampler() got greedy = %v, want %v", isGreedy, tt.wantGreedy)
-			}
-		})
-	}
-}
-
 func BenchmarkSample(b *testing.B) {
-	weighted := NewSampler(0.5, 10, 0.9, 0.2, -1)
 	samplers := map[string]Sampler{
-		"Greedy":   NewSampler(0, 0, 0, 0, 0), // Use NewSampler with temp=0 for greedy
-		"Weighted": weighted,
+		"Greedy":   NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
+		"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
 	}
 
 	// Generate random logits for benchmarking
@@ -132,7 +48,7 @@ func BenchmarkSample(b *testing.B) {
 			b.ResetTimer()
 			for b.Loop() {
 				if _, err := s.Sample(logits); err != nil {
-					b.Error(err)
+					b.Fatalf("error sampling: %v", err)
 				}
 			}
 		})

+ 8 - 8
sample/transforms.go

@@ -5,7 +5,7 @@ import (
 	"slices"
 )
 
-func softmax(ts []logit) []logit {
+func softmax(ts []token) []token {
 	var sum float32
 	for i, v := range ts {
 		ts[i].value = float32(math.Exp(float64(v.value)))
@@ -19,7 +19,7 @@ func softmax(ts []logit) []logit {
 	return ts
 }
 
-func temperature(ti []logit, t float32) []logit {
+func temperature(ti []token, t float32) []token {
 	if t == 1 {
 		return ti
 	}
@@ -51,7 +51,7 @@ func temperature(ti []logit, t float32) []logit {
 // 1. Finds the smallest value between the node and its children
 // 2. If the node is not the smallest, swaps it with its smallest child
 // 3. Continues this process down the affected path until the min-heap property is restored
-func siftDown(data []logit, start, end int) {
+func siftDown(data []token, start, end int) {
 	root := start
 	for {
 		child := 2*root + 1
@@ -73,7 +73,7 @@ func siftDown(data []logit, start, end int) {
 }
 
 // topK limits the number of tokens considered to the k highest logits
-func topK(ts []logit, k int) []logit {
+func topK(ts []token, k int) []token {
 	if k >= len(ts) {
 		return ts
 	}
@@ -99,7 +99,7 @@ func topK(ts []logit, k int) []logit {
 }
 
 // topP limits tokens to those with cumulative probability p
-func topP(ts []logit, p float32) []logit {
+func topP(ts []token, p float32) []token {
 	if p == 1.0 {
 		return ts
 	}
@@ -118,7 +118,7 @@ func topP(ts []logit, p float32) []logit {
 }
 
 // minP limits tokens to those with cumulative probability p
-func minP(ts []logit, p float32) []logit {
+func minP(ts []token, p float32) []token {
 	if p == 1.0 {
 		return ts
 	}
@@ -146,7 +146,7 @@ func minP(ts []logit, p float32) []logit {
 
 // TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584
 // Conting sort implementation to sort tokens by logits
-func sortLogits(tokens []logit) {
+func sortLogits(tokens []token) {
 	if len(tokens) <= 1 {
 		return
 	}
@@ -187,7 +187,7 @@ func sortLogits(tokens []logit) {
 	}
 
 	// Second pass: place elements in correct position
-	output := make([]logit, len(tokens))
+	output := make([]token, len(tokens))
 	// Track current positions
 	countsCopy := counts
 

+ 14 - 14
sample/transforms_test.go

@@ -7,10 +7,10 @@ import (
 )
 
 // Helper to convert float64 slice to logit slice
-func toLogits(values []float64) []logit {
-	tokens := make([]logit, len(values))
+func toTokens(values []float64) []token {
+	tokens := make([]token, len(values))
 	for i, v := range values {
-		tokens[i] = logit{
+		tokens[i] = token{
 			id:    int32(i),
 			value: float32(v),
 		}
@@ -19,7 +19,7 @@ func toLogits(values []float64) []logit {
 }
 
 // Helper to compare logit slices
-func compareLogits(t *testing.T, name string, want []float64, got []logit) {
+func compareLogits(t *testing.T, name string, want []float64, got []token) {
 	t.Helper()
 	if len(want) != len(got) {
 		t.Errorf("%s: length mismatch: want %d, got %d", name, len(want), len(got))
@@ -36,13 +36,13 @@ func TestTemperature(t *testing.T) {
 	input := []float64{2, -1, 4, -3, 1, -2, 0}
 	want := []float64{-4, -10, 0, -14, -6, -12, -8} // (logit - max logit) / temp
 
-	got := temperature(toLogits(input), 0.5)
+	got := temperature(toTokens(input), 0.5)
 	compareLogits(t, "Temperature", want, got)
 }
 
 func TestSoftmax(t *testing.T) {
 	input := []float64{-3, -2, -1, 0, 1, 2, 4}
-	got := softmax(toLogits(input))
+	got := softmax(toTokens(input))
 
 	// Check probabilities sum to 1
 	var sum float32
@@ -65,7 +65,7 @@ func TestTopK(t *testing.T) {
 	input := []float64{-3, -2, -1, 0, 1, 2, 4}
 
 	// Test k=3
-	got := topK(toLogits(input), 3)
+	got := topK(toTokens(input), 3)
 	if len(got) != 3 {
 		t.Errorf("topK(3): wrong length: want 3, got %d", len(got))
 	}
@@ -74,13 +74,13 @@ func TestTopK(t *testing.T) {
 	compareLogits(t, "topK(3)", want, got)
 
 	// Test k > len
-	got = topK(toLogits(input), 10)
+	got = topK(toTokens(input), 10)
 	compareLogits(t, "topK(10)", input, got)
 }
 
 func TestTopP(t *testing.T) {
 	input := []float64{-3, -2, -1, 0, 1, 2, 4}
-	tokens := toLogits(input)
+	tokens := toTokens(input)
 
 	// First apply temperature and softmax to get probabilities
 	tokens = temperature(tokens, 1)
@@ -99,7 +99,7 @@ func TestTopP(t *testing.T) {
 
 func TestMinP(t *testing.T) {
 	input := []float64{-3, -2, -1, 0, 1, 2, 4, 3}
-	tokens := toLogits(input)
+	tokens := toTokens(input)
 
 	// First apply temperature and softmax
 	tokens = temperature(tokens, 1)
@@ -116,7 +116,7 @@ func TestMinP(t *testing.T) {
 
 func TestSortLogits(t *testing.T) {
 	input := []float64{3, 1, 4, 2, -1, 0, -2}
-	tokens := toLogits(input)
+	tokens := toTokens(input)
 
 	sortLogits(tokens)
 
@@ -133,15 +133,15 @@ func TestSortLogits(t *testing.T) {
 
 func BenchmarkTransforms(b *testing.B) {
 	// Generate random logits
-	tokens := make([]logit, 1<<16)
+	tokens := make([]token, 1<<16)
 	for i := range tokens {
-		tokens[i] = logit{
+		tokens[i] = token{
 			id:    int32(i),
 			value: rand.Float32(),
 		}
 	}
 
-	tokensCopy := make([]logit, len(tokens))
+	tokensCopy := make([]token, len(tokens))
 
 	b.Run("Temperature", func(b *testing.B) {
 		b.ResetTimer()