Browse Source

sample: remove transforms from greedy sampling (#9377)

Parth Sareen 2 tháng trước cách đây
mục cha
commit
c245b0406f
2 tập tin đã thay đổi với 54 bổ sung86 xóa
  1. 16 35
      sample/samplers.go
  2. 38 51
      sample/samplers_test.go

+ 16 - 35
sample/samplers.go

@@ -54,53 +54,42 @@ func (s weighted) Sample(logits []float32) (int32, error) {
 	if idx, ok := w.Take(); ok {
 		return int32(indices[idx]), nil
 	}
-	return -1, errors.New("weighed sampler failed, no valid token found")
+	return -1, errors.New("weighted sampler failed, no valid token found")
 }
 
-type greedy struct {
-	transforms []Transform
-}
+type greedy struct{}
 
-func Greedy(transforms ...Transform) Sampler {
-	return greedy{transforms: transforms}
+func Greedy() Sampler {
+	return greedy{}
 }
 
+// Sample returns the index of the maximum value in logits.
 func (s greedy) Sample(logits []float32) (int32, error) {
-	logits64 := make([]float64, len(logits))
-	for i, v := range logits {
-		logits64[i] = float64(v)
+	if len(logits) == 0 {
+		return -1, errors.New("no logits provided for greedy sampling")
 	}
 
-	for _, t := range s.transforms {
-		logits64 = t.Apply(logits64)
-	}
-
-	var maxIdx int
-	var maxLogit float64
-	for i, logit := range logits64 {
-		if logit > maxLogit {
-			maxLogit = logit
+	maxIdx := 0
+	for i := range logits {
+		if logits[i] > logits[maxIdx] {
 			maxIdx = i
 		}
 	}
 
-	if maxLogit == math.Inf(-1) {
-		return -1, errors.New("no valid logits found for greedy sampling")
-	}
-
 	return int32(maxIdx), nil
 }
 
 // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
 func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) (Sampler, error) {
-	transforms := []Transform{}
+	if temperature == 0 {
+		return Greedy(), nil
+	}
+
 	if temperature < 0 || temperature > 2 {
 		return nil, errors.New("temperature must be between 0 and 2")
 	}
 
-	if temperature != 0 {
-		transforms = append(transforms, Temperature(temperature))
-	}
+	transforms := []Transform{Temperature(temperature)}
 
 	if topK != 0 {
 		if topK <= 0 {
@@ -123,15 +112,7 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
 		transforms = append(transforms, MinP(minP))
 	}
 
-	if len(transforms) == 0 {
-		return nil, errors.New("at least one transform is required")
-	}
-
-	if temperature == 0 {
-		return Greedy(transforms...), nil
-	}
-
-	if seed != 0 {
+	if seed >= 0 {
 		seed64 := uint64(seed)
 		return Weighted(&seed64, transforms...), nil
 	}

+ 38 - 51
sample/samplers_test.go

@@ -66,32 +66,15 @@ func TestSample(t *testing.T) {
 		callOrder: &callOrder,
 	}
 
-	got, err := Greedy(mock1, mock2, mock3).Sample(input)
+	_, err := Weighted(nil, mock1, mock2, mock3).Sample(input)
 	if err != nil {
 		t.Error(err)
 		return
 	}
-
-	want := int32(3) // Greedy sampler should pick highest logit
-	if want != got {
-		t.Errorf("index mismatch: want %d, got %d", want, got)
-	}
 	wantOrder := []int{1, 2, 3}
 	if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
 		t.Errorf("call order mismatch (-want +got):\n%s", diff)
 	}
-
-	callOrder = nil
-
-	_, err = Weighted(nil, mock1, mock2, mock3).Sample(input)
-	if err != nil {
-		t.Error(err)
-		return
-	}
-	wantOrder = []int{1, 2, 3}
-	if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
-		t.Errorf("call order mismatch (-want +got):\n%s", diff)
-	}
 }
 
 func TestNewSampler(t *testing.T) {
@@ -105,8 +88,9 @@ func TestNewSampler(t *testing.T) {
 		wantErr     bool
 	}{
 		{
-			name:    "no transforms",
-			wantErr: true,
+			name: "no transforms",
+			// temperature is 0, so greedy should be used
+			wantErr: false,
 		},
 		{
 			name:        "temperature",
@@ -124,49 +108,52 @@ func TestNewSampler(t *testing.T) {
 			wantErr:     true,
 		},
 		{
-			name:    "top k",
-			topK:    10,
-			wantErr: false,
-		},
-		{
-			name:    "invalid top k negative",
-			topK:    -1,
-			wantErr: true,
+			name:        "top k",
+			topK:        10,
+			temperature: 0.8,
+			wantErr:     false,
 		},
 		{
-			name:    "top p",
-			topP:    0.9,
-			wantErr: false,
+			name:        "invalid top k negative",
+			topK:        -1,
+			temperature: 0.8,
+			wantErr:     true,
 		},
 		{
-			name:    "invalid top p negative",
-			topP:    -0.1,
-			wantErr: true,
+			name:        "top p",
+			topP:        0.9,
+			temperature: 0.8,
+			wantErr:     false,
 		},
 		{
-			name:    "invalid top p one",
-			topP:    1.0,
-			wantErr: true,
+			name:        "invalid top p negative",
+			topP:        -0.1,
+			temperature: 0.8,
+			wantErr:     true,
 		},
 		{
-			name:    "min p",
-			minP:    0.2,
-			wantErr: false,
+			name:        "invalid top p one",
+			topP:        1.0,
+			temperature: 0.8,
+			wantErr:     true,
 		},
 		{
-			name:    "invalid min p negative",
-			minP:    -0.1,
-			wantErr: true,
+			name:        "min p",
+			minP:        0.2,
+			temperature: 0.8,
+			wantErr:     false,
 		},
 		{
-			name:    "invalid min p one",
-			minP:    1.0,
-			wantErr: true,
+			name:        "invalid min p negative",
+			minP:        -0.1,
+			temperature: 0.8,
+			wantErr:     true,
 		},
 		{
-			name:    "seed",
-			seed:    42,
-			wantErr: true, // seed alone is not valid without other transforms
+			name:        "invalid min p one",
+			minP:        1.0,
+			temperature: 0.8,
+			wantErr:     true,
 		},
 		{
 			name:        "default values",
@@ -184,7 +171,7 @@ func TestNewSampler(t *testing.T) {
 			topP:        0.0,
 			minP:        0.0,
 			seed:        0,
-			wantErr:     true, // all zeroes means no transforms
+			wantErr:     false, // all zeroes means no transforms
 		},
 		{
 			name:        "all transforms",
@@ -216,7 +203,7 @@ func BenchmarkSample(b *testing.B) {
 	}
 
 	samplers := map[string]Sampler{
-		"Greedy":   Greedy(transforms...),
+		"Greedy":   Greedy(),
 		"Weighted": Weighted(nil, transforms...),
 	}