Browse Source

llama: add qwen2vl support

jmorganca 4 months ago
parent
commit
80d41b579b

+ 147 - 0
llama/ggml-metal-embed.metal

@@ -2081,6 +2081,7 @@ typedef struct {
     float    attn_factor;
     float    beta_fast;
     float    beta_slow;
+    int32_t  sections[4];
 } ggml_metal_kargs_rope;
 
 typedef struct {
@@ -4785,8 +4786,148 @@ kernel void kernel_rope_neox(
     }
 }
 
+
+template<typename T>
+kernel void kernel_rope_multi(
+        constant ggml_metal_kargs_rope & args,
+        device const char * src0,
+        device const char * src1,
+        device const char * src2,
+        device       char * dst,
+        ushort  tiitg[[thread_index_in_threadgroup]],
+        ushort3 tptg [[threads_per_threadgroup]],
+        uint3   tgpig[[threadgroup_position_in_grid]]) {
+    const int i3 = tgpig[2];
+    const int i2 = tgpig[1];
+    const int i1 = tgpig[0];
+
+    float corr_dims[2];
+    rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
+
+    device const int32_t * pos = (device const int32_t *) src1;
+
+    int sect_dims = args.sections[0] + args.sections[1] + args.sections[2] + args.sections[3];
+    int sec_w = args.sections[1] + args.sections[0];
+
+    const float inv_ndims = -1.f/args.n_dims;
+
+    float cos_theta;
+    float sin_theta;
+
+    for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
+        if (i0 < args.n_dims) {
+            const int ic = i0/2;
+            const int sector = ic % sect_dims;
+
+            float theta_base = (float) pos[i2];
+            if (sector >= args.sections[0] && sector < sec_w) {
+                theta_base = (float) pos[i2 + args.ne2];
+            }
+            else if (sector >= sec_w && sector < sec_w + args.sections[2]) {
+                theta_base = (float) pos[i2 + args.ne2 * 2];
+            }
+            else if (sector >= sec_w + args.sections[2]) {
+                theta_base = (float) pos[i2 + args.ne2 * 3];
+            }
+
+            float theta = theta_base*pow(args.freq_base, inv_ndims*i0);
+
+            const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
+
+            rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
+
+            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
+            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + ic*args.nb0);
+
+            const float x0 = src[0];
+            const float x1 = src[args.n_dims/2];
+
+            dst_data[0]             = x0*cos_theta - x1*sin_theta;
+            dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
+        } else {
+            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
+            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0);
+
+            dst_data[0] = src[0];
+            dst_data[1] = src[1];
+        }
+    }
+}
+
+template<typename T>
+kernel void kernel_rope_vision(
+        constant ggml_metal_kargs_rope & args,
+        device const char * src0,
+        device const char * src1,
+        device const char * src2,
+        device       char * dst,
+        ushort  tiitg[[thread_index_in_threadgroup]],
+        ushort3 tptg [[threads_per_threadgroup]],
+        uint3   tgpig[[threadgroup_position_in_grid]]) {
+    const int i3 = tgpig[2];
+    const int i2 = tgpig[1];
+    const int i1 = tgpig[0];
+
+    float corr_dims[2];
+    rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
+
+    device const int32_t * pos = (device const int32_t *) src1;
+
+    int sect_dims = args.sections[0] + args.sections[1];
+    int sec_w = args.sections[1] + args.sections[0];
+    int sec_e = args.sections[2] + sec_w;
+
+    const float inv_ndims = -1.f/args.n_dims;
+
+    float cos_theta;
+    float sin_theta;
+
+    for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
+        const int ic = i0/2;
+        const int sector = ic % sect_dims;
+
+        float theta_base = (float) pos[i2];
+        if (sector >= args.sections[0] && sector < sec_w) {
+            theta_base = (float) pos[i2 + args.ne2];
+        }
+        else if (sector >= sec_w && sector < sec_w + args.sections[2]) {
+            theta_base = (float) pos[i2 + args.ne2 * 2];
+        }
+        else if (sector >= sec_w + args.sections[2]) {
+            theta_base = (float) pos[i2 + args.ne2 * 3];
+        }
+
+        int p = sector;
+        if (sector >= sec_w + args.sections[2]) {
+            p = sector - (sec_w + args.sections[2]);
+        } else if (sector >= sec_w) {
+            p = sector - sec_w;
+        } else if (sector >= args.sections[0]) {
+            p = sector - args.sections[0];
+        }
+
+        const float theta = theta_base*pow(args.freq_base, inv_ndims*2.0f*p);
+
+        const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
+
+        rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
+
+        device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
+        device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + ic*args.nb0);
+
+        const float x0 = src[0];
+        const float x1 = src[args.n_dims];
+
+        dst_data[0]             = x0*cos_theta - x1*sin_theta;
+        dst_data[args.n_dims]   = x0*sin_theta + x1*cos_theta;
+    }
+}
+
+
 typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
 typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
+typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
+typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
 
 template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
 template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
@@ -4794,6 +4935,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_
 template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
 template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
 
+template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
+template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
+
+template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
+template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
+
 typedef void (im2col_t)(
         device const float * x,
         device        char * dst,

+ 1 - 0
llama/ggml-metal-impl.h

@@ -169,6 +169,7 @@ typedef struct {
     float    attn_factor;
     float    beta_fast;
     float    beta_slow;
+    int32_t  sections[4];
 } ggml_metal_kargs_rope;
 
 typedef struct {

+ 146 - 0
llama/ggml-metal.metal

@@ -2594,8 +2594,148 @@ kernel void kernel_rope_neox(
     }
 }
 
+
+template<typename T>
+kernel void kernel_rope_multi(
+        constant ggml_metal_kargs_rope & args,
+        device const char * src0,
+        device const char * src1,
+        device const char * src2,
+        device       char * dst,
+        ushort  tiitg[[thread_index_in_threadgroup]],
+        ushort3 tptg [[threads_per_threadgroup]],
+        uint3   tgpig[[threadgroup_position_in_grid]]) {
+    const int i3 = tgpig[2];
+    const int i2 = tgpig[1];
+    const int i1 = tgpig[0];
+
+    float corr_dims[2];
+    rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
+
+    device const int32_t * pos = (device const int32_t *) src1;
+
+    int sect_dims = args.sections[0] + args.sections[1] + args.sections[2] + args.sections[3];
+    int sec_w = args.sections[1] + args.sections[0];
+
+    const float inv_ndims = -1.f/args.n_dims;
+
+    float cos_theta;
+    float sin_theta;
+
+    for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
+        if (i0 < args.n_dims) {
+            const int ic = i0/2;
+            const int sector = ic % sect_dims;
+
+            float theta_base = (float) pos[i2];
+            if (sector >= args.sections[0] && sector < sec_w) {
+                theta_base = (float) pos[i2 + args.ne2];
+            }
+            else if (sector >= sec_w && sector < sec_w + args.sections[2]) {
+                theta_base = (float) pos[i2 + args.ne2 * 2];
+            }
+            else if (sector >= sec_w + args.sections[2]) {
+                theta_base = (float) pos[i2 + args.ne2 * 3];
+            }
+
+            float theta = theta_base*pow(args.freq_base, inv_ndims*i0);
+
+            const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
+
+            rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
+
+            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
+            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + ic*args.nb0);
+
+            const float x0 = src[0];
+            const float x1 = src[args.n_dims/2];
+
+            dst_data[0]             = x0*cos_theta - x1*sin_theta;
+            dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
+        } else {
+            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
+            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0);
+
+            dst_data[0] = src[0];
+            dst_data[1] = src[1];
+        }
+    }
+}
+
+template<typename T>
+kernel void kernel_rope_vision(
+        constant ggml_metal_kargs_rope & args,
+        device const char * src0,
+        device const char * src1,
+        device const char * src2,
+        device       char * dst,
+        ushort  tiitg[[thread_index_in_threadgroup]],
+        ushort3 tptg [[threads_per_threadgroup]],
+        uint3   tgpig[[threadgroup_position_in_grid]]) {
+    const int i3 = tgpig[2];
+    const int i2 = tgpig[1];
+    const int i1 = tgpig[0];
+
+    float corr_dims[2];
+    rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
+
+    device const int32_t * pos = (device const int32_t *) src1;
+
+    int sect_dims = args.sections[0] + args.sections[1];
+    int sec_w = args.sections[1] + args.sections[0];
+    int sec_e = args.sections[2] + sec_w;
+
+    const float inv_ndims = -1.f/args.n_dims;
+
+    float cos_theta;
+    float sin_theta;
+
+    for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
+        const int ic = i0/2;
+        const int sector = ic % sect_dims;
+
+        float theta_base = (float) pos[i2];
+        if (sector >= args.sections[0] && sector < sec_w) {
+            theta_base = (float) pos[i2 + args.ne2];
+        }
+        else if (sector >= sec_w && sector < sec_w + args.sections[2]) {
+            theta_base = (float) pos[i2 + args.ne2 * 2];
+        }
+        else if (sector >= sec_w + args.sections[2]) {
+            theta_base = (float) pos[i2 + args.ne2 * 3];
+        }
+
+        int p = sector;
+        if (sector >= sec_w + args.sections[2]) {
+            p = sector - (sec_w + args.sections[2]);
+        } else if (sector >= sec_w) {
+            p = sector - sec_w;
+        } else if (sector >= args.sections[0]) {
+            p = sector - args.sections[0];
+        }
+
+        const float theta = theta_base*pow(args.freq_base, inv_ndims*2.0f*p);
+
+        const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
+
+        rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
+
+        device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
+        device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + ic*args.nb0);
+
+        const float x0 = src[0];
+        const float x1 = src[args.n_dims];
+
+        dst_data[0]             = x0*cos_theta - x1*sin_theta;
+        dst_data[args.n_dims]   = x0*sin_theta + x1*cos_theta;
+    }
+}
+
+
 typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
 typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
+typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
+typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
 
 template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
 template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
@@ -2603,6 +2743,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_
 template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
 template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
 
+template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
+template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
+
+template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
+template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
+
 typedef void (im2col_t)(
         device const float * x,
         device        char * dst,

+ 39 - 15
llama/ggml-metal_darwin_arm64.m

@@ -328,6 +328,10 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
     GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
     GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
+    GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,
+    GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16,
+    GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32,
+    GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16,
     GGML_METAL_KERNEL_TYPE_IM2COL_F16,
     GGML_METAL_KERNEL_TYPE_IM2COL_F32,
     GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
@@ -928,6 +932,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,                 rope_norm_f16,                  true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,                 rope_neox_f32,                  true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,                 rope_neox_f16,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,                rope_multi_f32,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16,                rope_multi_f16,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32,               rope_vision_f32,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16,               rope_vision_f16,                true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16,                    im2col_f16,                     true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32,                    im2col_f32,                     true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,                im2col_ext_f16,                 true);
@@ -1155,16 +1163,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
         case GGML_OP_NORM:
             return true;
         case GGML_OP_ROPE:
-            {
-                const int mode = ((const int32_t *) op->op_params)[2];
-                if (mode & GGML_ROPE_TYPE_MROPE) {
-                    return false;
-                }
-                if (mode & GGML_ROPE_TYPE_VISION) {
-                    return false;
-                }
-                return true;
-            }
+            return true;
         case GGML_OP_IM2COL:
             return op->src[0]->type == GGML_TYPE_F16;
         case GGML_OP_POOL_1D:
