|
@@ -50,8 +50,9 @@ type Sequence struct {
|
|
|
// inputs that have been added to a batch but not yet submitted to Decode
|
|
|
pendingInputs []input
|
|
|
|
|
|
+ // TODO: update this comment
|
|
|
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
|
|
- pendingResponses []string
|
|
|
+ pendingResponses []CompletionResponse
|
|
|
|
|
|
// input cache being used by this sequence
|
|
|
cache *InputCacheSlot
|
|
@@ -87,6 +88,9 @@ type Sequence struct {
|
|
|
|
|
|
logits []float32
|
|
|
|
|
|
+ // number of logprobs to return with the completion response
|
|
|
+ logprobs int
|
|
|
+
|
|
|
// Metrics
|
|
|
startProcessingTime time.Time
|
|
|
startGenerationTime time.Time
|
|
@@ -152,7 +156,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
|
|
numPromptInputs: len(inputs),
|
|
|
startProcessingTime: startTime,
|
|
|
numPredict: params.numPredict,
|
|
|
- pendingResponses: make([]string, 0),
|
|
|
+ pendingResponses: make([]CompletionResponse, 0),
|
|
|
responses: make(chan CompletionResponse, 100),
|
|
|
quit: make(chan bool, 1),
|
|
|
embedding: make(chan []float32, 1),
|
|
@@ -281,8 +285,11 @@ func flushPending(seq *Sequence) bool {
|
|
|
if len(seq.pendingResponses) == 0 {
|
|
|
return true
|
|
|
}
|
|
|
- content := strings.Join(seq.pendingResponses, "")
|
|
|
- seq.pendingResponses = []string{}
|
|
|
+ content := ""
|
|
|
+ for _, resp := range seq.pendingResponses {
|
|
|
+ content += resp.Content
|
|
|
+ }
|
|
|
+ seq.pendingResponses = []CompletionResponse{}
|
|
|
|
|
|
// Check if there are any partial UTF-8 characters remaining.
|
|
|
// We already check and queue as we are generating but some may
|
|
@@ -362,27 +369,27 @@ func (s *Server) run(ctx context.Context) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-// TokenData represents probability information for a token
|
|
|
-type TokenData struct {
|
|
|
+// TokenProbs represents probability information for a token
|
|
|
+type TokenProbs struct {
|
|
|
TokenID int
|
|
|
Logit float32
|
|
|
Prob float32
|
|
|
LogProb float32
|
|
|
}
|
|
|
|
|
|
-// getTokenProbabilities returns sorted token probabilities for a specific token index
|
|
|
-func (s *Server) getTokenProbabilities(seq *Sequence) []TokenData {
|
|
|
+// probs returns sorted token probabilities for a specific token index
|
|
|
+func (s *Server) probs(seq *Sequence) []TokenProbs {
|
|
|
// Get logits for the specific token index
|
|
|
logits := s.lc.GetLogits()
|
|
|
seq.logits = make([]float32, len(logits))
|
|
|
copy(seq.logits, logits)
|
|
|
|
|
|
vocabSize := s.model.NumVocab()
|
|
|
- probs := make([]TokenData, vocabSize)
|
|
|
+ probs := make([]TokenProbs, vocabSize)
|
|
|
|
|
|
// Initialize token data with logits
|
|
|
for i := 0; i < vocabSize; i++ {
|
|
|
- probs[i] = TokenData{
|
|
|
+ probs[i] = TokenProbs{
|
|
|
TokenID: i,
|
|
|
Logit: logits[i],
|
|
|
}
|
|
@@ -546,10 +553,9 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|
|
|
|
|
seq.numPredicted++
|
|
|
|
|
|
- // TODO: only do this when flag specified
|
|
|
- probs := s.getTokenProbabilities(seq)
|
|
|
- for i := range 10 {
|
|
|
- slog.Debug("top 10 tokens", "token", probs[i].TokenID, "prob", probs[i].Prob, "logit", probs[i].Logit, "piece", s.model.TokenToPiece(probs[i].TokenID))
|
|
|
+ if seq.logprobs > 0 {
|
|
|
+ // TODO: return selected token in logprobs always
|
|
|
+ // probs := s.probs(seq)
|
|
|
}
|
|
|
|
|
|
// if it's an end of sequence token, break
|
|
@@ -564,8 +570,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|
|
|
|
|
seq.inputs = []input{{token: token}}
|
|
|
|
|
|
- seq.pendingResponses = append(seq.pendingResponses, piece)
|
|
|
- sequence := strings.Join(seq.pendingResponses, "")
|
|
|
+ // TODO: add probs here
|
|
|
+ seq.pendingResponses = append(seq.pendingResponses, CompletionResponse{Content: piece})
|
|
|
+ var sequence string
|
|
|
+ for _, r := range seq.pendingResponses {
|
|
|
+ sequence += r.Content
|
|
|
+ }
|
|
|
|
|
|
if ok, stop := findStop(sequence, seq.stop); ok {
|
|
|
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|