attention.go 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. package nn
  2. import (
  3. "fmt"
  4. "github.com/ollama/ollama/kvcache"
  5. "github.com/ollama/ollama/ml"
  6. )
  7. // Attention implements scaled dot-product attention for transformer models:
  8. // Attention(Q, K, V) = softmax(QK^T/√d_k)V
  9. //
  10. // Parameters:
  11. // - ctx: Context for tensor operations
  12. // - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
  13. // - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
  14. // - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
  15. // - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
  16. // - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
  17. //
  18. // Returns:
  19. //
  20. // Attention output with shape [d_v, heads, seq_len_q]
  21. func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
  22. if key != nil && value != nil {
  23. if query.Dim(0) != key.Dim(0) {
  24. panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
  25. }
  26. if key.Dim(1) != value.Dim(1) {
  27. panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
  28. }
  29. if key.Dim(2) != value.Dim(2) {
  30. panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
  31. }
  32. if cache != nil {
  33. cache.Put(ctx, key, value)
  34. }
  35. } else if cache == nil {
  36. panic("key & value tensors must be provided if cache is nil")
  37. }
  38. var mask ml.Tensor
  39. if cache != nil {
  40. key, value, mask = cache.Get(ctx)
  41. }
  42. // Only use the fast SDPA implementation if we have a cache, since that's what
  43. // will do any expected backend-specific transformations for us
  44. if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
  45. return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
  46. } else {
  47. query = query.Permute(ctx, 0, 2, 1, 3)
  48. key = key.Permute(ctx, 0, 2, 1, 3)
  49. value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  50. kq := key.MulmatFullPrec(ctx, query)
  51. kq = kq.Scale(ctx, scale)
  52. if mask != nil {
  53. kq = kq.Add(ctx, mask)
  54. }
  55. kq = kq.Softmax(ctx)
  56. kqv := value.Mulmat(ctx, kq)
  57. return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  58. }
  59. }