Browse Source

print logprobs

Bruce MacDonald 2 tháng trước cách đây
mục cha
commit
7d16ec8fe8
2 tập tin đã thay đổi với 95 bổ sung3 xóa
  1. 26 1
      llama/llama.go
  2. 69 2
      llama/runner/runner.go

+ 26 - 1
llama/llama.go

@@ -50,7 +50,7 @@ import (
 	_ "github.com/ollama/ollama/llama/llama.cpp/common"
 	_ "github.com/ollama/ollama/llama/llama.cpp/common"
 	_ "github.com/ollama/ollama/llama/llama.cpp/examples/llava"
 	_ "github.com/ollama/ollama/llama/llama.cpp/examples/llava"
 	_ "github.com/ollama/ollama/llama/llama.cpp/src"
 	_ "github.com/ollama/ollama/llama/llama.cpp/src"
-	"github.com/ollama/ollama/ml/backend/ggml/ggml/src"
+	ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
 )
 )
 
 
 func BackendInit() {
 func BackendInit() {
@@ -220,6 +220,31 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 {
 	return embeddings
 	return embeddings
 }
 }
 
 
+// GetLogits returns the logits from the last decode operation.
+// The returned slice has length equal to the vocabulary size.
+func (c *Context) GetLogits() []float32 {
+	logits := unsafe.Pointer(C.llama_get_logits(c.c))
+	if logits == nil {
+		return nil
+	}
+
+	// Get the number of vocabulary tokens to determine array size
+	vocabSize := c.Model().NumVocab()
+	return unsafe.Slice((*float32)(logits), vocabSize)
+}
+
+func (m *Model) Detokenize(tokens []int) (string, error) {
+	var text string
+	for _, token := range tokens {
+		piece := m.TokenToPiece(token)
+		if piece == "" {
+			return "", fmt.Errorf("failed to convert token %d to piece", token)
+		}
+		text += piece
+	}
+	return text, nil
+}
+
 type ModelParams struct {
 type ModelParams struct {
 	NumGpuLayers int
 	NumGpuLayers int
 	MainGpu      int
 	MainGpu      int

+ 69 - 2
llama/runner/runner.go

@@ -8,12 +8,14 @@ import (
 	"fmt"
 	"fmt"
 	"log"
 	"log"
 	"log/slog"
 	"log/slog"
+	"math"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
 	"regexp"
 	"regexp"
 	"runtime"
 	"runtime"
+	"sort"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
@@ -83,6 +85,8 @@ type Sequence struct {
 
 
 	doneReason string
 	doneReason string
 
 
+	logits []float32
+
 	// Metrics
 	// Metrics
 	startProcessingTime time.Time
 	startProcessingTime time.Time
 	startGenerationTime time.Time
 	startGenerationTime time.Time
@@ -274,6 +278,9 @@ func (s *Server) allNil() bool {
 }
 }
 
 
 func flushPending(seq *Sequence) bool {
 func flushPending(seq *Sequence) bool {
+	if len(seq.pendingResponses) == 0 {
+		return true
+	}
 	joined := strings.Join(seq.pendingResponses, "")
 	joined := strings.Join(seq.pendingResponses, "")
 	seq.pendingResponses = []string{}
 	seq.pendingResponses = []string{}
 
 
@@ -287,8 +294,11 @@ func flushPending(seq *Sequence) bool {
 		joined = joined[:len(joined)-1]
 		joined = joined[:len(joined)-1]
 	}
 	}
 
 
-	if len(joined) == 0 {
-		return true
+	// Add logits if requested and available
+	wantLogits := true
+	if wantLogits && seq.logits != nil {
+		// resp.Logits = seq.logits
+		seq.logits = nil
 	}
 	}
 
 
 	select {
 	select {
@@ -350,6 +360,57 @@ func (s *Server) run(ctx context.Context) {
 	}
 	}
 }
 }
 
 
+// TokenData represents probability information for a token
+type TokenData struct {
+	TokenID int
+	Logit   float32
+	Prob    float32
+	LogProb float32
+}
+
+// getTokenProbabilities returns sorted token probabilities for a specific token index
+func (s *Server) getTokenProbabilities(seq *Sequence) []TokenData {
+	// Get logits for the specific token index
+	logits := s.lc.GetLogits()
+	seq.logits = make([]float32, len(logits))
+	copy(seq.logits, logits)
+
+	vocabSize := s.model.NumVocab()
+	probs := make([]TokenData, vocabSize)
+
+	// Initialize token data with logits
+	for i := 0; i < vocabSize; i++ {
+		probs[i] = TokenData{
+			TokenID: i,
+			Logit:   logits[i],
+		}
+	}
+
+	// Sort tokens by logits in descending order
+	sort.Slice(probs, func(i, j int) bool {
+		return probs[i].Logit > probs[j].Logit
+	})
+
+	// Apply softmax
+	maxLogit := probs[0].Logit
+	var sum float32 = 0.0
+
+	for i := range probs {
+		p := float32(math.Exp(float64(probs[i].Logit - maxLogit)))
+		probs[i].Prob = p
+		sum += p
+	}
+
+	// Normalize probabilities and calculate log probs
+	for i := range probs {
+		prob := probs[i].Prob / sum
+		probs[i].Prob = prob
+		probs[i].LogProb = float32(math.Log(float64(prob)))
+	}
+
+	return probs
+}
+
 // TODO (jmorganca): processBatch should be simplified, removing:
 // TODO (jmorganca): processBatch should be simplified, removing:
 // * sampling
 // * sampling
 // * stop token checking
 // * stop token checking
@@ -483,6 +544,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 
 
 		seq.numPredicted++
 		seq.numPredicted++
 
 
+		// TODO: only do this when flag specified
+		probs := s.getTokenProbabilities(seq)
+		for i := range 10 {
+			slog.Debug("top 10 tokens", "token", probs[i].TokenID, "prob", probs[i].Prob, "logit", probs[i].Logit, "piece", s.model.TokenToPiece(probs[i].TokenID))
+		}
+
 		// if it's an end of sequence token, break
 		// if it's an end of sequence token, break
 		if s.model.TokenIsEog(token) {
 		if s.model.TokenIsEog(token) {
 			// TODO (jmorganca): we should send this back
 			// TODO (jmorganca): we should send this back