Browse Source

log probs working

ParthSareen 3 tháng trước cách đây
mục cha
commit
afa2e855d4
5 tập tin đã thay đổi với 126 bổ sung73 xóa
  1. 1 1
      api/client.go
  2. 7 10
      api/types.go
  3. 12 0
      llama/llama.go
  4. 46 49
      llama/runner/runner.go
  5. 60 13
      server/routes.go

+ 1 - 1
api/client.go

@@ -129,7 +129,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
 	return nil
 }
 
-const maxBufferSize = 512 * format.KiloByte
+const maxBufferSize = 1024 * format.KiloByte
 
 func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
 	var buf *bytes.Buffer

+ 7 - 10
api/types.go

@@ -189,11 +189,12 @@ func (t *ToolFunction) String() string {
 // ChatResponse is the response returned by [Client.Chat]. Its fields are
 // similar to [GenerateResponse].
 type ChatResponse struct {
-	Model      string    `json:"model"`
-	CreatedAt  time.Time `json:"created_at"`
-	Message    Message   `json:"message"`
-	DoneReason string    `json:"done_reason,omitempty"`
-	Logits     []float32 `json:"logits"`
+	Model       string         `json:"model"`
+	CreatedAt   time.Time      `json:"created_at"`
+	Message     Message        `json:"message"`
+	DoneReason  string         `json:"done_reason,omitempty"`
+	Logits      []float32      `json:"logits"`
+	TopLogprobs []TokenLogprob `json:"top_logprobs"`
 
 	Done bool `json:"done"`
 
@@ -210,14 +211,10 @@ type Metrics struct {
 }
 
 type TokenLogprob struct {
-	Token   string  `json:"token"`
+	Text    string  `json:"text"`
 	Logprob float32 `json:"logprob"`
 }
 
-type LogProbs struct {
-	TopLogprobs []TokenLogprob `json:"top_logprobs"`
-}
-
 // Options specified in [GenerateRequest].  If you add a new option here, also
 // add it to the API docs.
 type Options struct {

+ 12 - 0
llama/llama.go

@@ -273,6 +273,18 @@ func (c *Context) GetLogits() []float32 {
 	return unsafe.Slice((*float32)(logits), vocabSize)
 }
 
+func (m *Model) Detokenize(tokens []int) (string, error) {
+	var text string
+	for _, token := range tokens {
+		piece := m.TokenToPiece(token)
+		if piece == "" {
+			return "", fmt.Errorf("failed to convert token %d to piece", token)
+		}
+		text += piece
+	}
+	return text, nil
+}
+
 type ModelParams struct {
 	NumGpuLayers int
 	MainGpu      int

+ 46 - 49
llama/runner/runner.go

@@ -15,7 +15,6 @@ import (
 	"path/filepath"
 	"regexp"
 	"runtime"
-	"sort"
 	"strconv"
 	"strings"
 	"sync"
@@ -503,9 +502,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 		if seq.returnLogits { // New flag we need to add to Sequence struct
 			logits := s.lc.GetLogits()
 			seq.logits = make([]float32, len(logits))
-			slog.Info("copying logits")
 			copy(seq.logits, logits)
-			slog.Info("copying logits success")
 		}
 
 		// Then sample token
@@ -608,7 +605,7 @@ type CompletionRequest struct {
 	Images       []ImageData `json:"image_data"`
 	Grammar      string      `json:"grammar"`
 	CachePrompt  bool        `json:"cache_prompt"`
-	ReturnLogits bool        `json:"return_logits"`
+	ReturnLogits bool        `json:"return_logits,omitempty"` // defaults to false
 
 	Options
 }
@@ -729,7 +726,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 			return
 		case content, ok := <-seq.responses:
 			if ok {
-				slog.Info("content", "content", content.Content)
+				// slog.Info("content", "content", content.Content)
 				if err := json.NewEncoder(w).Encode(&content); err != nil {
 					http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
 					close(seq.quit)
@@ -1040,50 +1037,50 @@ func Execute(args []string) error {
 	return nil
 }
 
-// Helper function to get top K logits and convert to log probabilities
-func getTopLogits(logits []float32, k int, model *llama.Model) []api.LogProbs {
-	if k <= 0 {
-		return nil
-	}
-
-	// Convert logits to probabilities using softmax
-	probs := softmax(logits)
-
-	// Create slice of index/probability pairs
-	pairs := make([]struct {
-		token int
-		prob  float32
-	}, len(probs))
-
-	for i, p := range probs {
-		pairs[i] = struct {
-			token int
-			prob  float32
-		}{i, p}
-	}
-
-	// Sort by probability (descending)
-	sort.Slice(pairs, func(i, j int) bool {
-		return pairs[i].prob > pairs[j].prob
-	})
-
-	// Take top K
-	k = min(k, len(pairs))
-	result := make([]api.LogProbs, k)
-
-	for i := 0; i < k; i++ {
-		result[i] = api.LogProbs{
-			TopLogprobs: []api.TokenLogprob{
-				{
-					Token:   model.TokenToPiece(pairs[i].token),
-					Logprob: float32(math.Log(float64(pairs[i].prob))),
-				},
-			},
-		}
-	}
-
-	return result
-}
+// // Helper function to get top K logits and convert to log probabilities
+// func getTopLogits(logits []float32, k int, model *llama.Model) []api.LogProbs {
+// 	if k <= 0 {
+// 		return nil
+// 	}
+
+// 	// Convert logits to probabilities using softmax
+// 	probs := softmax(logits)
+
+// 	// Create slice of index/probability pairs
+// 	pairs := make([]struct {
+// 		token int
+// 		prob  float32
+// 	}, len(probs))
+
+// 	for i, p := range probs {
+// 		pairs[i] = struct {
+// 			token int
+// 			prob  float32
+// 		}{i, p}
+// 	}
+
+// 	// Sort by probability (descending)
+// 	sort.Slice(pairs, func(i, j int) bool {
+// 		return pairs[i].prob > pairs[j].prob
+// 	})
+
+// 	// Take top K
+// 	k = min(k, len(pairs))
+// 	result := make([]api.LogProbs, k)
+
+// 	for i := 0; i < k; i++ {
+// 		result[i] = api.LogProbs{
+// 			TopLogprobs: []api.TokenLogprob{
+// 				{
+// 					Token:   model.TokenToPiece(pairs[i].token),
+// 					Logprob: float32(math.Log(float64(pairs[i].prob))),
+// 				},
+// 			},
+// 		}
+// 	}
+
+// 	return result
+// }
 
 // Helper function to compute softmax
 func softmax(logits []float32) []float32 {

+ 60 - 13
server/routes.go

@@ -19,6 +19,7 @@ import (
 	"os/signal"
 	"path/filepath"
 	"slices"
+	"sort"
 	"strings"
 	"syscall"
 	"time"
@@ -299,7 +300,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 			Images:       images,
 			Format:       req.Format,
 			Options:      opts,
-			ReturnLogits: req.ReturnLogits,
+			ReturnLogits: false,
 		}, func(cr llm.CompletionResponse) {
 			res := api.GenerateResponse{
 				Model:      req.Model,
@@ -1554,23 +1555,27 @@ func (s *Server) ChatHandler(c *gin.Context) {
 			Format:       req.Format,
 			Options:      opts,
 			ReturnLogits: true,
-		}, func(r llm.CompletionResponse) {
+		}, func(cr llm.CompletionResponse) {
 			res := api.ChatResponse{
 				Model:      req.Model,
 				CreatedAt:  time.Now().UTC(),
-				Message:    api.Message{Role: "assistant", Content: r.Content},
-				Done:       r.Done,
-				DoneReason: r.DoneReason,
-				Logits:     r.Logits,
+				Message:    api.Message{Role: "assistant", Content: cr.Content},
+				Done:       cr.Done,
+				DoneReason: cr.DoneReason,
+				Logits:     []float32{},
 				Metrics: api.Metrics{
-					PromptEvalCount:    r.PromptEvalCount,
-					PromptEvalDuration: r.PromptEvalDuration,
-					EvalCount:          r.EvalCount,
-					EvalDuration:       r.EvalDuration,
+					PromptEvalCount:    cr.PromptEvalCount,
+					PromptEvalDuration: cr.PromptEvalDuration,
+					EvalCount:          cr.EvalCount,
+					EvalDuration:       cr.EvalDuration,
 				},
 			}
 
-			if r.Done {
+			topK := int(3)
+			logits := make([]float32, len(cr.Logits))
+			copy(logits, cr.Logits)
+			res.TopLogprobs = getTopKLogProbs(c.Request.Context(), r, logits, topK)
+			if cr.Done {
 				res.TotalDuration = time.Since(checkpointStart)
 				res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
 			}
@@ -1586,7 +1591,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
 			// Streaming tool calls:
 			// If tools are recognized, use a flag to track the sending of a tool downstream
 			// This ensures that content is cleared from the message on the last chunk sent
-			sb.WriteString(r.Content)
+			sb.WriteString(cr.Content)
 			if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
 				res.Message.ToolCalls = toolCalls
 				for i := range toolCalls {
@@ -1599,7 +1604,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
 				return
 			}
 
-			if r.Done {
+			if cr.Done {
 				// Send any remaining content if no tool calls were detected
 				if toolCallIndex == 0 {
 					res.Message.Content = sb.String()
@@ -1649,6 +1654,48 @@ func (s *Server) ChatHandler(c *gin.Context) {
 	streamResponse(c, ch)
 }
 
+func getTopKLogProbs(ctx context.Context, s llm.LlamaServer, logits []float32, topK int) []api.TokenLogprob {
+	// Calculate softmax denominator first (log sum exp trick for numerical stability)
+	maxLogit := float32(math.Inf(-1))
+	for _, logit := range logits {
+		if logit > maxLogit {
+			maxLogit = logit
+		}
+	}
+
+	var sumExp float32
+	for _, logit := range logits {
+		sumExp += float32(math.Exp(float64(logit - maxLogit)))
+	}
+	logSumExp := float32(math.Log(float64(sumExp))) + maxLogit
+
+	// Calculate log probs and track top K
+	logProbs := make([]api.TokenLogprob, len(logits))
+	for i, logit := range logits {
+		text, err := s.Detokenize(ctx, []int{i})
+		if err != nil {
+			slog.Error("detokenize error for logprob", "error", err)
+			continue
+		}
+
+		logProbs[i] = api.TokenLogprob{
+			Text:    text,
+			Logprob: logit - logSumExp,
+		}
+	}
+
+	// Sort by logprob descending and take top K
+	sort.Slice(logProbs, func(i, j int) bool {
+		return logProbs[i].Logprob > logProbs[j].Logprob
+	})
+
+	if len(logProbs) > topK {
+		logProbs = logProbs[:topK]
+	}
+
+	return logProbs
+}
+
 func handleScheduleError(c *gin.Context, name string, err error) {
 	switch {
 	case errors.Is(err, errCapabilities), errors.Is(err, errRequired):