Parcourir la source

attention: Remove unnecessary contiguous operations

Prior to performing attention, we need to permute query, key
and value. Currently we call Contiguous after each of these
permutations, which is correct but expensive. Avoiding the
3 calls to Contiguous increases performance by over 20%.

The permutations of query and key do not violate the continuity
rules for mulmat and the Contiguous call can be simply removed.

Value requires a different permutation and does require Contiguous.
However, we can use the copy into the cache as a way to perform this
without further overhead.

To support this and avoid unexpected tensor shapes that are seen by
models, we need tighter integration between attention, cache
and backend. Future optimization will also likely need this structure
 - for example, flash attention has special padding requirements in
the cache and other backends may have their own needs.

This further contains the operations that go into attention so that
these and other optimizations can be handled transparently. Models
that have special requirements for attention can still implement
their own version of it.
Jesse Gross il y a 2 mois
Parent
commit
854a9195f3

+ 11 - 0
kvcache/cache.go

@@ -29,6 +29,17 @@ type Cache interface {
 	// cache implementation used.
 	Put(ctx ml.Context, key, value ml.Tensor)
 
+	// SetConfig controls optimizations (mostly backend-specific) that may transform
+	// the output of the cache to work better with specific kernels. If not called,
+	// the backend settings will be used. This works well when calling Attention.
+	//
+	// The config can be overridden by models, especially if they require vanilla
+	// output when implementing their own version of attention. To do this, pass
+	// an empty ml.CacheConfig.
+	//
+	// Most models will not need to use this.
+	SetConfig(ml.CacheConfig)
+
 	// ** cache management **
 
 	// Init sets up runtime parameters

+ 137 - 37
kvcache/causal.go

@@ -22,6 +22,9 @@ type Causal struct {
 	Capacity   int32
 	windowSize int32
 
+	// config controls mostly backend-specific optimizations
+	config *ml.CacheConfig
+
 	// ** current forward pass **
 
 	// the active layer for Get and Put
@@ -75,14 +78,34 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
 }
 
 func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
+	if c.config == nil {
+		var config ml.CacheConfig
+		if cc, ok := backend.(ml.BackendCacheConfig); ok {
+			config = cc.CacheConfig()
+		}
+		c.config = &config
+	}
+
+	if c.config.CachePadding == 0 {
+		c.config.CachePadding = 1
+	}
+
 	c.DType = dtype
-	c.Capacity = capacity
-	c.cells = make([]cacheCell, capacity)
+	c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
+	c.cells = make([]cacheCell, c.Capacity)
 	c.cellRanges = make(map[int]cellRange)
 	c.backend = backend
 	c.cacheCtx = backend.NewContext()
 }
 
+func (c *Causal) SetConfig(config ml.CacheConfig) {
+	if c.config != nil {
+		panic("config cannot be changed after being previously set, either by the model or backend")
+	}
+
+	c.config = &config
+}
+
 func (c *Causal) Close() {
 	c.cacheCtx.Close()
 }
@@ -157,36 +180,73 @@ func (c *Causal) findStartLoc() (int, error) {
 	return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
 }
 
+func roundDown(length, pad int) int {
+	return (length / pad) * pad
+}
+
+func roundUp(length, pad int) int {
+	return ((length + pad - 1) / pad) * pad
+}
+
 // Builds a mask of history x batch indicating whether for each token in the batch the
 // token in the history should apply. This is based on both the sequence and causality (the
 // position of the history is not ahead of the token in the batch).
 func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
-	// TODO(jessegross): This does not do padding, which is required for flash attention
-	len := c.curCellRange.max - c.curCellRange.min + 1
-	mask := make([]float32, c.curBatchSize*len)
+	// TODO(jessegross): This does not do mask padding, which is required for flash attention
+	// Align and pad the cache range as required by the backend
+	c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
+	c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
+
+	length := c.curCellRange.max - c.curCellRange.min + 1
+	mask := make([]float32, c.curBatchSize*length)
 
 	for i := range c.curBatchSize {
 		for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
 			if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
 				c.cells[j].pos < positions[i]-c.windowSize {
-				mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1))
+				mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
 			}
 		}
 	}
 
-	return ctx.FromFloatSlice(mask, len, c.curBatchSize)
+	return ctx.FromFloatSlice(mask, length, c.curBatchSize)
 }
 
