浏览代码

runner.go: Don't add inputs to cache view until actually processed

We need to track which tokens are in the cache ourselves. We currently
add tokens to the cache tracker when we add them to batch but they are
not actually in the cache until we call Decode. This can cause
confusion when we are shifting the cache.

Avoids "could not find a KV slot for the batch" issues.

Bug #7545
Jesse Gross 5 月之前
父节点
当前提交
c3ff916431
共有 2 个文件被更改,包括 31 次插入18 次删除
  1. 12 4
      llama/runner/cache.go
  2. 19 14
      llama/runner/runner.go

+ 12 - 4
llama/runner/cache.go

@@ -203,7 +203,11 @@ func countCommonPrefix(a []input, b []input) int {
 // the newest half into that space (saving numKeep inputs at the beginning).
 //
 // Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
-func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) {
+func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
+	if numKeep >= c.numCtx {
+		return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
+	}
+
 	targetFree := (c.numCtx - numKeep) / 2
 	targetFree = max(targetFree, 1)
 
@@ -211,18 +215,22 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) {
 	discard := targetFree - currentFree
 
 	if discard <= 0 {
-		return
+		return nil
 	}
 
-	slog.Debug("context limit hit - shifting", "limit", c.numCtx, "input", len(slot.Inputs),
+	slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
 		"keep", numKeep, "discard", discard)
 
 	// TODO (jessegross): KV cache removal can fail for certain types of models
-	c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard)
+	if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
+		return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard)
+	}
 	c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard)
 
 	for i := numKeep + discard; i < len(slot.Inputs); i++ {
 		slot.Inputs[i-discard] = slot.Inputs[i]
 	}
 	slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard]
+
+	return nil
 }

+ 19 - 14
llama/runner/runner.go

@@ -45,6 +45,9 @@ type Sequence struct {
 	// prompt inputs left to evaluate
 	inputs []input
 
+	// inputs that have been added to a batch but not yet submitted to Decode
+	pendingInputs []input
+
 	// tokens that have been generated but not returned yet (e.g. for stop sequences)
 	pendingResponses []string
 
@@ -367,14 +370,13 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 			continue
 		}
 
-		var numInputsProcessed int
-		shifted := false
-
 		for i, input := range seq.inputs {
-			if len(seq.cache.Inputs)+1 > s.cache.numCtx {
-				if !shifted {
-					s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
-					shifted = true
+			if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx {
+				if len(seq.pendingInputs) == 0 {
+					err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
+					if err != nil {
+						return err
+					}
 				} else {
 					break
 				}
@@ -403,15 +405,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 			}
 
 			crossAttention = seq.crossAttention
-			batch.Add(input.token, input.embed, len(seq.cache.Inputs), i+1 == len(seq.inputs), seq.cache.Id)
-			seq.cache.Inputs = append(seq.cache.Inputs, input)
-			numInputsProcessed++
-		}
-
-		if numInputsProcessed > 0 {
-			seq.inputs = seq.inputs[numInputsProcessed:]
+			batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), i+1 == len(seq.inputs), seq.cache.Id)
+			seq.pendingInputs = append(seq.pendingInputs, input)
 			seq.iBatch = batch.NumTokens() - 1
 		}
+
+		seq.inputs = seq.inputs[len(seq.pendingInputs):]
 	}
 
 	if batch == nil || batch.NumTokens() == 0 {
@@ -444,6 +443,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 			continue
 		}
 
+		// After calling Decode, pending inputs are now in the cache
+		if len(seq.pendingInputs) > 0 {
+			seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
+			seq.pendingInputs = []input{}
+		}
+
 		// don't sample prompt processing
 		if len(seq.inputs) != 0 {
 			continue