Browse Source

ml: Abstract attention out of model definitions

There are two benefits to doing this:
 - Provide a library function that models can use, reducing code for
   each model implementation
 - Enables a single place to drop in optimized implementations of
   attention based on the backend or other factors. One is provided for
   GGML.

On CUDA this improves token generation rate by about 3%. It does not
have a significant effect on Metal.

Co-authored-by: Daniel Hiltgen <daniel@ollama.com>
Jesse Gross 2 months ago
parent
commit
f53f4198c3
5 changed files with 102 additions and 22 deletions
  1. 20 0
      ml/backend.go
  2. 15 0
      ml/backend/ggml/ggml.go
  3. 59 0
      ml/nn/attention.go
  4. 2 7
      model/models/llama/model.go
  5. 6 15
      model/models/mllama/model_text.go

+ 20 - 0
ml/backend.go

@@ -111,6 +111,26 @@ type Tensor interface {
 	Copy(ctx Context, t2 Tensor) Tensor
 }
 
+// ScaledDotProductAttention implements a fused attention
+// operation equivalent to following code on a tensor named
+// query:
+//
+// kq := key.MulmatFullPrec(ctx, query)
+//
+// kq = kq.Scale(ctx, scale)
+//
+//	if mask != nil {
+//		kq = kq.Add(ctx, mask)
+//	}
+//
+// kq = kq.Softmax(ctx)
+//
+// kqv := value.Mulmat(ctx, kq)
+// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
+type ScaledDotProductAttention interface {
+	ScaledDotProductAttention(ctx Context, key, value, mask Tensor, scale float64) Tensor
+}
+
 type number interface {
 	~int | ~int8 | ~int16 | ~int32 | ~int64 |
 		~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |

+ 15 - 0
ml/backend/ggml/ggml.go

@@ -651,6 +651,21 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
 	}
 }
 
+func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
+	var kqMask *C.struct_ggml_tensor
+	if mask != nil {
+		kqMask = mask.(*Tensor).t
+	}
+
+	kq := key.MulmatFullPrec(ctx, t)
+	kq = &Tensor{
+		t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
+	}
+
+	kqv := value.Mulmat(ctx, kq)
+	return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
+}
+
 func (b *Backend) SystemInfo() string {
 	var compiler string
 	switch C.get_compiler() {

+ 59 - 0
ml/nn/attention.go

@@ -0,0 +1,59 @@
+package nn
+
+import (
+	"fmt"
+
+	"github.com/ollama/ollama/ml"
+)
+
+// Attention implements scaled dot-product attention for transformer models:
+// Attention(Q, K, V) = softmax(QK^T/√d_k)V
+//
+// Parameters:
+//   - ctx: Context for tensor operations
+//   - query: Query tensor (Q) with shape [d_k, seq_len_q, heads]
+//   - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads]
+//   - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads]
+//   - mask: Optional attention mask that is added to the attention score. If
+//     provided, should broadcast to [seq_len_k, seq_len_q, heads]
+//   - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
+//
+// Returns:
+//
+//	Attention output with shape [d_v, heads, seq_len_q]
+func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor {
+	if query.Dim(0) != key.Dim(0) {
+		panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
+	}
+
+	if mask != nil && query.Dim(1) != mask.Dim(1) {
+		panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1)))
+	}
+
+	if key.Dim(1) != value.Dim(0) {
+		panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0)))
+	}
+
+	if mask != nil && key.Dim(1) != mask.Dim(0) {
+		panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0)))
+	}
+
+	if key.Dim(2) != value.Dim(2) {
+		panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
+	}
+
+	if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
+		return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
+	} else {
+		kq := key.MulmatFullPrec(ctx, query)
+
+		kq = kq.Scale(ctx, scale)
+		if mask != nil {
+			kq = kq.Add(ctx, mask)
+		}
+		kq = kq.Softmax(ctx)
+
+		kqv := value.Mulmat(ctx, kq)
+		return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
+	}
+}

+ 2 - 7
model/models/llama/model.go

@@ -86,13 +86,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
 	k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
 	v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
 
-	kq := k.MulmatFullPrec(ctx, q)
-	kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
-	kq = kq.Add(ctx, mask)
-	kq = kq.Softmax(ctx)
-
-	kqv := v.Mulmat(ctx, kq)
-	kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
+	scaleFactor := 1.0 / math.Sqrt(float64(headDim))
+	kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor)
 	kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
 
 	return sa.Output.Forward(ctx, kqv)

+ 6 - 15
model/models/mllama/model_text.go

@@ -38,13 +38,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
 	key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
 	value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
 
-	scores := key.MulmatFullPrec(ctx, query)
-	scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
-	scores = scores.Add(ctx, mask)
-	scores = scores.Softmax(ctx)
-
-	attention := value.Mulmat(ctx, scores)
-	attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
+	scaleFactor := 1.0 / math.Sqrt(float64(headDim))
+	attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
 	attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
 
 	return sa.Output.Forward(ctx, attention)
@@ -112,7 +107,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
 	query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
 	query = ca.QueryNorm.Forward(ctx, query, opts.eps)
 
-	var key, value ml.Tensor
+	var key, value, mask ml.Tensor
 	if crossAttentionStates != nil {
 		numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
 
@@ -125,19 +120,15 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
 
 		cache.Put(ctx, key, value)
 	} else {
-		key, value, _ = cache.Get(ctx)
+		key, value, mask = cache.Get(ctx)
 	}
 
 	query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
 	key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
 	value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
 
-	scores := key.Mulmat(ctx, query)
-	scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
-	scores = scores.Softmax(ctx)
-
-	attention := value.Mulmat(ctx, scores)
-	attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
+	scaleFactor := 1.0 / math.Sqrt(float64(headDim))
+	attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
 	attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
 
 	return ca.Output.Forward(ctx, attention)