Forráskód Böngészése

sample: add sampling package for new engine (#8410)

Parth Sareen 2 hónapja
szülő
commit
0b7e1676eb
7 módosított fájl, 599 hozzáadás és 126 törlés
  1. 22 39
      runner/ollamarunner/runner.go
  2. 0 13
      sample/greedy.go
  3. 0 74
      sample/sample.go
  4. 139 0
      sample/samplers.go
  5. 238 0
      sample/samplers_test.go
  6. 120 0
      sample/transforms.go
  7. 80 0
      sample/transforms_test.go

+ 22 - 39
runner/ollamarunner/runner.go

@@ -65,8 +65,8 @@ type Sequence struct {
 	// number of tokens to predict
 	numPredict int
 
-	// set of samplers to run on generated logits
-	samplers []sample.Sampler
+	// sampler with transforms to run on generated logits
+	sampler sample.Sampler
 
 	// channel to send back the embedding if embedding only
 	embedding chan []float32
@@ -93,7 +93,7 @@ type NewSequenceParams struct {
 	numPredict int
 	stop       []string
 	numKeep    int32
-	samplers   []sample.Sampler
+	sampler    sample.Sampler
 	embedding  bool
 }
 
@@ -136,7 +136,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
 		responses:           make(chan string, 100),
 		quit:                make(chan bool, 1),
 		embedding:           make(chan []float32, 1),
-		samplers:            params.samplers,
+		sampler:             params.sampler,
 		embeddingOnly:       params.embedding,
 		stop:                params.stop,
 		numKeep:             params.numKeep,
@@ -393,13 +393,7 @@ func (s *Server) processBatch() error {
 		return fmt.Errorf("failed to decode batch: %w", err)
 	}
 
-	f32s := modelOutput.Floats()
-
-	// TODO(jessegross): This will no longer be necessary once the sampling interface takes f32s
-	logits := make([]float64, len(f32s))
-	for i, f32 := range f32s {
-		logits[i] = float64(f32)
-	}
+	logits := modelOutput.Floats()
 
 	for i, seq := range s.seqs {
 		if seq == nil {
@@ -433,15 +427,13 @@ func (s *Server) processBatch() error {
 		}
 
 		// sample a token
-		vocabSize := len(f32s) / len(options.Outputs)
-		tokens, err := sample.Sample(logits[seq.iBatch*vocabSize:(seq.iBatch+1)*vocabSize], seq.samplers...)
+		vocabSize := len(logits) / len(options.Outputs)
+
+		token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
 		if err != nil {
-			return err
+			return fmt.Errorf("failed to sample token: %w", err)
 		}
 
-		// TODO(jessegross): Sampler will output a single int32 in the future
-		token := int32(tokens[0])
-
 		// if it's an end of sequence token, break
 		if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
 			// TODO (jmorganca): we should send this back
@@ -565,27 +557,6 @@ type CompletionResponse struct {
 	Timings Timings `json:"timings"`
 }
 
-func getSamplers(_ CompletionRequest) []sample.Sampler {
-	// TODO(jessegross): Waiting for sampling code
-
-	/*samplingParams.TopK = req.TopK
-	samplingParams.TopP = req.TopP
-	samplingParams.MinP = req.MinP
-	samplingParams.TypicalP = req.TypicalP
-	samplingParams.Temp = req.Temperature
-	samplingParams.RepeatLastN = req.RepeatLastN
-	samplingParams.PenaltyRepeat = req.RepeatPenalty
-	samplingParams.PenaltyFreq = req.FrequencyPenalty
-	samplingParams.PenaltyPresent = req.PresencePenalty
-	samplingParams.Mirostat = req.Mirostat
-	samplingParams.MirostatTau = req.MirostatTau
-	samplingParams.MirostatEta = req.MirostatEta
-	samplingParams.Seed = uint32(req.Seed)
-	samplingParams.Grammar = req.Grammar*/
-
-	return []sample.Sampler{sample.Greedy()}
-}
-
 func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	var req CompletionRequest
 	req.Options = Options(api.DefaultOptions())
@@ -604,11 +575,23 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
+	sampler, err := sample.NewSampler(
+		req.Temperature,
+		req.TopK,
+		req.TopP,
+		req.MinP,
+		req.Seed,
+	)
+	if err != nil {
+		http.Error(w, fmt.Sprintf("Failed to create sampler: %v", err), http.StatusInternalServerError)
+		return
+	}
+
 	seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
 		numPredict: req.NumPredict,
 		stop:       req.Stop,
 		numKeep:    int32(req.NumKeep),
-		samplers:   getSamplers(req),
+		sampler:    sampler,
 		embedding:  false,
 	})
 	if err != nil {

+ 0 - 13
sample/greedy.go

@@ -1,13 +0,0 @@
-package sample
-
-import "gonum.org/v1/gonum/floats"
-
-type greedy struct{}
-
-func Greedy() Sampler {
-	return greedy{}
-}
-
-func (s greedy) Sample(t []float64) ([]float64, error) {
-	return []float64{float64(floats.MaxIdx(t))}, nil
-}

+ 0 - 74
sample/sample.go

@@ -1,74 +0,0 @@
-package sample
-
-import (
-	"slices"
-
-	"gonum.org/v1/gonum/floats"
-	"gonum.org/v1/gonum/stat/sampleuv"
-)
-
-type Sampler interface {
-	Sample([]float64) ([]float64, error)
-}
-
-type Temperature float64
-
-func (s Temperature) Sample(t []float64) ([]float64, error) {
-	floats.Div(t, slices.Repeat([]float64{float64(s)}, len(t)))
-	return t, nil
-}
-
-type softmax struct{}
-
-func Softmax() Sampler {
-	return softmax{}
-}
-
-func (softmax) Sample(t []float64) ([]float64, error) {
-	return t, nil
-}
-
-type TopK int
-
-func (s TopK) Sample(t []float64) ([]float64, error) {
-	return t, nil
-}
-
-type TopP float32
-
-func (s TopP) Sample(t []float64) ([]float64, error) {
-	return t, nil
-}
-
-type MinP float32
-
-func (s MinP) Sample(t []float64) ([]float64, error) {
-	return t, nil
-}
-
-type weighed struct{}
-
-func Weighed() Sampler {
-	return weighed{}
-}
-
-func (s weighed) Sample(t []float64) ([]float64, error) {
-	w := sampleuv.NewWeighted(t, nil)
-	if v, ok := w.Take(); ok {
-		return []float64{float64(v)}, nil
-	}
-
-	return t, nil
-}
-
-func Sample(floats []float64, samplers ...Sampler) ([]float64, error) {
-	var err error
-	for _, sampler := range samplers {
-		floats, err = sampler.Sample(floats)
-		if err != nil {
-			return nil, err
-		}
-	}
-
-	return floats, nil
-}

+ 139 - 0
sample/samplers.go

@@ -0,0 +1,139 @@
+package sample
+
+import (
+	"errors"
+	"math"
+
+	"golang.org/x/exp/rand"
+	"gonum.org/v1/gonum/stat/sampleuv"
+)
+
+type Sampler interface {
+	Sample([]float32) (int32, error)
+}
+
+type weighted struct {
+	src        rand.Source
+	transforms []Transform
+}
+
+// TODO(parthsareen): remove uv sample dependency https://github.com/ollama/ollama/issues/9279
+func Weighted(seed *uint64, transforms ...Transform) Sampler {
+	var src rand.Source
+	if seed != nil {
+		src = rand.NewSource(*seed)
+	}
+	return weighted{src: src, transforms: transforms}
+}
+
+func (s weighted) Sample(logits []float32) (int32, error) {
+	logits64 := make([]float64, len(logits))
+	for i, v := range logits {
+		logits64[i] = float64(v)
+	}
+
+	for _, t := range s.transforms {
+		logits64 = t.Apply(logits64)
+	}
+
+	logitsCopy := make([]float64, 0, len(logits))
+	indices := make([]int, 0, len(logits))
+	for i, logit := range logits64 {
+		if !math.IsInf(logit, -1) {
+			logitsCopy = append(logitsCopy, logit)
+			indices = append(indices, i)
+		}
+	}
+
+	if len(logitsCopy) == 0 {
+		return -1, errors.New("no valid logits found for weighed sampling")
+	}
+
+	probs := softmax(logitsCopy)
+	w := sampleuv.NewWeighted(probs, s.src)
+	if idx, ok := w.Take(); ok {
+		return int32(indices[idx]), nil
+	}
+	return -1, errors.New("weighed sampler failed, no valid token found")
+}
+
+type greedy struct {
+	transforms []Transform
+}
+
+func Greedy(transforms ...Transform) Sampler {
+	return greedy{transforms: transforms}
+}
+
+func (s greedy) Sample(logits []float32) (int32, error) {
+	logits64 := make([]float64, len(logits))
+	for i, v := range logits {
+		logits64[i] = float64(v)
+	}
+
+	for _, t := range s.transforms {
+		logits64 = t.Apply(logits64)
+	}
+
+	var maxIdx int
+	var maxLogit float64
+	for i, logit := range logits64 {
+		if logit > maxLogit {
+			maxLogit = logit
+			maxIdx = i
+		}
+	}
+
+	if maxLogit == math.Inf(-1) {
+		return -1, errors.New("no valid logits found for greedy sampling")
+	}
+
+	return int32(maxIdx), nil
+}
+
+// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
+func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) (Sampler, error) {
+	transforms := []Transform{}
+	if temperature < 0 || temperature > 2 {
+		return nil, errors.New("temperature must be between 0 and 2")
+	}
+
+	if temperature != 0 {
+		transforms = append(transforms, Temperature(temperature))
+	}
+
+	if topK != 0 {
+		if topK <= 0 {
+			return nil, errors.New("topK must be greater than 0")
+		}
+		transforms = append(transforms, TopK(topK))
+	}
+
+	if topP != 0 {
+		if topP < 0 || topP >= 1 {
+			return nil, errors.New("topP must be between 0 and 1")
+		}
+		transforms = append(transforms, TopP(topP))
+	}
+
+	if minP != 0 {
+		if minP < 0 || minP >= 1 {
+			return nil, errors.New("minP must be between 0 and 1")
+		}
+		transforms = append(transforms, MinP(minP))
+	}
+
+	if len(transforms) == 0 {
+		return nil, errors.New("at least one transform is required")
+	}
+
+	if temperature == 0 {
+		return Greedy(transforms...), nil
+	}
+
+	if seed != 0 {
+		seed64 := uint64(seed)
+		return Weighted(&seed64, transforms...), nil
+	}
+	return Weighted(nil, transforms...), nil
+}

+ 238 - 0
sample/samplers_test.go

@@ -0,0 +1,238 @@
+package sample
+
+import (
+	"math"
+	"math/rand/v2"
+	"testing"
+
+	"github.com/google/go-cmp/cmp"
+)
+
+func TestWeighted(t *testing.T) {
+	got, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))})
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	want := int32(1)
+	if want != got {
+		t.Errorf("index mismatch: want %d, got %d", want, got)
+	}
+
+	got, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))})
+	if err == nil {
+		t.Error("expected error for no valid tokens, got index", got)
+	}
+
+	seed := uint64(42)
+	got, err = Weighted(&seed).Sample([]float32{1, 2, 3, 4})
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	// With seed 42, we expect a consistent sample
+	want = int32(3) // This will be deterministic due to the seed
+	if want != got {
+		t.Errorf("index mismatch: want %d, got %d", want, got)
+	}
+}
+
+type testTransform struct {
+	id        int
+	callOrder *[]int
+}
+
+func (ts *testTransform) Apply(logits []float64) []float64 {
+	if ts.callOrder != nil {
+		*ts.callOrder = append(*ts.callOrder, ts.id)
+	}
+	return logits
+}
+
+func TestSample(t *testing.T) {
+	input := []float32{1, 2, 3, 4}
+
+	var callOrder []int
+	mock1 := &testTransform{
+		id:        1,
+		callOrder: &callOrder,
+	}
+	mock2 := &testTransform{
+		id:        2,
+		callOrder: &callOrder,
+	}
+	mock3 := &testTransform{
+		id:        3,
+		callOrder: &callOrder,
+	}
+
+	got, err := Greedy(mock1, mock2, mock3).Sample(input)
+	if err != nil {
+		t.Error(err)
+		return
+	}
+
+	want := int32(3) // Greedy sampler should pick highest logit
+	if want != got {
+		t.Errorf("index mismatch: want %d, got %d", want, got)
+	}
+	wantOrder := []int{1, 2, 3}
+	if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
+		t.Errorf("call order mismatch (-want +got):\n%s", diff)
+	}
+
+	callOrder = nil
+
+	_, err = Weighted(nil, mock1, mock2, mock3).Sample(input)
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	wantOrder = []int{1, 2, 3}
+	if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
+		t.Errorf("call order mismatch (-want +got):\n%s", diff)
+	}
+}
+
+func TestNewSampler(t *testing.T) {
+	tests := []struct {
+		name        string
+		temperature float32
+		topK        int
+		topP        float32
+		minP        float32
+		seed        int
+		wantErr     bool
+	}{
+		{
+			name:    "no transforms",
+			wantErr: true,
+		},
+		{
+			name:        "temperature",
+			temperature: 0.5,
+			wantErr:     false,
+		},
+		{
+			name:        "invalid temperature negative",
+			temperature: -1,
+			wantErr:     true,
+		},
+		{
+			name:        "invalid temperature too high",
+			temperature: 2.1,
+			wantErr:     true,
+		},
+		{
+			name:    "top k",
+			topK:    10,
+			wantErr: false,
+		},
+		{
+			name:    "invalid top k negative",
+			topK:    -1,
+			wantErr: true,
+		},
+		{
+			name:    "top p",
+			topP:    0.9,
+			wantErr: false,
+		},
+		{
+			name:    "invalid top p negative",
+			topP:    -0.1,
+			wantErr: true,
+		},
+		{
+			name:    "invalid top p one",
+			topP:    1.0,
+			wantErr: true,
+		},
+		{
+			name:    "min p",
+			minP:    0.2,
+			wantErr: false,
+		},
+		{
+			name:    "invalid min p negative",
+			minP:    -0.1,
+			wantErr: true,
+		},
+		{
+			name:    "invalid min p one",
+			minP:    1.0,
+			wantErr: true,
+		},
+		{
+			name:    "seed",
+			seed:    42,
+			wantErr: true, // seed alone is not valid without other transforms
+		},
+		{
+			name:        "default values",
+			temperature: 0.8,
+			topK:        40,
+			topP:        0.9,
+			minP:        0.0,
+			seed:        0,
+			wantErr:     false,
+		},
+		{
+			name:        "all zeroes",
+			temperature: 0.0,
+			topK:        0,
+			topP:        0.0,
+			minP:        0.0,
+			seed:        0,
+			wantErr:     true, // all zeroes means no transforms
+		},
+		{
+			name:        "all transforms",
+			temperature: 0.8,
+			topK:        50,
+			topP:        0.95,
+			minP:        0.1,
+			seed:        42,
+			wantErr:     false,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			_, err := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
+			if (err != nil) != tt.wantErr {
+				t.Errorf("NewSampler() error = %v, wantErr %v", err, tt.wantErr)
+			}
+		})
+	}
+}
+
+func BenchmarkSample(b *testing.B) {
+	transforms := []Transform{
+		Temperature(0.5),
+		TopK(10),
+		TopP(0.9),
+		MinP(0.2),
+	}
+
+	samplers := map[string]Sampler{
+		"Greedy":   Greedy(transforms...),
+		"Weighted": Weighted(nil, transforms...),
+	}
+
+	logits := make([]float32, 1<<16)
+	for i := range logits {
+		logits[i] = rand.Float32()
+	}
+
+	for name, s := range samplers {
+		b.Run(name, func(b *testing.B) {
+			b.ResetTimer()
+			for range b.N {
+				if _, err := s.Sample(logits); err != nil {
+					b.Error(err)
+				}
+			}
+		})
+	}
+}

