|
@@ -213,6 +213,16 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
|
|
|
return discard
|
|
|
}
|
|
|
|
|
|
+type ErrReprocessInputs struct {
|
|
|
+ Inputs []input
|
|
|
+ SlotId int
|
|
|
+}
|
|
|
+
|
|
|
+func (e *ErrReprocessInputs) Error() string {
|
|
|
+ return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (slot: %v, input count: %v)",
|
|
|
+ e.SlotId, len(e.Inputs))
|
|
|
+}
|
|
|
+
|
|
|
// ShiftCacheSlot frees up space in the KV cache by deleting the oldest half of history
|
|
|
// and shifting the newest half into that space (saving numKeep inputs at the beginning).
|
|
|
//
|
|
@@ -251,17 +261,20 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
|
|
|
|
|
// Update the slot.Inputs to match what would happen with a shift operation
|
|
|
// Keep the first numKeep tokens and the tokens after the discard
|
|
|
- keepInputs := make([]input, numKeep)
|
|
|
- copy(keepInputs, slot.Inputs[:numKeep])
|
|
|
-
|
|
|
- afterDiscardInputs := make([]input, len(slot.Inputs)-(numKeep+discard))
|
|
|
- copy(afterDiscardInputs, slot.Inputs[numKeep+discard:])
|
|
|
-
|
|
|
- // Update the inputs to match what would happen after a shift
|
|
|
- newInputs := make([]input, 0, numKeep+len(afterDiscardInputs))
|
|
|
- newInputs = append(newInputs, keepInputs...)
|
|
|
- newInputs = append(newInputs, afterDiscardInputs...)
|
|
|
- slot.Inputs = newInputs
|
|
|
+ newInputs := make([]input, numKeep+len(slot.Inputs)-(numKeep+discard))
|
|
|
+ copy(newInputs[:numKeep], slot.Inputs[:numKeep])
|
|
|
+ copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
|
|
|
+
|
|
|
+ // Update the slot.Inputs to be empty since we've cleared the cache
|
|
|
+ // The transformer will rebuild these as the inputs are processed
|
|
|
+ slot.Inputs = []input{}
|
|
|
+
|
|
|
+ // Return the inputs that need to be reprocessed
|
|
|
+ // The caller will need to prepend these to the sequence's inputs queue
|
|
|
+ return &ErrReprocessInputs{
|
|
|
+ Inputs: newInputs,
|
|
|
+ SlotId: slot.Id,
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
return nil
|