|
@@ -19,6 +19,7 @@ import (
|
|
|
"os/signal"
|
|
|
"path/filepath"
|
|
|
"slices"
|
|
|
+ "sort"
|
|
|
"strings"
|
|
|
"syscall"
|
|
|
"time"
|
|
@@ -299,7 +300,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|
|
Images: images,
|
|
|
Format: req.Format,
|
|
|
Options: opts,
|
|
|
- ReturnLogits: req.ReturnLogits,
|
|
|
+ ReturnLogits: false,
|
|
|
}, func(cr llm.CompletionResponse) {
|
|
|
res := api.GenerateResponse{
|
|
|
Model: req.Model,
|
|
@@ -1554,23 +1555,27 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|
|
Format: req.Format,
|
|
|
Options: opts,
|
|
|
ReturnLogits: true,
|
|
|
- }, func(r llm.CompletionResponse) {
|
|
|
+ }, func(cr llm.CompletionResponse) {
|
|
|
res := api.ChatResponse{
|
|
|
Model: req.Model,
|
|
|
CreatedAt: time.Now().UTC(),
|
|
|
- Message: api.Message{Role: "assistant", Content: r.Content},
|
|
|
- Done: r.Done,
|
|
|
- DoneReason: r.DoneReason,
|
|
|
- Logits: r.Logits,
|
|
|
+ Message: api.Message{Role: "assistant", Content: cr.Content},
|
|
|
+ Done: cr.Done,
|
|
|
+ DoneReason: cr.DoneReason,
|
|
|
+ Logits: []float32{},
|
|
|
Metrics: api.Metrics{
|
|
|
- PromptEvalCount: r.PromptEvalCount,
|
|
|
- PromptEvalDuration: r.PromptEvalDuration,
|
|
|
- EvalCount: r.EvalCount,
|
|
|
- EvalDuration: r.EvalDuration,
|
|
|
+ PromptEvalCount: cr.PromptEvalCount,
|
|
|
+ PromptEvalDuration: cr.PromptEvalDuration,
|
|
|
+ EvalCount: cr.EvalCount,
|
|
|
+ EvalDuration: cr.EvalDuration,
|
|
|
},
|
|
|
}
|
|
|
|
|
|
- if r.Done {
|
|
|
+ topK := int(3)
|
|
|
+ logits := make([]float32, len(cr.Logits))
|
|
|
+ copy(logits, cr.Logits)
|
|
|
+ res.TopLogprobs = getTopKLogProbs(c.Request.Context(), r, logits, topK)
|
|
|
+ if cr.Done {
|
|
|
res.TotalDuration = time.Since(checkpointStart)
|
|
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
|
|
}
|
|
@@ -1586,7 +1591,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|
|
// Streaming tool calls:
|
|
|
// If tools are recognized, use a flag to track the sending of a tool downstream
|
|
|
// This ensures that content is cleared from the message on the last chunk sent
|
|
|
- sb.WriteString(r.Content)
|
|
|
+ sb.WriteString(cr.Content)
|
|
|
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
|
|
res.Message.ToolCalls = toolCalls
|
|
|
for i := range toolCalls {
|
|
@@ -1599,7 +1604,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- if r.Done {
|
|
|
+ if cr.Done {
|
|
|
// Send any remaining content if no tool calls were detected
|
|
|
if toolCallIndex == 0 {
|
|
|
res.Message.Content = sb.String()
|
|
@@ -1649,6 +1654,48 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|
|
streamResponse(c, ch)
|
|
|
}
|
|
|
|
|
|
+func getTopKLogProbs(ctx context.Context, s llm.LlamaServer, logits []float32, topK int) []api.TokenLogprob {
|
|
|
+ // Calculate softmax denominator first (log sum exp trick for numerical stability)
|
|
|
+ maxLogit := float32(math.Inf(-1))
|
|
|
+ for _, logit := range logits {
|
|
|
+ if logit > maxLogit {
|
|
|
+ maxLogit = logit
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ var sumExp float32
|
|
|
+ for _, logit := range logits {
|
|
|
+ sumExp += float32(math.Exp(float64(logit - maxLogit)))
|
|
|
+ }
|
|
|
+ logSumExp := float32(math.Log(float64(sumExp))) + maxLogit
|
|
|
+
|
|
|
+ // Calculate log probs and track top K
|
|
|
+ logProbs := make([]api.TokenLogprob, len(logits))
|
|
|
+ for i, logit := range logits {
|
|
|
+ text, err := s.Detokenize(ctx, []int{i})
|
|
|
+ if err != nil {
|
|
|
+ slog.Error("detokenize error for logprob", "error", err)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ logProbs[i] = api.TokenLogprob{
|
|
|
+ Text: text,
|
|
|
+ Logprob: logit - logSumExp,
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Sort by logprob descending and take top K
|
|
|
+ sort.Slice(logProbs, func(i, j int) bool {
|
|
|
+ return logProbs[i].Logprob > logProbs[j].Logprob
|
|
|
+ })
|
|
|
+
|
|
|
+ if len(logProbs) > topK {
|
|
|
+ logProbs = logProbs[:topK]
|
|
|
+ }
|
|
|
+
|
|
|
+ return logProbs
|
|
|
+}
|
|
|
+
|
|
|
func handleScheduleError(c *gin.Context, name string, err error) {
|
|
|
switch {
|
|
|
case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
|