+ 120 - 0
sample/transforms.go

@@ -0,0 +1,120 @@
+package sample
+
+import (
+	"cmp"
+	"math"
+	"slices"
+
+	pq "github.com/emirpasic/gods/v2/queues/priorityqueue"
+)
+
+type Transform interface {
+	Apply([]float64) []float64
+}
+
+// TODO(parthsareen): potentially cache softmax values
+func softmax(logits []float64) []float64 {
+	var sum float64
+	probs := make([]float64, len(logits))
+	for i, v := range logits {
+		probs[i] = math.Exp(v)
+		sum += probs[i]
+	}
+
+	for i := range probs {
+		probs[i] /= sum
+	}
+
+	return probs
+}
+
+type Temperature float64
+
+func (t Temperature) Apply(logits []float64) []float64 {
+	temp := math.Max(float64(t), 1e-7)
+
+	// subtracting max logit to avoid under/overflow
+	maxLogit := slices.Max(logits)
+	for i := range logits {
+		logits[i] = (logits[i] - maxLogit) / temp
+	}
+
+	return logits
+}
+
+type logitMap struct {
+	index int
+	logit float64
+}
+
+type TopK int
+
+// TODO(parthsareen): avoid having to check all logits after this transform
+func (k TopK) Apply(logits []float64) []float64 {
+	if int(k) >= len(logits) {
+		return logits
+	}
+	q := pq.NewWith(func(a, b logitMap) int {
+		return -cmp.Compare(a.logit, b.logit)
+	})
+
+	for i, logit := range logits {
+		q.Enqueue(logitMap{index: i, logit: logit})
+	}
+
+	validLogits := make(map[int]float64)
+	for range k {
+		logitMap, _ := q.Dequeue()
+		validLogits[logitMap.index] = logitMap.logit
+	}
+
+	for i := range logits {
+		if _, ok := validLogits[i]; !ok {
+			logits[i] = math.Inf(-1)
+		}
+	}
+
+	return logits
+}
+
+type TopP float64
+
+func (p TopP) Apply(logits []float64) []float64 {
+	probs := softmax(logits)
+	indices := make([]int, len(probs))
+	for i := range indices {
+		indices[i] = i
+	}
+
+	// sort in descending order
+	slices.SortFunc(indices, func(i, j int) int {
+		return cmp.Compare(probs[j], probs[i])
+	})
+
+	var sum float64
+	for i, idx := range indices {
+		sum += probs[idx]
+		if sum > float64(p) {
+			for _, idx := range indices[i+1:] {
+				logits[idx] = math.Inf(-1)
+			}
+			break
+		}
+	}
+	return logits
+}
+
+type MinP float64
+
+func (p MinP) Apply(logits []float64) []float64 {
+	probs := softmax(logits)
+	threshold := slices.Max(probs) * float64(p)
+
+	for i, prob := range probs {
+		if prob < threshold {
+			logits[i] = math.Inf(-1)
+		}
+	}
+
+	return logits
+}