-func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) {
-	for _, obj := range objs {
-		if obj == nil {
+func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
+	for i := range c.keys {
+		if c.keys[i] == nil {
 			continue
 		}
 
-		srcView := obj.View(ctx, obj.Stride(2)*src, obj.Dim(0)*obj.Dim(1)*len)
-		dstView := obj.View(ctx, obj.Stride(2)*dst, obj.Dim(0)*obj.Dim(1)*len)
+		key := c.keys[i]
+
+		kHeadDim := key.Dim(0)
+		numKVHeads := key.Dim(1)
+		rowSize := key.Stride(2)
+
+		kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
+		kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
+
+		value := c.values[i]
+		var vSrcView, vDstView ml.Tensor
+		if c.config.PermutedV {
+			vHeadDim := value.Dim(1)
+			elemSize := value.Stride(0)
+
+			vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
+			vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
+		} else {
+			vHeadDim := value.Dim(0)
+			rowSize := value.Stride(2)
 
-		ctx.Forward(srcView.Copy(ctx, dstView))
+			vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
+			vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
+		}
+
+		ctx.Forward(
+			kSrcView.Copy(ctx, kDstView),
+			vSrcView.Copy(ctx, vDstView),
+		)
 	}
 }
 
@@ -238,8 +298,7 @@ func (c *Causal) defrag() {
 							pendingLen++
 							break
 						} else {
-							moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
-							moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
+							c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
 							moves++
 						}
 					}
@@ -263,8 +322,7 @@ func (c *Causal) defrag() {
 	}
 
 	if pendingLen > 0 {
-		moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
-		moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
+		c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
 		moves++
 	}
 
@@ -305,35 +363,73 @@ func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
 	key := c.keys[c.curLayer]
 	value := c.values[c.curLayer]
 
-	key = key.View(ctx, key.Stride(2)*c.curCellRange.min,
-		key.Dim(0), key.Stride(1),
-		key.Dim(1), key.Stride(2),
-		c.curMask.Dim(0),
-	)
+	kHeadDim := key.Dim(0)
+	numKVHeads := key.Dim(1)
+	rowSize := key.Stride(2)
+	cachedSize := c.curMask.Dim(0)
 
-	value = value.View(ctx, key.Stride(2)*c.curCellRange.min,
-		value.Dim(0), value.Stride(1),
-		value.Dim(1), value.Stride(2),
-		c.curMask.Dim(0),
+	key = key.View(ctx, rowSize*c.curCellRange.min,
+		kHeadDim, key.Stride(1),
+		numKVHeads, key.Stride(2),
+		cachedSize,
 	)
 
+	if c.config.PermutedV {
+		vHeadDim := value.Dim(1)
+		elemSize := value.Stride(0)
+
+		value = value.View(ctx, elemSize*c.curCellRange.min,
+			cachedSize, value.Stride(1),
+			vHeadDim, value.Stride(2),
+			numKVHeads,
+		)
+	} else {
+		vHeadDim := value.Dim(0)
+		rowSize := value.Stride(2)
+
+		value = value.View(ctx, rowSize*c.curCellRange.min,
+			vHeadDim, value.Stride(1),
+			numKVHeads, value.Stride(2),
+			cachedSize,
+		)
+	}
+
 	return key, value, c.curMask
 }
 
 func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
-	if c.curBatchSize != key.Dim(2) {
-		panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, key.Dim(2)))
+	kHeadDim := key.Dim(0)
+	vHeadDim := value.Dim(0)
+	numKVHeads := key.Dim(1)
+	batchSize := key.Dim(2)
+
+	if c.curBatchSize != batchSize {
+		panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
 	}
 
 	if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
-		c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int(c.Capacity))
-		c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity))
+		c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
+
+		if c.config.PermutedV {
+			c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
+		} else {
+			c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
+		}
 	}
 
-	ctx.Forward(
-		key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))),
-		value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))),
-	)
+	rowSize := c.keys[c.curLayer].Stride(2)
+	ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
+
+	if c.config.PermutedV {
+		elemSize := c.values[c.curLayer].Stride(0)
+
+		value = value.Permute(ctx, 1, 2, 0, 3)
+		ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
+	} else {
+		rowSize := c.values[c.curLayer].Stride(2)
+
+		ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
+	}
 }
 
 func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
@@ -389,9 +485,13 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
 			continue
 		}
 
