|
@@ -79,6 +79,8 @@ var devices = sync.OnceValue(func() []device {
|
|
|
})
|
|
|
|
|
|
type Backend struct {
|
|
|
+ flashAttention bool
|
|
|
+
|
|
|
meta *fs.GGML
|
|
|
cpus, gpus []Context
|
|
|
tensors map[string]*Context
|
|
@@ -192,9 +194,10 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
|
|
}
|
|
|
|
|
|
return &Backend{
|
|
|
- meta: meta,
|
|
|
- cpus: cpus,
|
|
|
- gpus: gpus,
|
|
|
+ flashAttention: params.FlashAttention,
|
|
|
+ meta: meta,
|
|
|
+ cpus: cpus,
|
|
|
+ gpus: gpus,
|
|
|
sched: C.ggml_backend_sched_new(
|
|
|
(*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
|
|
|
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
|
|
@@ -248,7 +251,11 @@ func (b *Backend) NewContext() ml.Context {
|
|
|
}
|
|
|
|
|
|
func (b *Backend) CacheConfig() ml.CacheConfig {
|
|
|
- return ml.CacheConfig{CachePadding: 32, PermutedV: true}
|
|
|
+ if b.flashAttention {
|
|
|
+ return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
|
|
|
+ } else {
|
|
|
+ return ml.CacheConfig{CachePadding: 32, PermutedV: true}
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
type Context struct {
|
|
@@ -705,14 +712,22 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T
|
|
|
query := t.Permute(ctx, 0, 2, 1, 3)
|
|
|
key = key.Permute(ctx, 0, 2, 1, 3)
|
|
|
|
|
|
- kq := key.MulmatFullPrec(ctx, query)
|
|
|
- kq = &Tensor{
|
|
|
- b: t.b,
|
|
|
- t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
|
|
|
- }
|
|
|
+ if t.b.flashAttention {
|
|
|
+ value = value.Permute(ctx, 0, 2, 1, 3)
|
|
|
|
|
|
- kqv := value.Mulmat(ctx, kq)
|
|
|
- return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
|
+ kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
|
|
|
+ C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
|
|
|
+ return &Tensor{b: t.b, t: kqv}
|
|
|
+ } else {
|
|
|
+ kq := key.MulmatFullPrec(ctx, query)
|
|
|
+ kq = &Tensor{
|
|
|
+ b: t.b,
|
|
|
+ t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
|
|
|
+ }
|
|
|
+
|
|
|
+ kqv := value.Mulmat(ctx, kq)
|
|
|
+ return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
func (b *Backend) SystemInfo() string {
|