Browse Source

kvcache: Support non-causal attention

Models can disable causality for all or part of their processing
while continuing to store data in the KV cache.
Jesse Gross 1 month ago
parent
commit
6da8b6a879
1 changed files with 36 additions and 4 deletions
  1. 36 4
      kvcache/causal.go

+ 36 - 4
kvcache/causal.go

@@ -20,6 +20,7 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
 type Causal struct {
 	DType      ml.DType
 	Capacity   int32
+	causal     bool
 	windowSize int32
 
 	// config controls mostly backend-specific optimizations
@@ -42,6 +43,12 @@ type Causal struct {
 	// locations in the cache that are needed for this batch
 	curCellRange cellRange
 
+	// curSequences is the sequences corresponding to this pass's entries in the cache
+	curSequences []int
+
+	// curPositions is the positions corresponding to this pass's entries in the cache
+	curPositions []int32
+
 	// ** cache metadata **
 
 	// for each possible location in the cache, stores the position and set of sequences
@@ -71,6 +78,7 @@ type cellRange struct {
 
 func NewCausalCache(shift shiftFn) *Causal {
 	return &Causal{
+		causal:     true,
 		windowSize: math.MaxInt32,
 		shiftFn:    shift,
 		ctxs:       make(map[int]ml.Context),
@@ -81,6 +89,7 @@ func NewCausalCache(shift shiftFn) *Causal {
 
 func NewSWACache(windowSize int32, shift shiftFn) *Causal {
 	return &Causal{
+		causal:     true,
 		windowSize: windowSize,
 		shiftFn:    shift,
 		ctxs:       make(map[int]ml.Context),
@@ -133,6 +142,8 @@ func (c *Causal) Close() {
 
 func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
 	c.curBatchSize = len(positions)
+	c.curSequences = seqs
+	c.curPositions = positions
 
 	var err error
 	c.curLoc, err = c.findStartLoc()
@@ -171,7 +182,7 @@ func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) err
 		c.cellRanges[seq] = seqRange
 	}
 
-	c.curMask, err = c.buildMask(ctx, positions, seqs)
+	c.curMask, err = c.buildMask(ctx)
 
 	return err
 }
@@ -212,7 +223,7 @@ func roundUp(length, pad int) int {
 // Builds a mask of history x batch indicating whether for each token in the batch the
 // token in the history should apply. This is based on both the sequence and causality (the
 // position of the history is not ahead of the token in the batch).
-func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
+func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
 	// Align and pad the two dimensions as required by the backend
 	batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
 
@@ -224,8 +235,9 @@ func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Te
 
 	for i := range c.curBatchSize {
 		for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
-			if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
-				c.cells[j].pos < positions[i]-c.windowSize {
+			if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
+				(c.causal && c.cells[j].pos > c.curPositions[i]) ||
+				c.cells[j].pos < c.curPositions[i]-c.windowSize {
 				mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
 			}
 		}
@@ -391,6 +403,26 @@ func (c *Causal) SetLayer(layer int) {
 	c.curLayer = layer
 }
 
+// SetCausal enables or disables causal mask generation for subsequent calls to Get.
+// This state carries over to future forward passes. The default value is true.
+//
+// ctx may be set to nil if this is called from outside of a forward pass, for
+// example, when initializing the cache.
+func (c *Causal) SetCausal(ctx ml.Context, causal bool) {
+	if c.causal != causal {
+		c.causal = causal
+
+		if ctx != nil {
+			var err error
+			c.curMask, err = c.buildMask(ctx)
+			if err != nil {
+				// This error should never occur because we have previously built a mask with the same shape
+				panic(fmt.Errorf("SetCausal: %w", err))
+			}
+		}
+	}
+}
+
 func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
 	key := c.keys[c.curLayer]
 	value := c.values[c.curLayer]