Bruce MacDonald hace 2 meses
padre
commit
b88489a87e
Se han modificado 2 ficheros con 70 adiciones y 7 borrados
  1. 12 7
      llama/runner/runner.go
  2. 58 0
      llama/runner/runner_test.go

+ 12 - 7
llama/runner/runner.go

@@ -378,13 +378,7 @@ type TokenProbs struct {
 }
 
 // probs returns sorted token probabilities for a specific token index
-func (s *Server) probs(seq *Sequence) []TokenProbs {
-	// Get logits for the specific token index
-	logits := s.lc.GetLogits()
-	seq.logits = make([]float32, len(logits))
-	copy(seq.logits, logits)
-
-	vocabSize := s.model.NumVocab()
+func probs(logits []float32, vocabSize int) []TokenProbs {
 	probs := make([]TokenProbs, vocabSize)
 
 	// Initialize token data with logits
@@ -420,6 +414,17 @@ func (s *Server) probs(seq *Sequence) []TokenProbs {
 	return probs
 }
 
+// probs returns sorted token probabilities for a specific token index
+func (s *Server) probs(seq *Sequence) []TokenProbs {
+	// Get logits for the specific token index
+	logits := s.lc.GetLogits()
+	seq.logits = make([]float32, len(logits))
+	copy(seq.logits, logits)
+
+	vocabSize := s.model.NumVocab()
+	return probs(logits, vocabSize)
+}
+
 // TODO (jmorganca): processBatch should be simplified, removing:
 // * sampling
 // * stop token checking

+ 58 - 0
llama/runner/runner_test.go

@@ -0,0 +1,58 @@
+package runner
+
+import (
+	"math"
+	"testing"
+)
+
+func TestProbs(t *testing.T) {
+	// Input test data
+	logits := []float32{1.0, 2.0, 0.5, -1.0}
+	vocabSize := 4
+	want := []TokenProbs{
+		{TokenID: 1, Logit: 2.0},  // Highest logit
+		{TokenID: 0, Logit: 1.0},  // Second highest
+		{TokenID: 2, Logit: 0.5},  // Third
+		{TokenID: 3, Logit: -1.0}, // Lowest
+	}
+
+	got := probs(logits, vocabSize)
+
+	// Test 1: Check sorting order
+	for i := 0; i < len(got)-1; i++ {
+		if got[i].Logit < got[i+1].Logit {
+			t.Errorf("probs not properly sorted: logit at pos %d (%f) < logit at pos %d (%f)",
+				i, got[i].Logit, i+1, got[i+1].Logit)
+		}
+	}
+
+	// Test 2: Check probability normalization
+	var sum float32
+	for _, p := range got {
+		sum += p.Prob
+	}
+	if math.Abs(float64(sum-1.0)) > 1e-6 {
+		t.Errorf("probabilities do not sum to 1: got %v", sum)
+	}
+
+	// Test 3: Check token IDs match expected order
+	for i, want := range want {
+		if got[i].TokenID != want.TokenID {
+			t.Errorf("wrong token ID at position %d: got %d, want %d",
+				i, got[i].TokenID, want.TokenID)
+		}
+		if got[i].Logit != want.Logit {
+			t.Errorf("wrong logit at position %d: got %f, want %f",
+				i, got[i].Logit, want.Logit)
+		}
+	}
+
+	// Test 4: Check log probs are correctly calculated
+	for i, p := range got {
+		expectedLogProb := float32(math.Log(float64(p.Prob)))
+		if math.Abs(float64(p.LogProb-expectedLogProb)) > 1e-6 {
+			t.Errorf("wrong log prob at position %d: got %f, want %f",
+				i, p.LogProb, expectedLogProb)
+		}
+	}
+}