-		key = key.View(ctx, key.Stride(2)*seqRange.min,
-			key.Dim(0), key.Stride(1),
-			key.Dim(1), key.Stride(2),
+		kHeadDim := key.Dim(0)
+		numKVHeads := key.Dim(1)
+		rowSize := key.Stride(2)
+
+		key = key.View(ctx, rowSize*seqRange.min,
+			kHeadDim, key.Stride(1),
+			numKVHeads, key.Stride(2),
 			size,
 		)
 

+ 29 - 0
kvcache/encoder.go

@@ -1,6 +1,8 @@
 package kvcache
 
 import (
+	"fmt"
+
 	"github.com/ollama/ollama/ml"
 )
 
@@ -11,6 +13,9 @@ import (
 //
 // Not currently safe for multiple sequences
 type EncoderCache struct {
+	// config controls mostly backend-specific optimizations
+	config *ml.CacheConfig
+
 	// ** current forward pass **
 
 	// the active layer for Get and Put
@@ -40,9 +45,29 @@ func NewEncoderCache() *EncoderCache {
 }
 
 func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
+	if c.config == nil {
+		var config ml.CacheConfig
+		if cc, ok := backend.(ml.BackendCacheConfig); ok {
+			config = cc.CacheConfig()
+		}
+		c.config = &config
+	}
+
+	if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
+		panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
+	}
+
 	c.cacheCtx = backend.NewContext()
 }
 
+func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
+	if c.config != nil {
+		panic("config cannot be changed after being previously set, either by the model or backend")
+	}
+
+	c.config = &config
+}
+
 func (c *EncoderCache) Close() {
 	c.cacheCtx.Close()
 }
@@ -75,6 +100,10 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
 	c.encoderPos = c.curPos
 	c.encoderCached = true
 
+	if c.config.PermutedV {
+		value = value.Permute(ctx, 1, 2, 0, 3)
+	}
+
 	if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
 		c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...)
 		c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)

+ 6 - 0
kvcache/wrapper.go

@@ -28,6 +28,12 @@ func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
 	}
 }
 
+func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
+	for _, cache := range c.caches {
+		cache.SetConfig(config)
+	}
+}
+
 func (c *WrapperCache) Close() {
 	for _, cache := range c.caches {
 		cache.Close()

+ 25 - 0
ml/backend.go

@@ -27,6 +27,27 @@ type Backend interface {
 	SystemInfo() string
 }
 
+// BackendCacheConfig should be implemented by backends that need special output
+// from the cache to meet specific requirements. It is frequently implemented in
+// conjunction with ScaledDotProductAttention.
+type BackendCacheConfig interface {
+	CacheConfig() CacheConfig
+}
+
+// CacheConfig controls optimizations (mostly backend-specific) that may transform
+// the output the cache to work better with specific kernels.
+type CacheConfig struct {
+	// CachePadding specifies the multiple for the number of tokens of cache history
+	// that will be returned from cache Get for k, v and mask. The capacity of the
+	// cache itself will also be increased to a multiple of this size if needed.
+	CachePadding int
+
+	// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
+	// and return the permuted version via Get. This uses the cache copy operation
+	// to avoid a Contiguous call on the permuted tensor.
+	PermutedV bool
+}
+
 // BackendParams controls how the backend loads and executes models
 type BackendParams struct {
 	// NumThreads sets the number of threads to use if running on the CPU
@@ -116,6 +137,10 @@ type Tensor interface {
 // operation equivalent to following code on a tensor named
 // query:
 //
+// query = query.Permute(ctx, 0, 2, 1, 3)
+// key = key.Permute(ctx, 0, 2, 1, 3)
+// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
+//
 // kq := key.MulmatFullPrec(ctx, query)
 //
 // kq = kq.Scale(ctx, scale)

+ 8 - 1
ml/backend/ggml/ggml.go

@@ -247,6 +247,10 @@ func (b *Backend) NewContext() ml.Context {
 	}
 }
 
+func (b *Backend) CacheConfig() ml.CacheConfig {
+	return ml.CacheConfig{CachePadding: 32, PermutedV: true}
+}
+
 type Context struct {
 	b       *Backend
 	ctx     *C.struct_ggml_context
@@ -661,7 +665,10 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T
 		kqMask = mask.(*Tensor).t
 	}
 
-	kq := key.MulmatFullPrec(ctx, t)
+	query := t.Permute(ctx, 0, 2, 1, 3)
+	key = key.Permute(ctx, 0, 2, 1, 3)
+
+	kq := key.MulmatFullPrec(ctx, query)
 	kq = &Tensor{
 		t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
 	}

+ 31 - 20
ml/nn/attention.go

@@ -3,6 +3,7 @@ package nn
 import (
 	"fmt"
 
+	"github.com/ollama/ollama/kvcache"
 	"github.com/ollama/ollama/ml"
 )
 
@@ -11,40 +12,50 @@ import (
 //
 // Parameters:
 //   - ctx: Context for tensor operations
-//   - query: Query tensor (Q) with shape [d_k, seq_len_q, heads]
-//   - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads]
-//   - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads]
-//   - mask: Optional attention mask that is added to the attention score. If
-//     provided, should broadcast to [seq_len_k, seq_len_q, heads]
+//   - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
+//   - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
+//   - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
 //   - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
+//   - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
 //
 // Returns:
 //
 //	Attention output with shape [d_v, heads, seq_len_q]
-func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor {
-	if query.Dim(0) != key.Dim(0) {
-		panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
-	}
+func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
+	if key != nil && value != nil {
+		if query.Dim(0) != key.Dim(0) {
+			panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
+		}
 
-	if mask != nil && query.Dim(1) != mask.Dim(1) {
-		panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1)))
-	}
+		if key.Dim(1) != value.Dim(1) {
+			panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
+		}
 
-	if key.Dim(1) != value.Dim(0) {
-		panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0)))
-	}
+		if key.Dim(2) != value.Dim(2) {
+			panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
+		}
 
