|
@@ -8,12 +8,14 @@ import (
|
|
|
"fmt"
|
|
|
"log"
|
|
|
"log/slog"
|
|
|
+ "math"
|
|
|
"net"
|
|
|
"net/http"
|
|
|
"os"
|
|
|
"path/filepath"
|
|
|
"regexp"
|
|
|
"runtime"
|
|
|
+ "sort"
|
|
|
"strconv"
|
|
|
"strings"
|
|
|
"sync"
|
|
@@ -59,7 +61,7 @@ type Sequence struct {
|
|
|
crossAttention bool
|
|
|
|
|
|
// channel to send responses over
|
|
|
- responses chan string
|
|
|
+ responses chan CompletionResponse
|
|
|
|
|
|
// channel to stop decoding (such as if the remote connection is closed)
|
|
|
quit chan bool
|
|
@@ -88,6 +90,15 @@ type Sequence struct {
|
|
|
startGenerationTime time.Time
|
|
|
numDecoded int
|
|
|
numPromptInputs int
|
|
|
+
|
|
|
+ // New flag we need to add to Sequence struct
|
|
|
+ returnLogits bool
|
|
|
+
|
|
|
+ // Using our new GetLogits() method
|
|
|
+ logits []float32
|
|
|
+
|
|
|
+ // Add new channel for logits
|
|
|
+ logitsOut chan []float32
|
|
|
}
|
|
|
|
|
|
type NewSequenceParams struct {
|
|
@@ -96,6 +107,7 @@ type NewSequenceParams struct {
|
|
|
numKeep int
|
|
|
samplingParams *llama.SamplingParams
|
|
|
embedding bool
|
|
|
+ returnLogits bool
|
|
|
}
|
|
|
|
|
|
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
|
|
@@ -149,13 +161,15 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
|
|
startProcessingTime: startTime,
|
|
|
numPredict: params.numPredict,
|
|
|
pendingResponses: make([]string, 0),
|
|
|
- responses: make(chan string, 100),
|
|
|
+ responses: make(chan CompletionResponse, 100),
|
|
|
quit: make(chan bool, 1),
|
|
|
embedding: make(chan []float32, 1),
|
|
|
samplingCtx: sc,
|
|
|
embeddingOnly: params.embedding,
|
|
|
stop: params.stop,
|
|
|
numKeep: params.numKeep,
|
|
|
+ returnLogits: params.returnLogits,
|
|
|
+ logitsOut: make(chan []float32, 100),
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
@@ -274,25 +288,36 @@ func (s *Server) allNil() bool {
|
|
|
}
|
|
|
|
|
|
func flushPending(seq *Sequence) bool {
|
|
|
- joined := strings.Join(seq.pendingResponses, "")
|
|
|
- seq.pendingResponses = []string{}
|
|
|
+ if len(seq.pendingResponses) == 0 {
|
|
|
+ return true
|
|
|
+ }
|
|
|
|
|
|
+ content := strings.Join(seq.pendingResponses, "")
|
|
|
// Check if there are any partial UTF-8 characters remaining.
|
|
|
// We already check and queue as we are generating but some may
|
|
|
// still make it here:
|
|
|
// - Sequence is ending, e.g. generation limit has been hit
|
|
|
// - Invalid characters in the middle of a string
|
|
|
// This is a stricter check to ensure we never output invalid Unicode.
|
|
|
- for !utf8.ValidString(joined) {
|
|
|
- joined = joined[:len(joined)-1]
|
|
|
+ for !utf8.ValidString(content) {
|
|
|
+ content = content[:len(content)-1]
|
|
|
}
|
|
|
+ seq.pendingResponses = nil
|
|
|
|
|
|
- if len(joined) == 0 {
|
|
|
- return true
|
|
|
+ resp := CompletionResponse{
|
|
|
+ Content: content,
|
|
|
}
|
|
|
|
|
|
+ // Add logits if requested and available
|
|
|
+ if seq.returnLogits && seq.logits != nil {
|
|
|
+ slog.Info("returning logits - flushPending")
|
|
|
+ resp.Logits = seq.logits
|
|
|
+ seq.logits = nil
|
|
|
+ }
|
|
|
+
|
|
|
+ slog.Info("returning logits - flushPending", "logits", resp.Logits[0])
|
|
|
select {
|
|
|
- case seq.responses <- joined:
|
|
|
+ case seq.responses <- resp:
|
|
|
return true
|
|
|
case <-seq.quit:
|
|
|
return false
|
|
@@ -476,7 +501,14 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
- // sample a token
|
|
|
+ // Before sampling:
|
|
|
+ if seq.returnLogits { // New flag we need to add to Sequence struct
|
|
|
+ slog.Info("returning logits")
|
|
|
+ seq.logits = s.lc.GetLogits() // Using our new GetLogits() method
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ // Then sample token
|
|
|
token := seq.samplingCtx.Sample(s.lc, seq.iBatch)
|
|
|
seq.samplingCtx.Accept(token, true)
|
|
|
piece := s.model.TokenToPiece(token)
|
|
@@ -572,10 +604,11 @@ type ImageData struct {
|
|
|
}
|
|
|
|
|
|
type CompletionRequest struct {
|
|
|
- Prompt string `json:"prompt"`
|
|
|
- Images []ImageData `json:"image_data"`
|
|
|
- Grammar string `json:"grammar"`
|
|
|
- CachePrompt bool `json:"cache_prompt"`
|
|
|
+ Prompt string `json:"prompt"`
|
|
|
+ Images []ImageData `json:"image_data"`
|
|
|
+ Grammar string `json:"grammar"`
|
|
|
+ CachePrompt bool `json:"cache_prompt"`
|
|
|
+ ReturnLogits bool `json:"return_logits"`
|
|
|
|
|
|
Options
|
|
|
}
|
|
@@ -588,8 +621,10 @@ type Timings struct {
|
|
|
}
|
|
|
|
|
|
type CompletionResponse struct {
|
|
|
- Content string `json:"content"`
|
|
|
- Stop bool `json:"stop"`
|
|
|
+ Content string `json:"content"`
|
|
|
+ Logits []float32 `json:"logits,omitempty"`
|
|
|
+ Tokens []string `json:"tokens,omitempty"`
|
|
|
+ Stop bool `json:"stop"`
|
|
|
|
|
|
Model string `json:"model,omitempty"`
|
|
|
Prompt string `json:"prompt,omitempty"`
|
|
@@ -637,12 +672,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|
|
samplingParams.Seed = uint32(req.Seed)
|
|
|
samplingParams.Grammar = req.Grammar
|
|
|
|
|
|
+ slog.Info("completion request", "return_logits", req.ReturnLogits)
|
|
|
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
|
|
numPredict: req.NumPredict,
|
|
|
stop: req.Stop,
|
|
|
numKeep: req.NumKeep,
|
|
|
samplingParams: &samplingParams,
|
|
|
embedding: false,
|
|
|
+ returnLogits: req.ReturnLogits,
|
|
|
})
|
|
|
if err != nil {
|
|
|
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
|
@@ -691,10 +728,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|
|
close(seq.quit)
|
|
|
return
|
|
|
case content, ok := <-seq.responses:
|
|
|
+ slog.Info("logits in last chan", "content", content.Logits[0])
|
|
|
if ok {
|
|
|
- if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
|
|
- Content: content,
|
|
|
- }); err != nil {
|
|
|
+ slog.Info("content", "content", content.Content)
|
|
|
+ if err := json.NewEncoder(w).Encode(&content); err != nil {
|
|
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
|
|
close(seq.quit)
|
|
|
return
|