0001-metal-handle-ggml_scale-for-n-4-0-close-3754.patch 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. From 469c9addef75893e6be12edda852d12e840bf064 Mon Sep 17 00:00:00 2001
  2. From: Georgi Gerganov <ggerganov@gmail.com>
  3. Date: Tue, 24 Oct 2023 09:46:50 +0300
  4. Subject: [PATCH 1/2] metal : handle ggml_scale for n%4 != 0 (close #3754)
  5. ggml-ci
  6. ---
  7. ggml-metal.m | 18 +++++++++++++-----
  8. ggml-metal.metal | 10 +++++++++-
  9. 2 files changed, 22 insertions(+), 6 deletions(-)
  10. diff --git a/ggml-metal.m b/ggml-metal.m
  11. index c908106..c1901dc 100644
  12. --- a/ggml-metal.m
  13. +++ b/ggml-metal.m
  14. @@ -62,6 +62,7 @@
  15. GGML_METAL_DECL_KERNEL(mul);
  16. GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
  17. GGML_METAL_DECL_KERNEL(scale);
  18. + GGML_METAL_DECL_KERNEL(scale_4);
  19. GGML_METAL_DECL_KERNEL(silu);
  20. GGML_METAL_DECL_KERNEL(relu);
  21. GGML_METAL_DECL_KERNEL(gelu);
  22. @@ -249,6 +250,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
  23. GGML_METAL_ADD_KERNEL(mul);
  24. GGML_METAL_ADD_KERNEL(mul_row);
  25. GGML_METAL_ADD_KERNEL(scale);
  26. + GGML_METAL_ADD_KERNEL(scale_4);
  27. GGML_METAL_ADD_KERNEL(silu);
  28. GGML_METAL_ADD_KERNEL(relu);
  29. GGML_METAL_ADD_KERNEL(gelu);
  30. @@ -347,6 +349,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
  31. GGML_METAL_DEL_KERNEL(mul);
  32. GGML_METAL_DEL_KERNEL(mul_row);
  33. GGML_METAL_DEL_KERNEL(scale);
  34. + GGML_METAL_DEL_KERNEL(scale_4);
  35. GGML_METAL_DEL_KERNEL(silu);
  36. GGML_METAL_DEL_KERNEL(relu);
  37. GGML_METAL_DEL_KERNEL(gelu);
  38. @@ -923,15 +926,20 @@ void ggml_metal_graph_compute(
  39. const float scale = *(const float *) src1->data;
  40. - [encoder setComputePipelineState:ctx->pipeline_scale];
  41. + int64_t n = ggml_nelements(dst);
  42. +
  43. + if (n % 4 == 0) {
  44. + n /= 4;
  45. + [encoder setComputePipelineState:ctx->pipeline_scale_4];
  46. + } else {
  47. + [encoder setComputePipelineState:ctx->pipeline_scale];
  48. + }
  49. +
  50. [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
  51. [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
  52. [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
  53. - const int64_t n = ggml_nelements(dst);
  54. - GGML_ASSERT(n % 4 == 0);
  55. -
  56. - [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
  57. + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
  58. } break;
  59. case GGML_OP_UNARY:
  60. switch (ggml_get_unary_op(gf->nodes[i])) {
  61. diff --git a/ggml-metal.metal b/ggml-metal.metal
  62. index 69fc713..f4b4605 100644
  63. --- a/ggml-metal.metal
  64. +++ b/ggml-metal.metal
  65. @@ -125,9 +125,17 @@ kernel void kernel_mul_row(
  66. }
  67. kernel void kernel_scale(
  68. + device const float * src0,
  69. + device float * dst,
  70. + constant float & scale,
  71. + uint tpig[[thread_position_in_grid]]) {
  72. + dst[tpig] = src0[tpig] * scale;
  73. +}
  74. +
  75. +kernel void kernel_scale_4(
  76. device const float4 * src0,
  77. device float4 * dst,
  78. - constant float & scale,
  79. + constant float & scale,
  80. uint tpig[[thread_position_in_grid]]) {
  81. dst[tpig] = src0[tpig] * scale;
  82. }
  83. --
  84. 2.39.3 (Apple Git-145)