-	if mask != nil && key.Dim(1) != mask.Dim(0) {
-		panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0)))
+		if cache != nil {
+			cache.Put(ctx, key, value)
+		}
+	} else if cache == nil {
+		panic("key & value tensors must be provided if cache is nil")
 	}
 
-	if key.Dim(2) != value.Dim(2) {
-		panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
+	var mask ml.Tensor
+	if cache != nil {
+		key, value, mask = cache.Get(ctx)
 	}
 
-	if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
+	// Only use the fast SDPA implementation if we have a cache, since that's what
+	// will do any expected backend-specific transformations for us
+	if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
 		return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
 	} else {
+		query = query.Permute(ctx, 0, 2, 1, 3)
+		key = key.Permute(ctx, 0, 2, 1, 3)
+		value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
+
 		kq := key.MulmatFullPrec(ctx, query)
 
 		kq = kq.Scale(ctx, scale)

+ 1 - 8
model/models/llama/model.go

@@ -81,15 +81,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
 	v := sa.Value.Forward(ctx, hiddenState)
 	v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
 
-	cache.Put(ctx, k, v)
-	k, v, mask := cache.Get(ctx)
-
-	q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
-	k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
-	v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
-
 	scaleFactor := 1.0 / math.Sqrt(float64(headDim))
-	kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor)
+	kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
 	kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
 
 	return sa.Output.Forward(ctx, kqv)

+ 3 - 1
model/models/mllama/model.go

@@ -43,7 +43,9 @@ func New(c ml.Config) (model.Model, error) {
 		TextModel:      newTextModel(c),
 	}
 
-	m.Cache = kvcache.NewWrapperCache(kvcache.NewEncoderCache(), kvcache.NewCausalCache(m.TextModel.Shift))
+	encoderCache := kvcache.NewEncoderCache()
+	encoderCache.SetConfig(ml.CacheConfig{})
+	m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift))
 
 	return &m, nil
 }

+ 16 - 16
model/models/mllama/model_text.go

@@ -31,22 +31,15 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
 	value := sa.Value.Forward(ctx, hiddenState)
 	value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
 
-	cache.Put(ctx, key, value)
-	key, value, mask := cache.Get(ctx)
-
-	query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
-	key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
-	value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
-
 	scaleFactor := 1.0 / math.Sqrt(float64(headDim))
-	attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
+	attention := nn.Attention(ctx, query, key, value, scaleFactor, cache)
 	attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
 
 	return sa.Output.Forward(ctx, attention)
 }
 
 func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
-	// This will only get called for layers in the cache, which are just the self attention layers
+	// This will only get called for layers in the causal cache, which are just the self attention layers
 	return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
 }
 
@@ -107,7 +100,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
 	query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
 	query = ca.QueryNorm.Forward(ctx, query, opts.eps)
 
-	var key, value, mask ml.Tensor
+	var key, value ml.Tensor
 	if crossAttentionStates != nil {
 		numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
 
@@ -119,16 +112,23 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
 		value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
 
 		cache.Put(ctx, key, value)
-	} else {
-		key, value, mask = cache.Get(ctx)
 	}
 
-	query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
-	key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
-	value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
+	key, value, _ = cache.Get(ctx)
 
 	scaleFactor := 1.0 / math.Sqrt(float64(headDim))
-	attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
+
+	query = query.Permute(ctx, 0, 2, 1, 3)
+	key = key.Permute(ctx, 0, 2, 1, 3)
+	value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
+
+	kq := key.MulmatFullPrec(ctx, query)
+
+	kq = kq.Scale(ctx, scaleFactor)
+	kq = kq.Softmax(ctx)
+
+	kqv := value.Mulmat(ctx, kq)
+	attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
 	attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
 
 	return ca.Output.Forward(ctx, attention)