Browse Source

runner.go: Don't add inputs to cache view until actually processed

We need to track which tokens are in the cache ourselves. We currently
add tokens to the cache tracker when we add them to batch but they are
not actually in the cache until we call Decode. This can cause
confusion when we are shifting the cache.

Avoids "could not find a KV slot for the batch" issues.

Bug #7545
Jesse Gross 5 months ago
parent
commit
c3ff916431
2 changed files with 31 additions and 18 deletions
  1. 12 4
      llama/runner/cache.go
  2. 19 14
      llama/runner/runner.go

+ 12 - 4
llama/runner/cache.go

@@ -203,7 +203,11 @@ func countCommonPrefix(a []input, b []input) int {
 // the newest half into that space (saving numKeep inputs at the beginning).
 //
 // Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
-func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) {
+func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
+	if numKeep >= c.numCtx {
+		return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
+	}
+
 	targetFree := (c.numCtx - numKeep) / 2
 	targetFree = max(targetFree, 1)
 
@@ -211,18 +215,22 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) {
 	discard := targetFree - currentFree
 
 	if discard <= 0 {
-		return
+		return nil
 	}
 
-	slog.Debug("context limit hit - shifting", "limit", c.numCtx, "input", len(slot.Inputs),
+	slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
 		"keep", numKeep, "discard", discard)
 
 	// TODO (jessegross): KV cache removal can fail for certain types of models
-	c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard)
+	if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
+		return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard)
+	}
 	c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard)
 
 	for i := numKeep + discard; i < len(slot.Inputs); i++ {
 		slot.Inputs[i-discard] = slot.Inputs[i]
 	}
 	slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard]
+
+	return nil
 }

+ 19 - 14
llama/runner/runner.go

@@ -45,6 +45,9 @@ type Sequence struct {
 	// prompt inputs left to evaluate
 	inputs []input
 
+	// inputs that have been added to a batch but not yet submitted to Decode
+	pendingInputs []input
+
 	// tokens that have been generated but not returned yet (e.g. for stop sequences)
 	pendingResponses []string
 
@@ -367,14 +370,13 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 			continue
 		}
 
-		var numInputsProcessed int
-		shifted := false
-
 		for i, input := range seq.inputs {
-			if len(seq.cache.Inputs)+1 > s.cache.numCtx {
-				if !shifted {
-					s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
-					shifted = true
+			if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx {
+				if len(seq.pendingInputs) == 0 {
+					err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
+					if err != nil {
+						return err
+					}
 				} else {
 					break
 				}
@@ -403,15 +405,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 			}
 
 			crossAttention = seq.crossAttention
-			batch.Add(input.token, input.embed, len(seq.cache.Inputs), i+1 == len(seq.inputs), seq.cache.Id)
-			seq.cache.Inputs = append(seq.cache.Inputs, input)
-			numInputsProcessed++
-		}
-
-		if numInputsProcessed > 0 {
-			seq.inputs = seq.inputs[numInputsProcessed:]
+			batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), i+1 == len(seq.inputs), seq.cache.Id)
+			seq.pendingInputs = append(seq.pendingInputs, input)
 			seq.iBatch = batch.NumTokens() - 1
 		}
+
+		seq.inputs = seq.inputs[len(seq.pendingInputs):]
 	}
 
 	if batch == nil || batch.NumTokens() == 0 {
@@ -444,6 +443,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 			continue
 		}
 
+		// After calling Decode, pending inputs are now in the cache
+		if len(seq.pendingInputs) > 0 {
+			seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
+			seq.pendingInputs = []input{}
+		}
+
 		// don't sample prompt processing
 		if len(seq.inputs) != 0 {
 			continue