소스 검색

kvcache: Account for source tensors in defrag operation count

Defragging the KV cache can generate a lot of operations, so we
need to be careful that we don't overflow the number that the graph
can support. We currently account for all of the nodes that we add
to the graph for each move but we also need to include the original
cache tensors as well.

Fixes #9904
Jesse Gross 1 개월 전
부모
커밋
d3e9ca3eda
1개의 변경된 파일3개의 추가작업 그리고 2개의 파일을 삭제
  1. 3 2
      kvcache/causal.go

+ 3 - 2
kvcache/causal.go

@@ -321,7 +321,8 @@ func (c *Causal) defrag() {
 	ctx := c.backend.NewContext()
 
 	// For every move, 6 tensors are required per layer (2 views and a
-	// copy for each of k and v).
+	// copy for each of k and v). We also need to refer to the original
+	// k and v cache tensors - once per layer, not per move.
 	layers := 0
 	for _, key := range c.keys {
 		if key == nil {
@@ -330,7 +331,7 @@ func (c *Causal) defrag() {
 		layers++
 	}
 
-	maxMoves := ctx.MaxGraphNodes() / (6 * layers)
+	maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
 	moves := 0
 
 	var pendingSrc, pendingDst, pendingLen int