attention.go 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. package nn
  2. import (
  3. "fmt"
  4. "github.com/ollama/ollama/ml"
  5. )
  6. // Attention implements scaled dot-product attention for transformer models:
  7. // Attention(Q, K, V) = softmax(QK^T/√d_k)V
  8. //
  9. // Parameters:
  10. // - ctx: Context for tensor operations
  11. // - query: Query tensor (Q) with shape [d_k, seq_len_q, heads]
  12. // - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads]
  13. // - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads]
  14. // - mask: Optional attention mask that is added to the attention score. If
  15. // provided, should broadcast to [seq_len_k, seq_len_q, heads]
  16. // - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
  17. //
  18. // Returns:
  19. //
  20. // Attention output with shape [d_v, heads, seq_len_q]
  21. func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor {
  22. if query.Dim(0) != key.Dim(0) {
  23. panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
  24. }
  25. if mask != nil && query.Dim(1) != mask.Dim(1) {
  26. panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1)))
  27. }
  28. if key.Dim(1) != value.Dim(0) {
  29. panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0)))
  30. }
  31. if mask != nil && key.Dim(1) != mask.Dim(0) {
  32. panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0)))
  33. }
  34. if key.Dim(2) != value.Dim(2) {
  35. panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
  36. }
  37. if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
  38. return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
  39. } else {
  40. kq := key.MulmatFullPrec(ctx, query)
  41. kq = kq.Scale(ctx, scale)
  42. if mask != nil {
  43. kq = kq.Add(ctx, mask)
  44. }
  45. kq = kq.Softmax(ctx)
  46. kqv := value.Mulmat(ctx, kq)
  47. return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  48. }
  49. }