|
@@ -321,7 +321,8 @@ func (c *Causal) defrag() {
|
|
ctx := c.backend.NewContext()
|
|
ctx := c.backend.NewContext()
|
|
|
|
|
|
// For every move, 6 tensors are required per layer (2 views and a
|
|
// For every move, 6 tensors are required per layer (2 views and a
|
|
- // copy for each of k and v).
|
|
|
|
|
|
+ // copy for each of k and v). We also need to refer to the original
|
|
|
|
+ // k and v cache tensors - once per layer, not per move.
|
|
layers := 0
|
|
layers := 0
|
|
for _, key := range c.keys {
|
|
for _, key := range c.keys {
|
|
if key == nil {
|
|
if key == nil {
|
|
@@ -330,7 +331,7 @@ func (c *Causal) defrag() {
|
|
layers++
|
|
layers++
|
|
}
|
|
}
|
|
|
|
|
|
- maxMoves := ctx.MaxGraphNodes() / (6 * layers)
|
|
|
|
|
|
+ maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
|
|
moves := 0
|
|
moves := 0
|
|
|
|
|
|
var pendingSrc, pendingDst, pendingLen int
|
|
var pendingSrc, pendingDst, pendingLen int
|