瀏覽代碼

pr feedback

Bruce MacDonald 2 月之前
父節點
當前提交
376ecde481
共有 2 個文件被更改,包括 31 次插入12 次删除
  1. 24 11
      runner/llamarunner/cache.go
  2. 7 1
      runner/llamarunner/runner.go

+ 24 - 11
runner/llamarunner/cache.go

@@ -213,6 +213,16 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
 	return discard
 }
 
+type ErrReprocessInputs struct {
+	Inputs []input
+	SlotId int
+}
+
+func (e *ErrReprocessInputs) Error() string {
+	return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (slot: %v, input count: %v)",
+		e.SlotId, len(e.Inputs))
+}
+
 // ShiftCacheSlot frees up space in the KV cache by deleting the oldest half of history
 // and shifting the newest half into that space (saving numKeep inputs at the beginning).
 //
@@ -251,17 +261,20 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
 
 		// Update the slot.Inputs to match what would happen with a shift operation
 		// Keep the first numKeep tokens and the tokens after the discard
-		keepInputs := make([]input, numKeep)
-		copy(keepInputs, slot.Inputs[:numKeep])
-
-		afterDiscardInputs := make([]input, len(slot.Inputs)-(numKeep+discard))
-		copy(afterDiscardInputs, slot.Inputs[numKeep+discard:])
-
-		// Update the inputs to match what would happen after a shift
-		newInputs := make([]input, 0, numKeep+len(afterDiscardInputs))
-		newInputs = append(newInputs, keepInputs...)
-		newInputs = append(newInputs, afterDiscardInputs...)
-		slot.Inputs = newInputs
+		newInputs := make([]input, numKeep+len(slot.Inputs)-(numKeep+discard))
+		copy(newInputs[:numKeep], slot.Inputs[:numKeep])
+		copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
+
+		// Update the slot.Inputs to be empty since we've cleared the cache
+		// The transformer will rebuild these as the inputs are processed
+		slot.Inputs = []input{}
+
+		// Return the inputs that need to be reprocessed
+		// The caller will need to prepend these to the sequence's inputs queue
+		return &ErrReprocessInputs{
+			Inputs: newInputs,
+			SlotId: slot.Id,
+		}
 	}
 
 	return nil

+ 7 - 1
runner/llamarunner/runner.go

@@ -388,7 +388,13 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 				if len(seq.pendingInputs) == 0 {
 					err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
 					if err != nil {
-						return err
+						if inr, ok := err.(*ErrReprocessInputs); ok {
+							// Prepend these inputs to the sequence's inputs queue for reprocessing
+							seq.inputs = append(inr.Inputs, seq.inputs...)
+							// Continue processing as normal
+						} else {
+							return err
+						}
 					}
 				} else {
 					break