+ 80 - 0
sample/transforms_test.go

@@ -0,0 +1,80 @@
+package sample
+
+import (
+	"math"
+	"math/rand/v2"
+	"testing"
+
+	"github.com/google/go-cmp/cmp"
+)
+
+func TestTemperature(t *testing.T) {
+	got := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0})
+	want := []float64{-4, -10, 0, -14, -6, -12, -8}
+	if diff := cmp.Diff(want, got); diff != "" {
+		t.Errorf("logits mismatch (-want +got):\n%s", diff)
+	}
+}
+
+func TestSoftmax(t *testing.T) {
+	got := softmax([]float64{-3, -2, -1, 0, 1, 2, 4})
+
+	want := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085}
+	if diff := cmp.Diff(want, got); diff != "" {
+		t.Errorf("probs mismatch (-want +got):\n%s", diff)
+	}
+}
+
+func TestTopK(t *testing.T) {
+	got := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
+	want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4}
+	if diff := cmp.Diff(want, got); diff != "" {
+		t.Errorf("logits mismatch (-want +got):\n%s", diff)
+	}
+
+	got = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
+
+	want = []float64{-3, -2, -1, 0, 1, 2, 4}
+	if diff := cmp.Diff(want, got); diff != "" {
+		t.Errorf("logits mismatch (-want +got):\n%s", diff)
+	}
+}
+
+func TestTopP(t *testing.T) {
+	got := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
+	want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4}
+	if diff := cmp.Diff(want, got); diff != "" {
+		t.Errorf("logits mismatch (-want +got):\n%s", diff)
+	}
+}
+
+func TestMinP(t *testing.T) {
+	got := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3})
+	want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3}
+	if diff := cmp.Diff(want, got); diff != "" {
+		t.Errorf("logits mismatch (-want +got):\n%s", diff)
+	}
+}
+
+func BenchmarkTransform(b *testing.B) {
+	transforms := map[string]Transform{
+		"Temperature": Temperature(0.5),
+		"TopK":        TopK(10),
+		"TopP":        TopP(0.9),
+		"MinP":        MinP(0.2),
+	}
+
+	logits := make([]float64, 1<<16)
+	for i := range logits {
+		logits[i] = rand.Float64()
+	}
+
+	for name, transform := range transforms {
+		b.Run(name, func(b *testing.B) {
+			b.ResetTimer()
+			for range b.N {
+				transform.Apply(logits)
+			}
+		})
+	}
+}