12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091 |
- From 469c9addef75893e6be12edda852d12e840bf064 Mon Sep 17 00:00:00 2001
- From: Georgi Gerganov <ggerganov@gmail.com>
- Date: Tue, 24 Oct 2023 09:46:50 +0300
- Subject: [PATCH 1/2] metal : handle ggml_scale for n%4 != 0 (close #3754)
- ggml-ci
- ---
- ggml-metal.m | 18 +++++++++++++-----
- ggml-metal.metal | 10 +++++++++-
- 2 files changed, 22 insertions(+), 6 deletions(-)
- diff --git a/ggml-metal.m b/ggml-metal.m
- index c908106..c1901dc 100644
- --- a/ggml-metal.m
- +++ b/ggml-metal.m
- @@ -62,6 +62,7 @@
- GGML_METAL_DECL_KERNEL(mul);
- GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
- GGML_METAL_DECL_KERNEL(scale);
- + GGML_METAL_DECL_KERNEL(scale_4);
- GGML_METAL_DECL_KERNEL(silu);
- GGML_METAL_DECL_KERNEL(relu);
- GGML_METAL_DECL_KERNEL(gelu);
- @@ -249,6 +250,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
- GGML_METAL_ADD_KERNEL(mul);
- GGML_METAL_ADD_KERNEL(mul_row);
- GGML_METAL_ADD_KERNEL(scale);
- + GGML_METAL_ADD_KERNEL(scale_4);
- GGML_METAL_ADD_KERNEL(silu);
- GGML_METAL_ADD_KERNEL(relu);
- GGML_METAL_ADD_KERNEL(gelu);
- @@ -347,6 +349,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
- GGML_METAL_DEL_KERNEL(mul);
- GGML_METAL_DEL_KERNEL(mul_row);
- GGML_METAL_DEL_KERNEL(scale);
- + GGML_METAL_DEL_KERNEL(scale_4);
- GGML_METAL_DEL_KERNEL(silu);
- GGML_METAL_DEL_KERNEL(relu);
- GGML_METAL_DEL_KERNEL(gelu);
- @@ -923,15 +926,20 @@ void ggml_metal_graph_compute(
-
- const float scale = *(const float *) src1->data;
-
- - [encoder setComputePipelineState:ctx->pipeline_scale];
- + int64_t n = ggml_nelements(dst);
- +
- + if (n % 4 == 0) {
- + n /= 4;
- + [encoder setComputePipelineState:ctx->pipeline_scale_4];
- + } else {
- + [encoder setComputePipelineState:ctx->pipeline_scale];
- + }
- +
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
-
- - const int64_t n = ggml_nelements(dst);
- - GGML_ASSERT(n % 4 == 0);
- -
- - [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_OP_UNARY:
- switch (ggml_get_unary_op(gf->nodes[i])) {
- diff --git a/ggml-metal.metal b/ggml-metal.metal
- index 69fc713..f4b4605 100644
- --- a/ggml-metal.metal
- +++ b/ggml-metal.metal
- @@ -125,9 +125,17 @@ kernel void kernel_mul_row(
- }
-
- kernel void kernel_scale(
- + device const float * src0,
- + device float * dst,
- + constant float & scale,
- + uint tpig[[thread_position_in_grid]]) {
- + dst[tpig] = src0[tpig] * scale;
- +}
- +
- +kernel void kernel_scale_4(
- device const float4 * src0,
- device float4 * dst,
- - constant float & scale,
- + constant float & scale,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] * scale;
- }
- --
- 2.39.3 (Apple Git-145)
|