@@ -3083,6 +3082,7 @@ static void ggml_metal_encode_node(
                 float attn_factor;
                 float beta_fast;
                 float beta_slow;
+                int32_t sections[4];
 
                 memcpy(&freq_base,   (const int32_t *) dst->op_params +  5, sizeof(float));
                 memcpy(&freq_scale,  (const int32_t *) dst->op_params +  6, sizeof(float));
@@ -3090,21 +3090,44 @@ static void ggml_metal_encode_node(
                 memcpy(&attn_factor, (const int32_t *) dst->op_params +  8, sizeof(float));
                 memcpy(&beta_fast,   (const int32_t *) dst->op_params +  9, sizeof(float));
                 memcpy(&beta_slow,   (const int32_t *) dst->op_params + 10, sizeof(float));
+                memcpy(&sections,    (const int32_t *) dst->op_params + 11, sizeof(int32_t)*4);
 
                 const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
+                const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
+                const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
+
+                if (is_mrope) {
+                    GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
+                }
+
+                if (is_vision) {
+                    GGML_ASSERT(n_dims == ne00/2);
+                }
 
                 id<MTLComputePipelineState> pipeline = nil;
 
-                if (!is_neox) {
+                if (is_neox) {
                     switch (src0->type) {
-                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
-                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
+                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
+                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
+                        default: GGML_ABORT("fatal error");
+                    };
+                } else if (is_mrope && !is_vision) {
+                    switch (src0->type) {
+                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break;
+                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break;
+                        default: GGML_ABORT("fatal error");
+                    };
+                } else if (is_vision) {
+                    switch (src0->type) {
+                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break;
+                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break;
                         default: GGML_ABORT("fatal error");
                     };
                 } else {
                     switch (src0->type) {
-                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
-                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
+                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
+                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
                         default: GGML_ABORT("fatal error");
                     };
                 }
@@ -3135,6 +3158,7 @@ static void ggml_metal_encode_node(
                     /*.attn_factor =*/ attn_factor,
                     /*.beta_fast   =*/ beta_fast,
                     /*.beta_slow   =*/ beta_slow,
+                    /*.sections    =*/ {sections[0], sections[1], sections[2], sections[3]}
                 };
 
                 [encoder setComputePipelineState:pipeline];

+ 299 - 0
llama/patches/0014-qwen2vl-support.patch

@@ -0,0 +1,299 @@
+From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
+From: jmorganca <jmorganca@gmail.com>
+Date: Sun, 15 Dec 2024 23:56:24 -0800
+Subject: [PATCH] qwen2vl support
+
+---
+ ggml/src/ggml-metal/ggml-metal-impl.h |   1 +
+ ggml/src/ggml-metal/ggml-metal.m      |  54 +++++++---
+ ggml/src/ggml-metal/ggml-metal.metal  | 146 ++++++++++++++++++++++++++
+ 3 files changed, 186 insertions(+), 15 deletions(-)
+
+diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h
+index e3dc25f1..766a4999 100644
+--- a/ggml/src/ggml-metal/ggml-metal-impl.h
++++ b/ggml/src/ggml-metal/ggml-metal-impl.h
+@@ -143,6 +143,7 @@ typedef struct {
+     float    attn_factor;
+     float    beta_fast;
+     float    beta_slow;
++    int32_t  sections[4];
+ } ggml_metal_kargs_rope;
+ 
+ typedef struct {
+diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
+index 787fc713..806c9fd3 100644
+--- a/ggml/src/ggml-metal/ggml-metal.m
++++ b/ggml/src/ggml-metal/ggml-metal.m
+@@ -302,6 +302,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
+     GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
+     GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
+     GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
++    GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,
++    GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16,
++    GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32,
++    GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16,
+     GGML_METAL_KERNEL_TYPE_IM2COL_F16,
+     GGML_METAL_KERNEL_TYPE_IM2COL_F32,
+     GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
+@@ -902,6 +906,10 @@ @implementation GGMLMetalClass
+         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,                 rope_norm_f16,                  true);
+         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,                 rope_neox_f32,                  true);
+         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,                 rope_neox_f16,                  true);
++        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,                rope_multi_f32,                 true);
++        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16,                rope_multi_f16,                 true);
++        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32,               rope_vision_f32,                true);
++        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16,               rope_vision_f16,                true);
+         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16,                    im2col_f16,                     true);
+         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32,                    im2col_f32,                     true);
+         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,                im2col_ext_f16,                 true);
+@@ -1129,16 +1137,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
+         case GGML_OP_NORM:
+             return true;
+         case GGML_OP_ROPE:
+-            {
+-                const int mode = ((const int32_t *) op->op_params)[2];
+-                if (mode & GGML_ROPE_TYPE_MROPE) {
+-                    return false;
+-                }
+-                if (mode & GGML_ROPE_TYPE_VISION) {
+-                    return false;
+-                }
+-                return true;
+-            }
++            return true;
+         case GGML_OP_IM2COL:
+             return op->src[0]->type == GGML_TYPE_F16;
+         case GGML_OP_POOL_1D:
+@@ -3057,6 +3056,7 @@ static void ggml_metal_encode_node(
+                 float attn_factor;
+                 float beta_fast;
+                 float beta_slow;
++                int32_t sections[4];
+ 
+                 memcpy(&freq_base,   (const int32_t *) dst->op_params +  5, sizeof(float));
+                 memcpy(&freq_scale,  (const int32_t *) dst->op_params +  6, sizeof(float));
+@@ -3064,21 +3064,44 @@ static void ggml_metal_encode_node(
+                 memcpy(&attn_factor, (const int32_t *) dst->op_params +  8, sizeof(float));
+                 memcpy(&beta_fast,   (const int32_t *) dst->op_params +  9, sizeof(float));
+                 memcpy(&beta_slow,   (const int32_t *) dst->op_params + 10, sizeof(float));
++                memcpy(&sections,    (const int32_t *) dst->op_params + 11, sizeof(int32_t)*4);
+ 
+                 const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
++                const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
++                const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
++
++                if (is_mrope) {
++                    GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
++                }
++
++                if (is_vision) {
++                    GGML_ASSERT(n_dims == ne00/2);
++                }
+ 
+                 id<MTLComputePipelineState> pipeline = nil;
+ 
+-                if (!is_neox) {
++                if (is_neox) {
+                     switch (src0->type) {
+-                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
+-                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
++                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
++                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
++                        default: GGML_ABORT("fatal error");
++                    };
++                } else if (is_mrope && !is_vision) {
++                    switch (src0->type) {
++                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break;
++                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break;
++                        default: GGML_ABORT("fatal error");
++                    };
++                } else if (is_vision) {
++                    switch (src0->type) {
++                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break;
++                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break;
+                         default: GGML_ABORT("fatal error");
+                     };
+                 } else {
+                     switch (src0->type) {
+-                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
+-                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
++                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
++                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
+                         default: GGML_ABORT("fatal error");
+                     };
+                 }
+@@ -3109,6 +3132,7 @@ static void ggml_metal_encode_node(
+                     /*.attn_factor =*/ attn_factor,
+                     /*.beta_fast   =*/ beta_fast,
+                     /*.beta_slow   =*/ beta_slow,
++                    /*.sections    =*/ {sections[0], sections[1], sections[2], sections[3]}
+                 };
+ 
+                 [encoder setComputePipelineState:pipeline];
+diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
+index 204c93e6..67b3240f 100644
+--- a/ggml/src/ggml-metal/ggml-metal.metal
++++ b/ggml/src/ggml-metal/ggml-metal.metal
+@@ -2568,8 +2568,148 @@ kernel void kernel_rope_neox(
+     }
+ }
+ 
++
++template<typename T>
++kernel void kernel_rope_multi(
++        constant ggml_metal_kargs_rope & args,
++        device const char * src0,
++        device const char * src1,
++        device const char * src2,
++        device       char * dst,
++        ushort  tiitg[[thread_index_in_threadgroup]],
++        ushort3 tptg [[threads_per_threadgroup]],
++        uint3   tgpig[[threadgroup_position_in_grid]]) {
++    const int i3 = tgpig[2];
++    const int i2 = tgpig[1];
++    const int i1 = tgpig[0];
++
++    float corr_dims[2];
++    rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
++
++    device const int32_t * pos = (device const int32_t *) src1;
++
++    int sect_dims = args.sections[0] + args.sections[1] + args.sections[2] + args.sections[3];
++    int sec_w = args.sections[1] + args.sections[0];
++
++    const float inv_ndims = -1.f/args.n_dims;
++
++    float cos_theta;
++    float sin_theta;
++
++    for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
++        if (i0 < args.n_dims) {
++            const int ic = i0/2;
++            const int sector = ic % sect_dims;
++
++            float theta_base = (float) pos[i2];
++            if (sector >= args.sections[0] && sector < sec_w) {
++                theta_base = (float) pos[i2 + args.ne2];
++            }
++            else if (sector >= sec_w && sector < sec_w + args.sections[2]) {
++                theta_base = (float) pos[i2 + args.ne2 * 2];
++            }
++            else if (sector >= sec_w + args.sections[2]) {
++                theta_base = (float) pos[i2 + args.ne2 * 3];
++            }
++
++            float theta = theta_base*pow(args.freq_base, inv_ndims*i0);
++
++            const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
++
++            rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
++
++            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
++            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + ic*args.nb0);
++
++            const float x0 = src[0];
++            const float x1 = src[args.n_dims/2];
++
++            dst_data[0]             = x0*cos_theta - x1*sin_theta;
++            dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
++        } else {
++            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
++            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0);
++
++            dst_data[0] = src[0];
++            dst_data[1] = src[1];
++        }
++    }
++}
++
++template<typename T>
++kernel void kernel_rope_vision(
++        constant ggml_metal_kargs_rope & args,
++        device const char * src0,
++        device const char * src1,
++        device const char * src2,
++        device       char * dst,
++        ushort  tiitg[[thread_index_in_threadgroup]],
++        ushort3 tptg [[threads_per_threadgroup]],
++        uint3   tgpig[[threadgroup_position_in_grid]]) {
++    const int i3 = tgpig[2];
++    const int i2 = tgpig[1];
++    const int i1 = tgpig[0];
++
++    float corr_dims[2];
++    rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
++
++    device const int32_t * pos = (device const int32_t *) src1;
++
++    int sect_dims = args.sections[0] + args.sections[1];
++    int sec_w = args.sections[1] + args.sections[0];
++    int sec_e = args.sections[2] + sec_w;
++
++    const float inv_ndims = -1.f/args.n_dims;
++
++    float cos_theta;
++    float sin_theta;
++
++    for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
++        const int ic = i0/2;
++        const int sector = ic % sect_dims;
++
++        float theta_base = (float) pos[i2];
++        if (sector >= args.sections[0] && sector < sec_w) {
++            theta_base = (float) pos[i2 + args.ne2];
++        }
++        else if (sector >= sec_w && sector < sec_w + args.sections[2]) {
++            theta_base = (float) pos[i2 + args.ne2 * 2];
++        }
++        else if (sector >= sec_w + args.sections[2]) {
++            theta_base = (float) pos[i2 + args.ne2 * 3];
++        }
++
++        int p = sector;
++        if (sector >= sec_w + args.sections[2]) {
++            p = sector - (sec_w + args.sections[2]);
++        } else if (sector >= sec_w) {
++            p = sector - sec_w;
++        } else if (sector >= args.sections[0]) {
++            p = sector - args.sections[0];
++        }
++
++        const float theta = theta_base*pow(args.freq_base, inv_ndims*2.0f*p);
++
++        const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
++
++        rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
++
++        device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
++        device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + ic*args.nb0);
++
++        const float x0 = src[0];
++        const float x1 = src[args.n_dims];
++
++        dst_data[0]             = x0*cos_theta - x1*sin_theta;
++        dst_data[args.n_dims]   = x0*sin_theta + x1*cos_theta;
++    }
++}
++
++
+ typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
+ typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
++typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
++typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
+ 
+ template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
+ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
+@@ -2577,6 +2717,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_
+ template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
+ template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
+ 
++template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
++template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
++
++template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
++template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
++
+ typedef void (im2col_t)(
+         device const float * x,
+         device        char * dst,