|
@@ -8,12 +8,14 @@ import (
|
|
"fmt"
|
|
"fmt"
|
|
"log"
|
|
"log"
|
|
"log/slog"
|
|
"log/slog"
|
|
|
|
+ "math"
|
|
"net"
|
|
"net"
|
|
"net/http"
|
|
"net/http"
|
|
"os"
|
|
"os"
|
|
"path/filepath"
|
|
"path/filepath"
|
|
"regexp"
|
|
"regexp"
|
|
"runtime"
|
|
"runtime"
|
|
|
|
+ "sort"
|
|
"strconv"
|
|
"strconv"
|
|
"strings"
|
|
"strings"
|
|
"sync"
|
|
"sync"
|
|
@@ -83,6 +85,8 @@ type Sequence struct {
|
|
|
|
|
|
doneReason string
|
|
doneReason string
|
|
|
|
|
|
|
|
+ logits []float32
|
|
|
|
+
|
|
// Metrics
|
|
// Metrics
|
|
startProcessingTime time.Time
|
|
startProcessingTime time.Time
|
|
startGenerationTime time.Time
|
|
startGenerationTime time.Time
|
|
@@ -274,6 +278,9 @@ func (s *Server) allNil() bool {
|
|
}
|
|
}
|
|
|
|
|
|
func flushPending(seq *Sequence) bool {
|
|
func flushPending(seq *Sequence) bool {
|
|
|
|
+ if len(seq.pendingResponses) == 0 {
|
|
|
|
+ return true
|
|
|
|
+ }
|
|
joined := strings.Join(seq.pendingResponses, "")
|
|
joined := strings.Join(seq.pendingResponses, "")
|
|
seq.pendingResponses = []string{}
|
|
seq.pendingResponses = []string{}
|
|
|
|
|
|
@@ -287,8 +294,11 @@ func flushPending(seq *Sequence) bool {
|
|
joined = joined[:len(joined)-1]
|
|
joined = joined[:len(joined)-1]
|
|
}
|
|
}
|
|
|
|
|
|
- if len(joined) == 0 {
|
|
|
|
- return true
|
|
|
|
|
|
+ // Add logits if requested and available
|
|
|
|
+ wantLogits := true
|
|
|
|
+ if wantLogits && seq.logits != nil {
|
|
|
|
+ // resp.Logits = seq.logits
|
|
|
|
+ seq.logits = nil
|
|
}
|
|
}
|
|
|
|
|
|
select {
|
|
select {
|
|
@@ -350,6 +360,57 @@ func (s *Server) run(ctx context.Context) {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+// TokenData represents probability information for a token
|
|
|
|
+type TokenData struct {
|
|
|
|
+ TokenID int
|
|
|
|
+ Logit float32
|
|
|
|
+ Prob float32
|
|
|
|
+ LogProb float32
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// getTokenProbabilities returns sorted token probabilities for a specific token index
|
|
|
|
+func (s *Server) getTokenProbabilities(seq *Sequence) []TokenData {
|
|
|
|
+ // Get logits for the specific token index
|
|
|
|
+ logits := s.lc.GetLogits()
|
|
|
|
+ seq.logits = make([]float32, len(logits))
|
|
|
|
+ copy(seq.logits, logits)
|
|
|
|
+
|
|
|
|
+ vocabSize := s.model.NumVocab()
|
|
|
|
+ probs := make([]TokenData, vocabSize)
|
|
|
|
+
|
|
|
|
+ // Initialize token data with logits
|
|
|
|
+ for i := 0; i < vocabSize; i++ {
|
|
|
|
+ probs[i] = TokenData{
|
|
|
|
+ TokenID: i,
|
|
|
|
+ Logit: logits[i],
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Sort tokens by logits in descending order
|
|
|
|
+ sort.Slice(probs, func(i, j int) bool {
|
|
|
|
+ return probs[i].Logit > probs[j].Logit
|
|
|
|
+ })
|
|
|
|
+
|
|
|
|
+ // Apply softmax
|
|
|
|
+ maxLogit := probs[0].Logit
|
|
|
|
+ var sum float32 = 0.0
|
|
|
|
+
|
|
|
|
+ for i := range probs {
|
|
|
|
+ p := float32(math.Exp(float64(probs[i].Logit - maxLogit)))
|
|
|
|
+ probs[i].Prob = p
|
|
|
|
+ sum += p
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Normalize probabilities and calculate log probs
|
|
|
|
+ for i := range probs {
|
|
|
|
+ prob := probs[i].Prob / sum
|
|
|
|
+ probs[i].Prob = prob
|
|
|
|
+ probs[i].LogProb = float32(math.Log(float64(prob)))
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return probs
|
|
|
|
+}
|
|
|
|
+
|
|
// TODO (jmorganca): processBatch should be simplified, removing:
|
|
// TODO (jmorganca): processBatch should be simplified, removing:
|
|
// * sampling
|
|
// * sampling
|
|
// * stop token checking
|
|
// * stop token checking
|
|
@@ -483,6 +544,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|
|
|
|
|
seq.numPredicted++
|
|
seq.numPredicted++
|
|
|
|
|
|
|
|
+ // TODO: only do this when flag specified
|
|
|
|
+ probs := s.getTokenProbabilities(seq)
|
|
|
|
+ for i := range 10 {
|
|
|
|
+ slog.Debug("top 10 tokens", "token", probs[i].TokenID, "prob", probs[i].Prob, "logit", probs[i].Logit, "piece", s.model.TokenToPiece(probs[i].TokenID))
|
|
|
|
+ }
|
|
|
|
+
|
|
// if it's an end of sequence token, break
|
|
// if it's an end of sequence token, break
|
|
if s.model.TokenIsEog(token) {
|
|
if s.model.TokenIsEog(token) {
|
|
// TODO (jmorganca): we should send this back
|
|
// TODO (jmorganca): we should send this back
|