Prechádzať zdrojové kódy

kvcache: Optimize sliding window attention

Currently sliding window attention allocates and uses the full
context size and just masks out any tokens that are outside of the
window. However, we really only need (roughly) the sliding window
size.

At large context sizes this improves two things:
 - Memory allocated - since the fully context size is allocated up front,
   memory requirements drop substantially. On Gemma3:4b with a 32k
   context window, total memory usage (including weights and non-sliding
   layers) drops from ~20GB to ~8GB.
 - Computation - ranges that are completely outside of the sliding
   window are now removed from the tensors that are returned from the
   cache rather than simply being masked out. This results in more
   efficient processing, scaling with the size of the context that
   has actually been used.

Notable, this does not update the scheduler for any model to be aware of
the smaller memory requirements. This is difficult for Gemma3 because
the layers are heterogeneous between sliding and non-sliding attention.
As a result, while actual memory consumption will be reduced, the
scheduler will over-estimate the requirements of the model. This means
that splitting between GPUs or GPUs and CPUs will still be suboptimal.

Bug #9730
Jesse Gross 1 mesiac pred
rodič
commit
2d6eac9084
2 zmenil súbory, kde vykonal 64 pridanie a 3 odobranie
  1. 52 1
      kvcache/causal.go
  2. 12 2
      kvcache/causal_test.go

+ 52 - 1
kvcache/causal.go

@@ -118,7 +118,12 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
 		c.config.MaskDType = ml.DTypeF32
 	}
 
-	cacheSize := maxSequences * capacity
+	var cacheSize int
+	if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize)+maxBatch {
+		cacheSize = maxSequences * capacity
+	} else {
+		cacheSize = maxSequences * (int(c.windowSize) + maxBatch)
+	}
 	cacheSize = roundUp(cacheSize, c.config.CachePadding)
 	c.cells = make([]cacheCell, cacheSize)
 
@@ -147,6 +152,8 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
 	c.curPositions = batch.Positions
 	c.opts.Except = nil
 
+	c.updateSlidingWindow()
+
 	var err error
 	c.curLoc, err = c.findStartLoc()
 	if errors.Is(err, ErrKvCacheFull) {
@@ -214,6 +221,50 @@ func (c *Causal) findStartLoc() (int, error) {
 	return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells))
 }
 
+func (c *Causal) updateSlidingWindow() {
+	if c.windowSize == math.MaxInt32 {
+		return
+	}
+
+	// create a map of unique sequences to the lowest position in that sequence
+	lowestPos := make(map[int]int32)
+	for i := range c.curPositions {
+		seq := c.curSequences[i]
+
+		pos, ok := lowestPos[seq]
+		if !ok {
+			pos = c.curPositions[i]
+		} else if c.curPositions[i] < pos {
+			pos = c.curPositions[i]
+		}
+
+		lowestPos[seq] = pos
+	}
+
+	// delete any entries that are beyond the window of the oldest position in the sequence
+	for seq, pos := range lowestPos {
+		oldRange, ok := c.cellRanges[seq]
+		if !ok {
+			continue
+		}
+
+		newRange := newRange()
+
+		for i := oldRange.min; i <= oldRange.max; i++ {
+			if slices.Contains(c.cells[i].sequences, seq) {
+				if c.cells[i].pos < pos-c.windowSize {
+					c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
+				} else {
+					newRange.min = min(newRange.min, i)
+					newRange.max = max(newRange.max, i)
+				}
+			}
+		}
+
+		c.cellRanges[seq] = newRange
+	}
+}
+
 func roundDown(length, pad int) int {
 	return (length / pad) * pad
 }

+ 12 - 2
kvcache/causal_test.go

@@ -58,11 +58,11 @@ func TestSWA(t *testing.T) {
 	cache := NewSWACache(1, nil)
 	defer cache.Close()
 
-	cache.Init(backend, ml.DTypeF32, 1, 16, 16)
+	cache.Init(backend, ml.DTypeF16, 1, 16, 16)
 
 	tests := []testCase{
 		{
-			name:          "SlidingWindow",
+			name:          "FirstBatch",
 			in:            []float32{1, 2, 3, 4},
 			inShape:       []int{1, 1, 4},
 			seqs:          []int{0, 0, 0, 0},
@@ -71,6 +71,16 @@ func TestSWA(t *testing.T) {
 			expectedShape: []int{1, 1, 4},
 			expectedMask:  []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
 		},
+		{
+			name:          "SecondBatch",
+			in:            []float32{5, 6},
+			inShape:       []int{1, 1, 2},
+			seqs:          []int{0, 0},
+			pos:           []int32{4, 5},
+			expected:      []float32{5, 6, 3, 4},
+			expectedShape: []int{1, 1, 4},
+			expectedMask:  []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))},
+		},
 	}
 
 	testCache(t, backend, cache, tests)