|
@@ -213,8 +213,8 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
|
|
return discard
|
|
return discard
|
|
}
|
|
}
|
|
|
|
|
|
-// Frees up space in the KV cache by deleting the oldest half of history and shifting
|
|
|
|
-// the newest half into that space (saving numKeep inputs at the beginning).
|
|
|
|
|
|
+// ShiftCacheSlot frees up space in the KV cache by deleting the oldest half of history
|
|
|
|
+// and shifting the newest half into that space (saving numKeep inputs at the beginning).
|
|
//
|
|
//
|
|
// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
|
|
// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
|
|
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
|
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
|
@@ -231,16 +231,35 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
|
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
|
|
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
|
|
"keep", numKeep, "discard", discard)
|
|
"keep", numKeep, "discard", discard)
|
|
|
|
|
|
- // TODO (jessegross): KV cache removal can fail for certain types of models
|
|
|
|
- if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
|
|
|
|
- return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard)
|
|
|
|
- }
|
|
|
|
- c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard)
|
|
|
|
|
|
+ if c.lc.KvCacheCanShift() {
|
|
|
|
+ if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
|
|
|
|
+ return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard)
|
|
|
|
+ }
|
|
|
|
+ c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard)
|
|
|
|
+
|
|
|
|
+ for i := numKeep + discard; i < len(slot.Inputs); i++ {
|
|
|
|
+ slot.Inputs[i-discard] = slot.Inputs[i]
|
|
|
|
+ }
|
|
|
|
+ slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard]
|
|
|
|
+ } else {
|
|
|
|
+ slog.Debug("kv cache cannot shift, clearing cache and truncating history")
|
|
|
|
+
|
|
|
|
+ // Clear the entire KV cache
|
|
|
|
+ if !c.lc.KvCacheSeqRm(slot.Id, 0, -1) {
|
|
|
|
+ return fmt.Errorf("unable to remove kv cache entries (id: %v)", slot.Id)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Update the slot.Inputs to match what would happen with a shift operation
|
|
|
|
+ // Keep the first numKeep tokens and the tokens after the discard
|
|
|
|
+ keepInputs := make([]input, numKeep)
|
|
|
|
+ copy(keepInputs, slot.Inputs[:numKeep])
|
|
|
|
+
|
|
|
|
+ afterDiscardInputs := make([]input, len(slot.Inputs)-(numKeep+discard))
|
|
|
|
+ copy(afterDiscardInputs, slot.Inputs[numKeep+discard:])
|
|
|
|
|
|
- for i := numKeep + discard; i < len(slot.Inputs); i++ {
|
|
|
|
- slot.Inputs[i-discard] = slot.Inputs[i]
|
|
|
|
|
|
+ // Update the inputs to match what would happen after a shift
|
|
|
|
+ slot.Inputs = append(keepInputs, afterDiscardInputs...)
|
|
}
|
|
}
|
|
- slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard]
|
|
|
|
|
|
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|