0003-metal-fix-synchronization-in-new-matrix-multiplicati.patch 1.2 KB

123456789101112131415161718192021222324252627282930
  1. From dadbed99e65252d79f81101a392d0d6497b86caa Mon Sep 17 00:00:00 2001
  2. From: Shouzheng Liu <lshzh.hi@gmail.com>
  3. Date: Mon, 21 Aug 2023 06:59:29 -0400
  4. Subject: [PATCH] metal : fix synchronization in new matrix multiplication
  5. kernel (#2686)
  6. ---
  7. ggml-metal.metal | 3 ++-
  8. 1 file changed, 2 insertions(+), 1 deletion(-)
  9. diff --git a/ggml-metal.metal b/ggml-metal.metal
  10. index 3f31252..88d48f6 100644
  11. --- a/ggml-metal.metal
  12. +++ b/ggml-metal.metal
  13. @@ -1898,10 +1898,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
  14. threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
  15. + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
  16. for (int i = 0; i < 8; i++) {
  17. + threadgroup_barrier(mem_flags::mem_device);
  18. simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
  19. }
  20. - threadgroup_barrier(mem_flags::mem_threadgroup);
  21. + threadgroup_barrier(mem_flags::mem_device);
  22. device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
  23. if (sgitg==0) {
  24. for (int i = 0; i < n_rows; i++) {
  25. --
  26. 2.41.0