jmorganca 10 months ago
parent
commit
b767f6554c
7 changed files with 6643 additions and 20 deletions
  1. 2 1
      llama/.gitignore
  2. 45 0
      llama/README.md
  3. 2 2
      llama/ggml-metal-darwin_arm64.m
  4. 6566 0
      llama/ggml-metal.in.metal
  5. 26 2
      llama/llama.go
  6. 1 0
      llama/llama_test.go
  7. 1 15
      llama/sync_llama.sh

+ 2 - 1
llama/.gitignore

@@ -6,4 +6,5 @@
 *.so
 *.o
 *.obj
-ollama_runner*
+ollama_runner*
+ggml-metal.metal

+ 45 - 0
llama/README.md

@@ -1,6 +1,10 @@
 # `llama`
 
+<<<<<<< Updated upstream
 This package integrates llama.cpp as a Go package that's easy to build with tags for different CPU and GPU processors.
+=======
+This package integrates the [llama.cpp](https://github.com/ggerganov/llama.cpp) library as a Go package and makes it easy to build it with tags for different CPU and GPU processors.
+>>>>>>> Stashed changes
 
 Supported:
 
@@ -12,6 +16,15 @@ Supported:
 - [x] Linux CUDA
 - [x] Linux ROCm
 - [x] Llava
+<<<<<<< Updated upstream
+=======
+- [ ] Parallel Requests
+
+Extra build steps are required for CUDA and ROCm on Windows since `nvcc` and `hipcc` both require using msvc as the host compiler. For these small dlls are created:
+
+- `ggml-cuda.dll`
+- `ggml-hipblas.dll`
+>>>>>>> Stashed changes
 
 > Note: it's important that memory is allocated and freed by the same compiler (e.g. entirely by code compiled with msvc or mingw). Issues from this should be rare, but there are some places where pointers are returned by the CUDA or HIP runtimes and freed elsewhere, causing a a crash. In a future change the same runtime should be used in both cases to avoid crashes.
 
@@ -41,7 +54,15 @@ go build -tags=avx,avx2 .
 
 ### CUDA
 
+<<<<<<< Updated upstream
 Install the [CUDA toolkit v11.3.1](https://developer.nvidia.com/cuda-11-3-1-download-archive):
+=======
+Install the [CUDA toolkit v11.3.1](https://developer.nvidia.com/cuda-11-3-1-download-archive) then build `libggml-cuda.so`:
+
+```shell
+./build_cuda.sh
+```
+>>>>>>> Stashed changes
 
 Then build the package with the `cuda` tag:
 
@@ -51,8 +72,11 @@ go build -tags=cuda .
 
 ## Windows
 
+<<<<<<< Updated upstream
 Download [w64devkit](https://github.com/skeeto/w64devkit/releases/latest) for a simple MinGW development environment.
 
+=======
+>>>>>>> Stashed changes
 ### CUDA
 
 Install the [CUDA toolkit v11.3.1](https://developer.nvidia.com/cuda-11-3-1-download-archive) then build the cuda code:
@@ -60,7 +84,11 @@ Install the [CUDA toolkit v11.3.1](https://developer.nvidia.com/cuda-11-3-1-down
 Build `ggml-cuda.dll`:
 
 ```shell
+<<<<<<< Updated upstream
 make ggml_cuda.dll
+=======
+./build_cuda.ps1
+>>>>>>> Stashed changes
 ```
 
 Then build the package with the `cuda` tag:
@@ -71,10 +99,19 @@ go build -tags=cuda .
 
 ### ROCm
 
+<<<<<<< Updated upstream
 Install [ROCm 5.7.1](https://rocm.docs.amd.com/en/docs-5.7.1/).
 
 ```shell
 make ggml_hipblas.dll
+=======
+Install [ROCm 5.7.1](https://rocm.docs.amd.com/en/docs-5.7.1/) and [Strawberry Perl](https://strawberryperl.com/).
+
+Then, build `ggml-hipblas.dll`:
+
+```shell
+./build_hipblas.sh
+>>>>>>> Stashed changes
 ```
 
 Then build the package with the `rocm` tag:
@@ -85,8 +122,16 @@ go build -tags=rocm .
 
 ## Syncing with llama.cpp
 
+<<<<<<< Updated upstream
 To update this package to the latest llama.cpp code, use the `sync_llama.sh` script from the root of this repo:
 
 ```
 ./sync_llama.sh ../../llama.cpp
+=======
+To update this package to the latest llama.cpp code, use the `scripts/sync_llama.sh` script from the root of this repo, providing the location of a llama.cpp checkout:
+
+```
+cd ollama
+./scripts/sync_llama.sh ../llama.cpp
+>>>>>>> Stashed changes
 ```

+ 2 - 2
llama/ggml-metal-darwin_arm64.m

@@ -377,8 +377,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
 #if GGML_METAL_EMBED_LIBRARY
             GGML_METAL_LOG_INFO("%s: using embedded metal library\n", __func__);
 
-            extern const char ggml_metallib_start[];
-            extern const char ggml_metallib_end[];
+            extern const char* ggml_metallib_start;
+            extern const char* ggml_metallib_end;
 
             NSString * src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
 #else

+ 6566 - 0
llama/ggml-metal.in.metal

@@ -0,0 +1,6566 @@
+/**
+ * llama.cpp - git ee459f40f65810a810151b24eba5b8bd174ceffe
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#define GGML_COMMON_DECL_METAL
+#define GGML_COMMON_IMPL_METAL
+#include "ggml-common.h"
+
+#include <metal_stdlib>
+
+using namespace metal;
+
+#define MAX(x, y) ((x) > (y) ? (x) : (y))
+#define MIN(x, y) ((x) < (y) ? (x) : (y))
+#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
+
+#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
+
+enum ggml_sort_order {
+    GGML_SORT_ORDER_ASC,
+    GGML_SORT_ORDER_DESC,
+};
+
+// general-purpose kernel for addition, multiplication and division of two tensors
+// pros: works for non-contiguous tensors, supports broadcast across all dims
+// cons: not very efficient
+kernel void kernel_add(
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        constant  int64_t & ne00,
+        constant  int64_t & ne01,
+        constant  int64_t & ne02,
+        constant  int64_t & ne03,
+        constant uint64_t & nb00,
+        constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant uint64_t & nb03,
+        constant  int64_t & ne10,
+        constant  int64_t & ne11,
+        constant  int64_t & ne12,
+        constant  int64_t & ne13,
+        constant uint64_t & nb10,
+        constant uint64_t & nb11,
+        constant uint64_t & nb12,
+        constant uint64_t & nb13,
+        constant  int64_t & ne0,
+        constant  int64_t & ne1,
+        constant  int64_t & ne2,
+        constant  int64_t & ne3,
+        constant uint64_t & nb0,
+        constant uint64_t & nb1,
+        constant uint64_t & nb2,
+        constant uint64_t & nb3,
+        constant  int64_t & offs,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig.z;
+    const int64_t i02 = tgpig.y;
+    const int64_t i01 = tgpig.x;
+
+    const int64_t i13 = i03 % ne13;
+    const int64_t i12 = i02 % ne12;
+    const int64_t i11 = i01 % ne11;
+
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
+    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1  + offs;
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        const int i10 = i0 % ne10;
+        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
+    }
+}
+
+kernel void kernel_mul(
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        constant  int64_t & ne00,
+        constant  int64_t & ne01,
+        constant  int64_t & ne02,
+        constant  int64_t & ne03,
+        constant uint64_t & nb00,
+        constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant uint64_t & nb03,
+        constant  int64_t & ne10,
+        constant  int64_t & ne11,
+        constant  int64_t & ne12,
+        constant  int64_t & ne13,
+        constant uint64_t & nb10,
+        constant uint64_t & nb11,
+        constant uint64_t & nb12,
+        constant uint64_t & nb13,
+        constant  int64_t & ne0,
+        constant  int64_t & ne1,
+        constant  int64_t & ne2,
+        constant  int64_t & ne3,
+        constant uint64_t & nb0,
+        constant uint64_t & nb1,
+        constant uint64_t & nb2,
+        constant uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig.z;
+    const int64_t i02 = tgpig.y;
+    const int64_t i01 = tgpig.x;
+
+    const int64_t i13 = i03 % ne13;
+    const int64_t i12 = i02 % ne12;
+    const int64_t i11 = i01 % ne11;
+
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        const int i10 = i0 % ne10;
+        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
+    }
+}
+
+kernel void kernel_div(
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        constant  int64_t & ne00,
+        constant  int64_t & ne01,
+        constant  int64_t & ne02,
+        constant  int64_t & ne03,
+        constant uint64_t & nb00,
+        constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant uint64_t & nb03,
+        constant  int64_t & ne10,
+        constant  int64_t & ne11,
+        constant  int64_t & ne12,
+        constant  int64_t & ne13,
+        constant uint64_t & nb10,
+        constant uint64_t & nb11,
+        constant uint64_t & nb12,
+        constant uint64_t & nb13,
+        constant  int64_t & ne0,
+        constant  int64_t & ne1,
+        constant  int64_t & ne2,
+        constant  int64_t & ne3,
+        constant uint64_t & nb0,
+        constant uint64_t & nb1,
+        constant uint64_t & nb2,
+        constant uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig.z;
+    const int64_t i02 = tgpig.y;
+    const int64_t i01 = tgpig.x;
+
+    const int64_t i13 = i03 % ne13;
+    const int64_t i12 = i02 % ne12;
+    const int64_t i11 = i01 % ne11;
+
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        const int i10 = i0 % ne10;
+        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
+    }
+}
+
+template<typename T>
+kernel void kernel_repeat(
+        device const char * src0,
+        device       char * dst,
+        constant  int64_t & ne00,
+        constant  int64_t & ne01,
+        constant  int64_t & ne02,
+        constant  int64_t & ne03,
+        constant uint64_t & nb00,
+        constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant uint64_t & nb03,
+        constant  int64_t & ne0,
+        constant  int64_t & ne1,
+        constant  int64_t & ne2,
+        constant  int64_t & ne3,
+        constant uint64_t & nb0,
+        constant uint64_t & nb1,
+        constant uint64_t & nb2,
+        constant uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    const int64_t i03 = i3 % ne03;
+    const int64_t i02 = i2 % ne02;
+    const int64_t i01 = i1 % ne01;
+
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+    device       char * dst_ptr  = dst  +  i3*nb3  +  i2*nb2  +  i1*nb1 ;
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        const int i00 = i0 % ne00;
+        *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
+    }
+}
+
+typedef decltype(kernel_repeat<float>) kernel_repeat_t;
+
+template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
+template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
+template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
+template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
+
+// assumption: src1 is a row
+// broadcast src1 into src0
+kernel void kernel_add_row(
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
+        constant   uint64_t & nb [[buffer(28)]],
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] + src1[tpig % nb];
+}
+
+kernel void kernel_mul_row(
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
+        constant   uint64_t & nb  [[buffer(28)]],
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * src1[tpig % nb];
+}
+
+kernel void kernel_div_row(
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
+        constant   uint64_t & nb  [[buffer(28)]],
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] / src1[tpig % nb];
+}
+
+kernel void kernel_scale(
+        device const float * src0,
+        device       float * dst,
+        constant     float & scale,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * scale;
+}
+
+kernel void kernel_scale_4(
+        device const float4 * src0,
+        device       float4 * dst,
+        constant     float  & scale,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * scale;
+}
+
+kernel void kernel_clamp(
+        device const float * src0,
+        device       float * dst,
+        constant     float & min,
+        constant     float & max,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]);
+}
+
+kernel void kernel_relu(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = max(0.0f, src0[tpig]);
+}
+
+kernel void kernel_sigmoid(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
+}
+
+kernel void kernel_tanh(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+    dst[tpig] = precise::tanh(x);
+}
+
+constant float GELU_COEF_A     = 0.044715f;
+constant float GELU_QUICK_COEF = -1.702f;
+constant float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
+
+kernel void kernel_gelu(
+    device const float * src0,
+    device       float * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+
+    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_gelu_4(
+    device const float4 * src0,
+    device       float4 * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float4 & x = src0[tpig];
+
+    // BEWARE !!!
+    // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
+    // This was observed with Falcon 7B and 40B models
+    //
+    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_gelu_quick(
+    device const float * src0,
+    device       float * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+
+    dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
+}
+
+kernel void kernel_gelu_quick_4(
+    device const float4 * src0,
+    device       float4 * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float4 & x = src0[tpig];
+
+    dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
+}
+
+kernel void kernel_silu(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+    dst[tpig] = x / (1.0f + exp(-x));
+}
+
+kernel void kernel_silu_4(
+        device const float4 * src0,
+        device       float4 * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float4 & x = src0[tpig];
+    dst[tpig] = x / (1.0f + exp(-x));
+}
+
+kernel void kernel_sqr(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * src0[tpig];
+}
+
+kernel void kernel_sum_rows(
+        device const float * src0,
+        device       float * dst,
+        constant  int64_t & ne00,
+        constant  int64_t & ne01,
+        constant  int64_t & ne02,
+        constant  int64_t & ne03,
+        constant uint64_t & nb00,
+        constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant uint64_t & nb03,
+        constant  int64_t & ne10,
+        constant  int64_t & ne11,
+        constant  int64_t & ne12,
+        constant  int64_t & ne13,
+        constant uint64_t & nb10,
+        constant uint64_t & nb11,
+        constant uint64_t & nb12,
+        constant uint64_t & nb13,
+        constant  int64_t & ne0,
+        constant  int64_t & ne1,
+        constant  int64_t & ne2,
+        constant  int64_t & ne3,
+        constant uint64_t & nb0,
+        constant uint64_t & nb1,
+        constant uint64_t & nb2,
+        constant uint64_t & nb3,
+        uint3 tpig[[thread_position_in_grid]]) {
+    int64_t i3 = tpig.z;
+    int64_t i2 = tpig.y;
+    int64_t i1 = tpig.x;
+
+    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+        return;
+    }
+
+    device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
+    device       float * dst_row = (device       float *) ((device       char *) dst  + i1*nb1  + i2*nb2  + i3*nb3);
+
+    float row_sum = 0;
+
+    for (int64_t i0 = 0; i0 < ne00; i0++) {
+        row_sum += src_row[i0];
+    }
+
+    dst_row[0] = row_sum;
+}
+
+template<typename T>
+kernel void kernel_soft_max(
+        device const  char * src0,
+        device const  char * src1,
+        device        char * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant     float & scale,
+        constant     float & max_bias,
+        constant     float & m0,
+        constant     float & m1,
+        constant  uint32_t & n_head_log2,
+        threadgroup  float * buf [[threadgroup(0)]],
+        uint  tgpig[[threadgroup_position_in_grid]],
+        uint  tpitg[[thread_position_in_threadgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint    ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = (tgpig) / (ne02*ne01);
+    const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+    const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+
+    device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    device const     T * pmask = src1 != src0 ? (device const    T *) src1         + i01*ne00 : nullptr;
+    device       float * pdst  = (device       float *) dst  + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+
+    float slope = 1.0f;
+
+    // ALiBi
+    if (max_bias > 0.0f) {
+        const int64_t h = i02;
+
+        const float base = h < n_head_log2 ? m0 : m1;
+        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+        slope = pow(base, exp);
+    }
+
+    // parallel max
+    float lmax = -INFINITY;
+
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
+    }
+
+    // find the max value in the block
+    float max_val = simd_max(lmax);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = -INFINITY;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = max_val;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        max_val = buf[tiisg];
+        max_val = simd_max(max_val);
+    }
+
+    // parallel sum
+    float lsum = 0.0f;
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
+        lsum += exp_psrc0;
+        pdst[i00] = exp_psrc0;
+    }
+
+    // This barrier fixes a failing test
+    // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+    threadgroup_barrier(mem_flags::mem_none);
+
+    float sum = simd_sum(lsum);
+
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = sum;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        sum = buf[tiisg];
+        sum = simd_sum(sum);
+    }
+
+    const float inv_sum = 1.0f/sum;
+
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        pdst[i00] *= inv_sum;
+    }
+}
+
+template<typename T>
+kernel void kernel_soft_max_4(
+        device const  char * src0,
+        device const  char * src1,
+        device        char * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant     float & scale,
+        constant     float & max_bias,
+        constant     float & m0,
+        constant     float & m1,
+        constant  uint32_t & n_head_log2,
+        threadgroup  float * buf [[threadgroup(0)]],
+        uint  tgpig[[threadgroup_position_in_grid]],
+        uint  tpitg[[thread_position_in_threadgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint    ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = (tgpig) / (ne02*ne01);
+    const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+    const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+
+    device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+    device const      T * pmask = src1 != src0 ? (device const     T *) src1         + i01*ne00/4 : nullptr;
+    device       float4 * pdst4 = (device       float4 *) dst  + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+
+    float slope = 1.0f;
+
+    if (max_bias > 0.0f) {
+        const int64_t h = i02;
+
+        const float base = h < n_head_log2 ? m0 : m1;
+        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+        slope = pow(base, exp);
+    }
+
+    // parallel max
+    float4 lmax4 = -INFINITY;
+
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+        lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
+    }
+
+    const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
+
+    float max_val = simd_max(lmax);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = -INFINITY;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = max_val;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        max_val = buf[tiisg];
+        max_val = simd_max(max_val);
+    }
+
+    // parallel sum
+    float4 lsum4 = 0.0f;
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+        const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
+        lsum4 += exp_psrc4;
+        pdst4[i00] = exp_psrc4;
+    }
+
+    const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
+
+    // This barrier fixes a failing test
+    // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+    threadgroup_barrier(mem_flags::mem_none);
+
+    float sum = simd_sum(lsum);
+
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = sum;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        sum = buf[tiisg];
+        sum = simd_sum(sum);
+    }
+
+    const float inv_sum = 1.0f/sum;
+
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+        pdst4[i00] *= inv_sum;
+    }
+}
+
+typedef decltype(kernel_soft_max<float>)    kernel_soft_max_t;
+typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
+
+template [[host_name("kernel_soft_max_f16")]]   kernel kernel_soft_max_t   kernel_soft_max<half>;
+template [[host_name("kernel_soft_max_f32")]]   kernel kernel_soft_max_t   kernel_soft_max<float>;
+template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
+template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
+
+kernel void kernel_diag_mask_inf(
+        device const float * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant       int & n_past,
+        uint3 tpig[[thread_position_in_grid]]) {
+    const int64_t i02 = tpig[2];
+    const int64_t i01 = tpig[1];
+    const int64_t i00 = tpig[0];
+
+    if (i00 > n_past + i01) {
+        dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
+    } else {
+        dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
+    }
+}
+
+kernel void kernel_diag_mask_inf_8(
+        device const float4 * src0,
+        device       float4 * dst,
+        constant    int64_t & ne00,
+        constant    int64_t & ne01,
+        constant        int & n_past,
+        uint3 tpig[[thread_position_in_grid]]) {
+
+    const int64_t i = 2*tpig[0];
+
+    dst[i+0] = src0[i+0];
+    dst[i+1] = src0[i+1];
+    int64_t i4 = 4*i;
+    const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
+    const int64_t i01 = i4/(ne00);      i4 -= i01*ne00;
+    const int64_t i00 = i4;
+    for (int k = 3; k >= 0; --k) {
+        if (i00 + 4 + k <= n_past + i01) {
+            break;
+        }
+        dst[i+1][k] = -INFINITY;
+        if (i00 + k > n_past + i01) {
+            dst[i][k] = -INFINITY;
+        }
+    }
+}
+
+kernel void kernel_norm(
+        device const  void * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant     float & eps,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint tgpig[[threadgroup_position_in_grid]],
+        uint tpitg[[thread_position_in_threadgroup]],
+        uint   ntg[[threads_per_threadgroup]]) {
+    device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
+    // MEAN
+    // parallel sum
+    sum[tpitg] = 0.0f;
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        sum[tpitg] += x[i00];
+    }
+    // reduce
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    for (uint i = ntg/2; i > 0; i /= 2) {
+        if (tpitg < i) {
+            sum[tpitg] += sum[tpitg + i];
+        }
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+    const float mean  = sum[0] / ne00;
+
+    // recenter and VARIANCE
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    device float * y = dst + tgpig*ne00;
+    sum[tpitg] = 0.0f;
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        y[i00] = x[i00] - mean;
+        sum[tpitg] += y[i00] * y[i00];
+    }
+
+    // reduce
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    for (uint i = ntg/2; i > 0; i /= 2) {
+        if (tpitg < i) {
+            sum[tpitg] += sum[tpitg + i];
+        }
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+    const float variance = sum[0] / ne00;
+
+    const float scale = 1.0f/sqrt(variance + eps);
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        y[i00] = y[i00] * scale;
+    }
+}
+
+kernel void kernel_rms_norm(
+        device const  void * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant     float & eps,
+        threadgroup float  * buf [[threadgroup(0)]],
+        uint tgpig[[threadgroup_position_in_grid]],
+        uint tpitg[[thread_position_in_threadgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]],
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint   ntg[[threads_per_threadgroup]]) {
+    device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
+
+    float4 sumf = 0;
+    float all_sum = 0;
+
+    // parallel sum
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+        sumf += x[i00] * x[i00];
+    }
+    all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
+    all_sum = simd_sum(all_sum);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = all_sum;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        all_sum = buf[tiisg];
+        all_sum = simd_sum(all_sum);
+    }
+
+    const float mean  = all_sum/ne00;
+    const float scale = 1.0f/sqrt(mean + eps);
+
+    device float4 * y = (device float4 *) (dst + tgpig*ne00);
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+        y[i00] = x[i00] * scale;
+    }
+}
+
+kernel void kernel_group_norm(
+        device const float * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int32_t & n_groups,
+        constant     float & eps,
+        threadgroup float  * buf [[threadgroup(0)]],
+        uint tgpig[[threadgroup_position_in_grid]],
+        uint tpitg[[thread_position_in_threadgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]],
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint   ntg[[threads_per_threadgroup]]) {
+    const int64_t ne = ne00*ne01*ne02;
+    const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
+
+    int start = tgpig * gs;
+    int end   = start + gs;
+
+    start += tpitg;
+
+    if (end >= ne) {
+        end = ne;
+    }
+
+    float tmp = 0.0f; // partial sum for thread in warp
+
+    for (int j = start; j < end; j += ntg) {
+        tmp += src0[j];
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    tmp = simd_sum(tmp);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = tmp;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        tmp = buf[tiisg];
+        tmp = simd_sum(tmp);
+    }
+
+    const float mean = tmp / gs;
+    tmp = 0.0f;
+
+    for (int j = start; j < end; j += ntg) {
+        float xi = src0[j] - mean;
+        dst[j] = xi;
+        tmp += xi * xi;
+    }
+
+    tmp = simd_sum(tmp);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = tmp;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        tmp = buf[tiisg];
+        tmp = simd_sum(tmp);
+    }
+
+    const float variance = tmp / gs;
+    const float scale = 1.0f/sqrt(variance + eps);
+    for (int j = start; j < end; j += ntg) {
+        dst[j] *= scale;
+    }
+}
+
+// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q4 quants begin (0 or QK4_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
+    float d = qb_curr->d;
+
+    float2 acc = 0.f;
+
+    device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
+
+    for (int i = 0; i < 8; i+=2) {
+        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
+                + yl[i + 1] * (qs[i / 2] & 0x0F00);
+        acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
+                + yl[i + 9] * (qs[i / 2] & 0xF000);
+    }
+    return d * (sumy * -8.f + acc[0] + acc[1]);
+}
+
+// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q4 quants begin (0 or QK4_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
+    float d = qb_curr->d;
+    float m = qb_curr->m;
+
+    float2 acc = 0.f;
+
+    device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
+
+    for (int i = 0; i < 8; i+=2) {
+        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
+                + yl[i + 1] * (qs[i / 2] & 0x0F00);
+        acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
+                + yl[i + 9] * (qs[i / 2] & 0xF000);
+    }
+    return d * (acc[0] + acc[1]) + sumy * m;
+}
+
+// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q5 quants begin (0 or QK5_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
+    float d = qb_curr->d;
+
+    float2 acc = 0.f;
+
+    device const uint16_t * qs =  ((device const uint16_t *)qb_curr + 3 + il/2);
+           const uint32_t   qh = *((device const uint32_t *)qb_curr->qh);
+
+    for (int i = 0; i < 8; i+=2) {
+        acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il        ) << 4 ) & 0x00010))
+                + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il        ) << 12) & 0x01000));
+        acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
+                + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
+    }
+    return d * (sumy * -16.f + acc[0] + acc[1]);
+}
+
+// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q5 quants begin (0 or QK5_1/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
+    float d = qb_curr->d;
+    float m = qb_curr->m;
+
+    float2 acc = 0.f;
+
+    device const uint16_t * qs =  ((device const uint16_t *)qb_curr + 4 + il/2);
+           const uint32_t   qh = *((device const uint32_t *)qb_curr->qh);
+
+    for (int i = 0; i < 8; i+=2) {
+        acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il        ) << 4 ) & 0x00010))
+                + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il        ) << 12) & 0x01000));
+        acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
+                + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
+    }
+    return d * (acc[0] + acc[1]) + sumy * m;
+}
+
+// putting them in the kernel cause a significant performance penalty
+#define N_DST 4        // each SIMD group works on 4 rows
+#define N_SIMDGROUP 2  // number of SIMD groups in a thread group
+//Note: This is a template, but strictly speaking it only applies to
+//      quantizations where the block size is 32. It also does not
+//      guard against the number of rows not being divisible by
+//      N_DST, so this is another explicit assumption of the implementation.
+template<typename block_q_type, int nr, int nsg, int nw>
+void mul_vec_q_n_f32_impl(
+        device const void  * src0,
+        device const float * src1,
+        device       float * dst,
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
+                   uint3 tgpig, uint tiisg, uint sgitg) {
+    const int nb = ne00/QK4_0;
+
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * nsg + sgitg) * nr;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+    device const block_q_type * x = (device const block_q_type *) src0 + offset0;
+    device const float        * y = (device const float        *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float yl[16]; // src1 vector cache
+    float sumf[nr] = {0.f};
+
+    const int ix = (tiisg/2);
+    const int il = (tiisg%2)*8;
+
+    device const float * yb = y + ix * QK4_0 + il;
+
+    // each thread in a SIMD group deals with half a block.
+    for (int ib = ix; ib < nb; ib += nw/2) {
+        float sumy = 0;
+        for (int i = 0; i < 8; i += 2) {
+            sumy += yb[i] + yb[i+1];
+            yl[i+0] = yb[i+ 0];
+            yl[i+1] = yb[i+ 1]/256.f;
+
+            sumy += yb[i+16] + yb[i+17];
+            yl[i+8] = yb[i+16]/16.f;
+            yl[i+9] = yb[i+17]/4096.f;
+        }
+
+        for (int row = 0; row < nr; row++) {
+            sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
+        }
+
+        yb += QK4_0 * 16;
+    }
+
+    for (int row = 0; row < nr; ++row) {
+        const float tot = simd_sum(sumf[row]);
+        if (tiisg == 0 && first_row + row < ne01) {
+            dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
+        }
+    }
+}
+
+kernel void kernel_mul_mv_q4_0_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+    mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+}
+
+kernel void kernel_mul_mv_q4_1_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+     mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+}
+
+kernel void kernel_mul_mv_q5_0_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+    mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+}
+
+kernel void kernel_mul_mv_q5_1_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+    mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+}
+
+
+#define NB_Q8_0 8
+
+void kernel_mul_mv_q8_0_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
+    const int nr  = N_DST;
+    const int nsg = N_SIMDGROUP;
+    const int nw  = N_SIMDWIDTH;
+
+    const int nb = ne00/QK8_0;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * nsg + sgitg) * nr;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+    device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
+    device const float      * y = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float yl[NB_Q8_0];
+    float sumf[nr]={0.f};
+
+    const int ix = tiisg/4;
+    const int il = tiisg%4;
+
+    device const float * yb = y + ix * QK8_0 + NB_Q8_0*il;
+
+    // each thread in a SIMD group deals with NB_Q8_0 quants at a time
+    for (int ib = ix; ib < nb; ib += nw/4) {
+        for (int i = 0; i < NB_Q8_0; ++i) {
+            yl[i] = yb[i];
+        }
+
+        for (int row = 0; row < nr; row++) {
+            device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
+            float sumq = 0.f;
+            for (int iq = 0; iq < NB_Q8_0; ++iq) {
+                sumq += qs[iq] * yl[iq];
+            }
+            sumf[row] += sumq*x[ib+row*nb].d;
+        }
+
+        yb += NB_Q8_0 * nw;
+    }
+
+    for (int row = 0; row < nr; ++row) {
+        const float tot = simd_sum(sumf[row]);
+        if (tiisg == 0 && first_row + row < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_q8_0_f32")]]
+kernel void kernel_mul_mv_q8_0_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+    kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+}
+
+#define N_F32_F32 4
+
+void kernel_mul_mv_f32_f32_impl(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                  uint64_t   nb00,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                   int64_t   ne10,
+                   int64_t   ne11,
+                   int64_t   ne12,
+                  uint64_t   nb10,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                     uint    r2,
+                     uint    r3,
+                     uint3   tgpig,
+                     uint    tiisg) {
+
+    const int64_t r0 = tgpig.x;
+    const int64_t rb = tgpig.y*N_F32_F32;
+    const int64_t im = tgpig.z;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+    device const float * x = (device const float *) (src0 + offset0);
+
+    if (ne00 < 128) {
+        for (int row = 0; row < N_F32_F32; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+
+            float sumf = 0;
+            for (int i = tiisg; i < ne00; i += 32) {
+                sumf += (float) x[i] * (float) y[i];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    } else {
+        device const float4 * x4 = (device const float4 *)x;
+        for (int row = 0; row < N_F32_F32; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            device const float  * y  = (device const float  *) (src1 + r1*nb11 + im*nb12);
+            device const float4 * y4 = (device const float4 *) y;
+
+            float sumf = 0;
+            for (int i = tiisg; i < ne00/4; i += 32) {
+                for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_f32_f32")]]
+kernel void kernel_mul_mv_f32_f32(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]]) {
+    kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
+}
+
+#define N_F16_F16 4
+
+kernel void kernel_mul_mv_f16_f16(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]]) {
+
+    const int64_t r0 = tgpig.x;
+    const int64_t rb = tgpig.y*N_F16_F16;
+    const int64_t im = tgpig.z;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+    device const half * x = (device const half *) (src0 + offset0);
+
+    if (ne00 < 128) {
+        for (int row = 0; row < N_F16_F16; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
+
+            float sumf = 0;
+            for (int i = tiisg; i < ne00; i += 32) {
+                sumf += (half) x[i] * (half) y[i];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    } else {
+        device const half4 * x4 = (device const half4 *)x;
+        for (int row = 0; row < N_F16_F16; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            device const half  * y  = (device const half  *) (src1 + r1*nb11 + im*nb12);
+            device const half4 * y4 = (device const half4 *) y;
+
+            float sumf = 0;
+            for (int i = tiisg; i < ne00/4; i += 32) {
+                for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    }
+}
+
+void kernel_mul_mv_f16_f32_1row_impl(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]]) {
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+    const int64_t im = tgpig.z;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+    device const half  * x = (device const half  *) (src0 + offset0);
+    device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+
+    float sumf = 0;
+    if (ne00 < 128) {
+        for (int i = tiisg; i < ne00; i += 32) {
+            sumf += (float) x[i] * (float) y[i];
+        }
+        float all_sum = simd_sum(sumf);
+        if (tiisg == 0) {
+            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+        }
+    } else {
+        device const half4  * x4 = (device const half4  *) x;
+        device const float4 * y4 = (device const float4 *) y;
+        for (int i = tiisg; i < ne00/4; i += 32) {
+            for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
+        }
+        float all_sum = simd_sum(sumf);
+        if (tiisg == 0) {
+            for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
+            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_f16_f32_1row")]]
+kernel void kernel_mul_mv_f16_f32_1row(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]]) {
+    kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
+}
+
+#define N_F16_F32 4
+
+void kernel_mul_mv_f16_f32_impl(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                  uint64_t   nb00,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                   int64_t   ne10,
+                   int64_t   ne11,
+                   int64_t   ne12,
+                  uint64_t   nb10,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+                   uint3     tgpig,
+                   uint      tiisg) {
+
+    const int64_t r0 = tgpig.x;
+    const int64_t rb = tgpig.y*N_F16_F32;
+    const int64_t im = tgpig.z;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+    device const half * x = (device const half *) (src0 + offset0);
+
+    if (ne00 < 128) {
+        for (int row = 0; row < N_F16_F32; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+
+            float sumf = 0;
+            for (int i = tiisg; i < ne00; i += 32) {
+                sumf += (float) x[i] * (float) y[i];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    } else {
+        device const half4 * x4 = (device const half4 *)x;
+        for (int row = 0; row < N_F16_F32; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            device const float  * y  = (device const float  *) (src1 + r1*nb11 + im*nb12);
+            device const float4 * y4 = (device const float4 *) y;
+
+            float sumf = 0;
+            for (int i = tiisg; i < ne00/4; i += 32) {
+                for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_f16_f32")]]
+kernel void kernel_mul_mv_f16_f32(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]]) {
+    kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
+}
+
+// Assumes row size (ne00) is a multiple of 4
+kernel void kernel_mul_mv_f16_f32_l4(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]]) {
+
+    const int nrows = ne11;
+    const int64_t r0 = tgpig.x;
+    const int64_t im = tgpig.z;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+    device const half4 * x4 = (device const half4 *) (src0 + offset0);
+
+    for (int r1 = 0; r1 < nrows; ++r1) {
+        device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
+
+        float sumf = 0;
+        for (int i = tiisg; i < ne00/4; i += 32) {
+            for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
+        }
+
+        float all_sum = simd_sum(sumf);
+        if (tiisg == 0) {
+            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+        }
+    }
+}
+
+static float rope_yarn_ramp(const float low, const float high, const int i0) {
+    const float y = (i0 / 2 - low) / max(0.001f, high - low);
+    return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static void rope_yarn(
+    float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
+    thread float * cos_theta, thread float * sin_theta) {
+    // Get n-d rotational scaling corrected for extrapolation
+    float theta_interp = freq_scale * theta_extrap;
+    float theta = theta_interp;
+    if (ext_factor != 0.0f) {
+        float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+        // Get n-d magnitude scaling corrected for interpolation
+        mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
+    }
+    *cos_theta = cos(theta) * mscale;
+    *sin_theta = sin(theta) * mscale;
+}
+
+// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
+// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
+static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
+    return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
+}
+
+static void rope_yarn_corr_dims(
+    int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
+) {
+    // start and end correction dims
+    dims[0] = max(0.0f,         floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
+    dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
+}
+
+template<typename T>
+kernel void kernel_rope_norm(
+        device const    void * src0,
+        device const int32_t * src1,
+        device const   float * src2,
+        device         float * dst,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant     int64_t & ne03,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant    uint64_t & nb03,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & ne2,
+        constant     int64_t & ne3,
+        constant    uint64_t & nb0,
+        constant    uint64_t & nb1,
+        constant    uint64_t & nb2,
+        constant    uint64_t & nb3,
+        constant         int & n_past,
+        constant         int & n_dims,
+        constant         int & n_ctx_orig,
+        constant       float & freq_base,
+        constant       float & freq_scale,
+        constant       float & ext_factor,
+        constant       float & attn_factor,
+        constant       float & beta_fast,
+        constant       float & beta_slow,
+        uint  tiitg[[thread_index_in_threadgroup]],
+        uint3 tptg[[threads_per_threadgroup]],
+        uint3 tgpig[[threadgroup_position_in_grid]]) {
+    const int64_t i3 = tgpig[2];
+    const int64_t i2 = tgpig[1];
+    const int64_t i1 = tgpig[0];
+
+    float corr_dims[2];
+    rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
+
+    device const int32_t * pos = src1;
+
+    const float theta_base = (float) pos[i2];
+    const float inv_ndims = -1.f/n_dims;
+
+    float cos_theta;
+    float sin_theta;
+
+    for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
+        if (i0 < n_dims) {
+            const int64_t ic = i0/2;
+
+            const float theta = theta_base * pow(freq_base, inv_ndims*i0);
+
+            const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
+
+            rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+            device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+            device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+            const float x0 = src[0];
+            const float x1 = src[1];
+
+            dst_data[0] = x0*cos_theta - x1*sin_theta;
+            dst_data[1] = x0*sin_theta + x1*cos_theta;
+        } else {
+            device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+            device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+            dst_data[0] = src[0];
+            dst_data[1] = src[1];
+        }
+    }
+}
+
+template<typename T>
+kernel void kernel_rope_neox(
+        device const    void * src0,
+        device const int32_t * src1,
+        device const   float * src2,
+        device         float * dst,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant     int64_t & ne03,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant    uint64_t & nb03,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & ne2,
+        constant     int64_t & ne3,
+        constant    uint64_t & nb0,
+        constant    uint64_t & nb1,
+        constant    uint64_t & nb2,
+        constant    uint64_t & nb3,
+        constant         int & n_past,
+        constant         int & n_dims,
+        constant         int & n_ctx_orig,
+        constant       float & freq_base,
+        constant       float & freq_scale,
+        constant       float & ext_factor,
+        constant       float & attn_factor,
+        constant       float & beta_fast,
+        constant       float & beta_slow,
+        uint  tiitg[[thread_index_in_threadgroup]],
+        uint3 tptg[[threads_per_threadgroup]],
+        uint3 tgpig[[threadgroup_position_in_grid]]) {
+    const int64_t i3 = tgpig[2];
+    const int64_t i2 = tgpig[1];
+    const int64_t i1 = tgpig[0];
+
+    float corr_dims[2];
+    rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
+
+    device const int32_t * pos = src1;
+
+    const float theta_base = (float) pos[i2];
+    const float inv_ndims = -1.f/n_dims;
+
+    float cos_theta;
+    float sin_theta;
+
+    for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
+        if (i0 < n_dims) {
+            const int64_t ic = i0/2;
+
+            const float theta = theta_base * pow(freq_base, inv_ndims*i0);
+
+            const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
+
+            rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+            device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+            device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
+
+            const float x0 = src[0];
+            const float x1 = src[n_dims/2];
+
+            dst_data[0]        = x0*cos_theta - x1*sin_theta;
+            dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+        } else {
+            device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+            device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+            dst_data[0] = src[0];
+            dst_data[1] = src[1];
+        }
+    }
+}
+
+typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
+typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
+
+template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
+template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
+
+template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
+template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
+
+typedef void (im2col_t)(
+        device const float * x,
+        device        char * dst,
+        constant   int32_t & ofs0,
+        constant   int32_t & ofs1,
+        constant   int32_t & IW,
+        constant   int32_t & IH,
+        constant   int32_t & CHW,
+        constant   int32_t & s0,
+        constant   int32_t & s1,
+        constant   int32_t & p0,
+        constant   int32_t & p1,
+        constant   int32_t & d0,
+        constant   int32_t & d1,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3  tgpg[[threadgroups_per_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]);
+
+template <typename T>
+kernel void kernel_im2col(
+        device const float * x,
+        device        char * dst,
+        constant   int32_t & ofs0,
+        constant   int32_t & ofs1,
+        constant   int32_t & IW,
+        constant   int32_t & IH,
+        constant   int32_t & CHW,
+        constant   int32_t & s0,
+        constant   int32_t & s1,
+        constant   int32_t & p0,
+        constant   int32_t & p1,
+        constant   int32_t & d0,
+        constant   int32_t & d1,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3  tgpg[[threadgroups_per_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
+    const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
+
+    const int32_t offset_dst =
+        (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
+        (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
+
+    device T * pdst = (device T *) (dst);
+
+    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+        pdst[offset_dst] = 0.0f;
+    } else {
+        const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
+        pdst[offset_dst] = x[offset_src + iih * IW + iiw];
+    }
+}
+
+template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
+template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
+
+kernel void kernel_upscale_f32(
+    device  const char * src0,
+    device        char * dst,
+    constant   int64_t & ne00,
+    constant   int64_t & ne01,
+    constant   int64_t & ne02,
+    constant   int64_t & ne03,
+    constant  uint64_t & nb00,
+    constant  uint64_t & nb01,
+    constant  uint64_t & nb02,
+    constant  uint64_t & nb03,
+    constant   int64_t & ne0,
+    constant   int64_t & ne1,
+    constant   int64_t & ne2,
+    constant   int64_t & ne3,
+    constant  uint64_t & nb0,
+    constant  uint64_t & nb1,
+    constant  uint64_t & nb2,
+    constant  uint64_t & nb3,
+    constant     float & sf0,
+    constant     float & sf1,
+    constant     float & sf2,
+    constant     float & sf3,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    const int64_t i03 = i3/sf3;
+    const int64_t i02 = i2/sf2;
+    const int64_t i01 = i1/sf1;
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        const int64_t i00 = i0/sf0;
+
+        device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+        device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1  +  i0*nb0);
+
+        dst_ptr[0] = src0_ptr[0];
+    }
+}
+
+kernel void kernel_pad_f32(
+    device  const char * src0,
+    device        char * dst,
+    constant   int64_t & ne00,
+    constant   int64_t & ne01,
+    constant   int64_t & ne02,
+    constant   int64_t & ne03,
+    constant  uint64_t & nb00,
+    constant  uint64_t & nb01,
+    constant  uint64_t & nb02,
+    constant  uint64_t & nb03,
+    constant   int64_t & ne0,
+    constant   int64_t & ne1,
+    constant   int64_t & ne2,
+    constant   int64_t & ne3,
+    constant  uint64_t & nb0,
+    constant  uint64_t & nb1,
+    constant  uint64_t & nb2,
+    constant  uint64_t & nb3,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    const int64_t i03 = i3;
+    const int64_t i02 = i2;
+    const int64_t i01 = i1;
+
+    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
+
+    if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
+        for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+            if (i0 < ne00) {
+                dst_ptr[i0] = src0_ptr[i0];
+            } else {
+                dst_ptr[i0] = 0.0f;
+            }
+        }
+
+        return;
+    }
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        dst_ptr[i0] = 0.0f;
+    }
+}
+
+kernel void kernel_arange_f32(
+    device        char * dst,
+    constant   int64_t & ne0,
+    constant   float   & start,
+    constant   float   & step,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    device float * dst_ptr = (device float *) dst;
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        dst_ptr[i0] = start + step * i0;
+    }
+}
+
+kernel void kernel_timestep_embedding_f32(
+    device  const char * src0,
+    device        char * dst,
+    constant  uint64_t & nb1,
+    constant  int      & dim,
+    constant  int      & max_period,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    int i = tgpig.x;
+    device float * embed_data = (device float *)(dst +  i*nb1);
+
+    int half_ = dim / 2;
+    for (int j = tpitg.x; j < half_; j += ntg.x) {
+        float timestep = ((device float *)src0)[i];
+        float freq = (float)exp(-log((float)max_period) * j / half_);
+        float arg = timestep * freq;
+        embed_data[j        ] = cos(arg);
+        embed_data[j + half_] = sin(arg);
+    }
+
+    if (dim % 2 != 0 && tpitg.x == 0) {
+        embed_data[dim] = 0.f;
+    }
+}
+
+// bitonic sort implementation following the CUDA kernels as reference
+typedef void (argsort_t)(
+        device const float  * x,
+        device     int32_t  * dst,
+        constant   int64_t  & ncols,
+        constant   int64_t  & ncols_pad,
+        threadgroup int32_t * shared_values [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]]);
+
+template<ggml_sort_order order>
+kernel void kernel_argsort_f32_i32(
+        device const float   * x,
+        device       int32_t * dst,
+        constant     int64_t & ncols,
+        constant     int64_t & ncols_pad,
+        threadgroup int32_t  * shared_values [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]]) {
+    // bitonic sort
+    int col = tpitg[0];
+    int row = tgpig[1];
+
+    if (col >= ncols_pad) return;
+
+    device const float   * x_row   = x + row * ncols;
+    threadgroup int32_t  * dst_row = shared_values;
+
+    // initialize indices
+    dst_row[col] = col;
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    for (int k = 2; k <= ncols_pad; k *= 2) {
+        for (int j = k / 2; j > 0; j /= 2) {
+            int ixj = col ^ j;
+            if (ixj > col) {
+                if ((col & k) == 0) {
+                    if (dst_row[col] >= ncols ||
+                        (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                            x_row[dst_row[col]] > x_row[dst_row[ixj]] :
+                            x_row[dst_row[col]] < x_row[dst_row[ixj]]))
+                    ) {
+                        SWAP(dst_row[col], dst_row[ixj]);
+                    }
+                } else {
+                    if (dst_row[ixj] >= ncols ||
+                        (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                            x_row[dst_row[col]] < x_row[dst_row[ixj]] :
+                            x_row[dst_row[col]] > x_row[dst_row[ixj]]))
+                    ) {
+                        SWAP(dst_row[col], dst_row[ixj]);
+                    }
+                }
+            }
+            threadgroup_barrier(mem_flags::mem_threadgroup);
+        }
+    }
+
+    // copy the result to dst without the padding
+    if (col < ncols) {
+        dst[row * ncols + col] = dst_row[col];
+    }
+}
+
+template [[host_name("kernel_argsort_f32_i32_asc")]]  kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
+template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
+
+kernel void kernel_leaky_relu_f32(
+        device const float * src0,
+        device       float * dst,
+        constant     float & slope,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
+}
+
+typedef void (flash_attn_ext_f16_t)(
+        device const  char * q,
+        device const  char * k,
+        device const  char * v,
+        device const  char * mask,
+        device       float * dst,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant   int64_t & ne13,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
+        constant  uint64_t & nb21,
+        constant  uint64_t & nb22,
+        constant  uint64_t & nb23,
+        constant  uint64_t & nb31,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant     float & scale,
+        constant     float & max_bias,
+        constant     float & m0,
+        constant     float & m1,
+        constant  uint32_t & n_head_log2,
+        threadgroup   half * shared,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        uint3  tpitg[[thread_position_in_threadgroup]],
+        uint3    ntg[[threads_per_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]);
+
+// ref: https://arxiv.org/pdf/2307.08691.pdf
+template<int64_t D, int64_t Q = 8, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
+kernel void kernel_flash_attn_ext_f16(
+        device const  char * q,
+        device const  char * k,
+        device const  char * v,
+        device const  char * mask,
+        device       float * dst,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant   int64_t & ne13,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
+        constant  uint64_t & nb21,
+        constant  uint64_t & nb22,
+        constant  uint64_t & nb23,
+        constant  uint64_t & nb31,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant     float & scale,
+        constant     float & max_bias,
+        constant     float & m0,
+        constant     float & m1,
+        constant  uint32_t & n_head_log2,
+        threadgroup   half * shared [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        uint3  tpitg[[thread_position_in_threadgroup]],
+        uint3    ntg[[threads_per_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    const short nsg = ntg.y; // number of simdgroups
+
+    const short iq3 = tgpig[2];
+    const short iq2 = tgpig[1];
+    const short iq1 = tgpig[0]*Q;
+
+    const short D4 = D/4;
+    const short D8 = D/8;
+  //const short Q8 = Q/8;
+    const short NW = N_SIMDWIDTH;
+    const short SH = (C + Q); // shared memory per simdgroup in (half)
+
+    const short T  = D + 2*nsg*SH; // shared memory size per query in (half)
+    const short TF = T/2;        // shared memory size per query in (float)
+    const short T4 = T/4;        // shared memory size per query in (half4)
+
+    threadgroup half  * sq  = (threadgroup half  *) (shared +              0*D); // holds the query data
+    threadgroup half4 * sq4 = (threadgroup half4 *) (shared +              0*D); // same as above but in half4
+    threadgroup float * ss  = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
+
+    // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
+    simdgroup_half8x8 lo[D8];
+
+    // load heads from Q to shared memory
+    for (short j = sgitg; j < Q; j += nsg) {
+        device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
+
+        for (short i = tiisg; i < D4; i += NW) {
+            if (iq1 + j < ne01) {
+                sq4[j*T4 + i] = (half4) q4[i];
+            } else {
+                sq4[j*T4 + i] = 0.0h;
+            }
+        }
+    }
+
+    // zero out lo
+    for (short i = 0; i < D8; ++i) {
+        lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
+    }
+
+    // zero out shared memory SH
+    for (short j = 0; j < Q; ++j) {
+        for (short i = tiisg; i < SH; i += NW) {
+            ss[j*TF + i] = 0.0f;
+        }
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    {
+        float S[Q] = { [0 ... Q-1] = 0.0h };
+        float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
+
+        // assume K and V are same shape
+        const short ne22 = ne12;
+        const short ne23 = ne13;
+
+        // broadcast
+        const short rk2 = ne02/ne12;
+        const short rk3 = ne03/ne13;
+
+        const short rv2 = ne02/ne22;
+        const short rv3 = ne03/ne23;
+
+        // k indices
+        const short ik2 = iq2/rk2;
+        const short ik3 = iq3/rk3;
+
+        // v indices
+        const short iv2 = iq2/rv2;
+        const short iv3 = iq3/rv3;
+
+        // load the queries from shared memory into local memory
+        simdgroup_half8x8 mq[D8];
+
+        for (short i = 0; i < D8; ++i) {
+            simdgroup_load(mq[i], sq + i*8, T);
+        }
+
+        // pointer to the mask
+        device const half * mp = (device const half *) (mask + iq1*nb31);
+
+        float slope = 1.0f;
+
+        // ALiBi
+        if (max_bias > 0.0f) {
+            const uint32_t h = iq2;
+
+            const float base = h < n_head_log2 ? m0 : m1;
+            const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+            slope = pow(base, exph);
+        }
+
+        // loop over the KV cache
+        // each simdgroup handles blocks of Q rows and C columns
+        for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
+            const int ic = ic0 + C*sgitg;
+            if (ic >= ne11) {
+                break;
+            }
+
+            // Q*K^T
+            {
+                for (short cc = 0; cc < C/8; ++cc) {
+                    simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
+
+                    device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
+
+                    for (short i = 0; i < D8; ++i) {
+                        simdgroup_half8x8 mk;
+                        simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
+
+                        simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
+                    }
+
+                    simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
+
+                    const short tx = tiisg%4;
+                    const short ty = tiisg/4;
+
+                    if (mask != q) {
+                        // mqk = mqk*scale + mask*slope
+                        ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
+                        ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
+                    } else {
+                        // mqk = mqk*scale
+                        ss[8*cc + ty*TF + 2*tx + 0] *= scale;
+                        ss[8*cc + ty*TF + 2*tx + 1] *= scale;
+                    }
+                }
+            }
+
+            // used to detect blocks full of -INF
+            float smax = -INFINITY;
+
+            // online softmax
+            {
+                float ms[Q];
+
+                for (short j = 0; j < Q; ++j) {
+                    const short p = tiisg;
+
+                    const float m = M[j];
+                    const float s = ss[j*TF + p];
+
+                    smax = simd_max(max(smax, s));
+                    M[j] = simd_max(max(M[j], s));
+
+                                ms[j] = exp(m - M[j]);
+                    const float vs    = exp(s - M[j]);
+
+                    S[j] = S[j]*ms[j] + simd_sum(vs);
+
+                    // the P matrix from the paper (Q rows, C columns)
+                    ss[j*TF + p] = vs;
+                }
+
+                // create a QxQ diagonal matrix for rescaling the output
+                if (tiisg < Q) {
+                    ss[tiisg*TF + C + tiisg] = ms[tiisg];
+                }
+            }
+
+            // skip -INF blocks
+            if (smax == -INFINITY) {
+                continue;
+            }
+
+            // O = diag(ms)*O
+            {
+                simdgroup_float8x8 mm;
+                simdgroup_load(mm, ss + C, TF, 0, false);
+
+                for (short i = 0; i < D8; ++i) {
+                    simdgroup_multiply(lo[i], mm, lo[i]);
+                }
+            }
+
+            // O = O + (Q*K^T)*V
+            {
+                for (short cc = 0; cc < C/8; ++cc) {
+                    device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
+
+                    for (short i = 0; i < D8; ++i) {
+                        simdgroup_half8x8 mk;
+                        simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
+
+                        simdgroup_float8x8 mv;
+                        simdgroup_load(mv, ss + 8*cc, TF, 0, false);
+
+                        simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]);
+                    }
+                }
+            }
+        }
+
+        // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
+        for (short j = 0; j < Q; ++j) {
+            if (tiisg == 0) {
+                ss[j*TF + 0] = S[j];
+                ss[j*TF + 1] = M[j];
+            }
+        }
+    }
+
+    // reduce the warps sequentially
+    for (short sg = 1; sg < nsg; ++sg) {
+        float S = { 0.0h };
+        float M = { -FLT_MAX/2 };
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        // each simdgroup stores its output to shared memory, reusing sq
+        if (sgitg == sg) {
+            for (short i = 0; i < D8; ++i) {
+                simdgroup_store(lo[i], sq + i*8, T, 0, false);
+            }
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        // the first simdgroup accumulates the results from the other simdgroups
+        if (sgitg == 0) {
+            for (short j = 0; j < Q; ++j) {
+                const float S0 = ss[j*TF +         0];
+                const float S1 = ss[j*TF + sg*SH + 0];
+
+                const float M0 = ss[j*TF +         1];
+                const float M1 = ss[j*TF + sg*SH + 1];
+
+                M = max(M0, M1);
+
+                const float ms0 = exp(M0 - M);
+                const float ms1 = exp(M1 - M);
+
+                S = S0*ms0 + S1*ms1;
+
+                if (tiisg == 0) {
+                    ss[j*TF + 0] = S;
+                    ss[j*TF + 1] = M;
+
+                    ss[j*TF + C + j        ] = ms0;
+                    ss[j*TF + C + j + sg*SH] = ms1;
+                }
+            }
+
+            // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
+            {
+                simdgroup_half8x8 t;
+                simdgroup_float8x8 ms0;
+                simdgroup_float8x8 ms1;
+
+                simdgroup_load(ms0, ss + C,         TF, 0, false);
+                simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false);
+
+                for (short i = 0; i < D8; ++i) {
+                    simdgroup_load    (t, sq + i*8, T, 0, false);
+                    simdgroup_multiply(t, ms1, t);
+
+                    simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
+                }
+            }
+        }
+    }
+
+    // store result to shared memory (reuse sq)
+    if (sgitg == 0) {
+        for (short i = 0; i < D8; ++i) {
+            simdgroup_store(lo[i], sq + i*8, T, 0, false);
+        }
+    }
+
+    device float4 * dst4 = (device float4 *) dst;
+
+    // final rescale with 1/S and store to global memory
+    if (sgitg == 0) {
+        for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
+            const float S = ss[j*TF + 0];
+
+            for (short i = tiisg; i < D4; i += NW) {
+                dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
+            }
+        }
+    }
+}
+
+template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>;
+template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>;
+template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
+template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
+template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
+//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
+
+template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
+kernel void kernel_flash_attn_ext_vec_f16(
+        device const  char * q,
+        device const  char * k,
+        device const  char * v,
+        device const  char * mask,
+        device       float * dst,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant   int64_t & ne13,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
+        constant  uint64_t & nb21,
+        constant  uint64_t & nb22,
+        constant  uint64_t & nb23,
+        constant  uint64_t & nb31,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant     float & scale,
+        constant     float & max_bias,
+        constant     float & m0,
+        constant     float & m1,
+        constant  uint32_t & n_head_log2,
+        threadgroup   half * shared [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        uint3  tpitg[[thread_position_in_threadgroup]],
+        uint3    ntg[[threads_per_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    const short nsg = ntg.y; // number of simdgroups
+
+    const short iq3 = tgpig[2];
+    const short iq2 = tgpig[1];
+    const short iq1 = tgpig[0];
+
+    const short D4 = D/4;
+    const short NW = N_SIMDWIDTH;
+    const short SH = (C + Q); // shared memory per simdgroup in (half)
+
+    const short T  = D + 2*nsg*SH; // shared memory size per query in (half)
+
+    float slope = 1.0f;
+
+    // ALiBi
+    if (max_bias > 0.0f) {
+        const uint32_t h = iq2;
+
+        const float base = h < n_head_log2 ? m0 : m1;
+        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+        slope = pow(base, exp);
+    }
+
+  //threadgroup half   * sq  = (threadgroup half   *) (shared +              0*D); // holds the query data
+    threadgroup half4  * sq4 = (threadgroup half4  *) (shared +              0*D); // same as above but in half4
+    threadgroup float  * ss  = (threadgroup float  *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
+    threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
+    threadgroup half4  * sr4 = (threadgroup half4  *) (shared +   sgitg*D  + 1*T); // scratch buffer for the results
+
+    // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
+    half4 lo[D4/NW];
+
+    // load heads from Q to shared memory
+    device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
+
+    for (short i = tiisg; i < D4; i += NW) {
+        if (iq1 < ne01) {
+            sq4[i] = (half4) q4[i];
+        } else {
+            sq4[i] = 0.0h;
+        }
+    }
+
+    // zero out lo
+    for (short i = tiisg; i < D4; i += NW) {
+        lo[i/NW] = 0.0h;
+    }
+
+    // zero out shared memory SH
+    for (short i = tiisg; i < SH/4; i += NW) {
+        ss4[i] = 0.0h;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    {
+        float S = { 0.0h };
+        float M = { -FLT_MAX/2 };
+
+        // assume K and V are same shape
+        const short ne22 = ne12;
+        const short ne23 = ne13;
+
+        // broadcast
+        const short rk2 = ne02/ne12;
+        const short rk3 = ne03/ne13;
+
+        const short rv2 = ne02/ne22;
+        const short rv3 = ne03/ne23;
+
+        // k indices
+        const short ik2 = iq2 / rk2;
+        const short ik3 = iq3 / rk3;
+
+        // v indices
+        const short iv2 = iq2 / rv2;
+        const short iv3 = iq3 / rv3;
+
+        // load the queries from shared memory into local memory
+        half4 mq[D4];
+
+        for (short ii = 0; ii < D4; ii += NW) {
+            short i = ii + tiisg;
+            mq[i] = sq4[i];
+        }
+
+        // pointer to the mask
+        device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
+
+        // loop over the KV cache
+        // each simdgroup handles blocks of Q rows and C columns
+        for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
+            const int ic = ic0 + C*sgitg;
+            if (ic >= ne11) {
+                break;
+            }
+
+            // Q*K^T
+            {
+#pragma unroll
+                for (short cc = 0; cc < C/4; ++cc) {
+                    float4 mqk = { 0.0h };
+
+                    device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
+
+#pragma unroll
+                    for (short ii = 0; ii < D4; ii += NW) {
+                        const short i = ii + tiisg;
+
+                        half4x4 mk;
+                        mk[0] = pk4[i + 0*(nb11/8)];
+                        mk[1] = pk4[i + 1*(nb11/8)];
+                        mk[2] = pk4[i + 2*(nb11/8)];
+                        mk[3] = pk4[i + 3*(nb11/8)];
+
+                        mqk += (float4) (mq[i] * mk);
+                    }
+
+                    // reduce the results from the threads in the simdgroup
+                    mqk += simd_shuffle_down(mqk, 16);
+                    mqk += simd_shuffle_down(mqk,  8);
+                    mqk += simd_shuffle_down(mqk,  4);
+                    mqk += simd_shuffle_down(mqk,  2);
+                    mqk += simd_shuffle_down(mqk,  1);
+
+                    // mqk = mqk*scale + mask*slope
+                    if (tiisg == 0) {
+                        mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f);
+
+                        ss4[cc] = mqk;
+                    }
+                }
+            }
+
+            // online softmax
+            {
+                const short p = tiisg;
+
+                const float m = M;
+                const float s = ss[p];
+
+                M = simd_max(max(M, s));
+
+                const float ms = exp(m - M);
+                const float vs = exp(s - M);
+
+                S = S*ms + simd_sum(vs);
+
+                // the P matrix from the paper (Q rows, C columns)
+                ss[p] = vs;
+
+                // O = diag(ms)*O
+#pragma unroll
+                for (short ii = 0; ii < D4; ii += NW) {
+                    const short i = ii + tiisg;
+                    lo[i/NW] *= ms;
+                }
+            }
+
+            // O = O + (Q*K^T)*V
+            {
+#pragma unroll
+                for (short cc = 0; cc < C/4; ++cc) {
+                    device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23));
+
+#pragma unroll
+                    for (short ii = 0; ii < D4; ii += NW) {
+                        const short i = ii + tiisg;
+
+                        lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
+                        lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
+                        lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
+                        lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
+                    }
+                }
+            }
+
+        }
+
+        // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
+        if (tiisg == 0) {
+            ss[0] = S;
+            ss[1] = M;
+        }
+    }
+
+    // store results to shared memory
+    for (short ii = 0; ii < D4; ii += NW) {
+        short i = ii + tiisg;
+        sr4[i] = lo[ii/NW];
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    // parallel reduce
+    for (short r = nsg/2; r > 0; r >>= 1) {
+        if (sgitg < r) {
+            const float S0 = ss[       0];
+            const float S1 = ss[r*SH + 0];
+
+            const float M0 = ss[       1];
+            const float M1 = ss[r*SH + 1];
+
+            const float M = max(M0, M1);
+
+            const float ms0 = exp(M0 - M);
+            const float ms1 = exp(M1 - M);
+
+            const float S = S0*ms0 + S1*ms1;
+
+            if (tiisg == 0) {
+                ss[0] = S;
+                ss[1] = M;
+            }
+
+            // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
+            for (short ii = 0; ii < D4; ii += NW) {
+                short i = ii + tiisg;
+                sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
+            }
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    device float4 * dst4 = (device float4 *) dst;
+
+    // final rescale with 1/S and store to global memory
+    if (sgitg == 0) {
+        const float S = ss[0];
+
+        for (short ii = 0; ii < D4; ii += NW) {
+            short i = ii + tiisg;
+            dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
+        }
+    }
+}
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
+//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
+
+kernel void kernel_cpy_f16_f16(
+        device  const half * src0,
+        device        half * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+        device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+        dst_data[i00] = src[0];
+    }
+}
+
+kernel void kernel_cpy_f16_f32(
+        device  const half * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+        device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+        dst_data[i00] = src[0];
+    }
+}
+
+kernel void kernel_cpy_f32_f16(
+        device const float * src0,
+        device        half * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        dst_data[i00] = src[0];
+    }
+}
+
+kernel void kernel_cpy_f32_f32(
+        device const float * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        dst_data[i00] = src[0];
+    }
+}
+
+kernel void kernel_cpy_f32_q8_0(
+        device const float * src0,
+        device        void * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
+
+    device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        float amax = 0.0f; // absolute max
+
+        for (int j = 0; j < QK8_0; j++) {
+            const float v = src[j];
+            amax = MAX(amax, fabs(v));
+        }
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK8_0].d = d;
+
+        for (int j = 0; j < QK8_0; ++j) {
+            const float x0 = src[j]*id;
+
+            dst_data[i00/QK8_0].qs[j] = round(x0);
+        }
+    }
+}
+
+kernel void kernel_cpy_f32_q4_0(
+        device const float * src0,
+        device        void * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
+
+    device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        float amax = 0.0f; // absolute max
+        float max  = 0.0f;
+
+        for (int j = 0; j < QK4_0; j++) {
+            const float v = src[j];
+            if (amax < fabs(v)) {
+                amax = fabs(v);
+                max  = v;
+            }
+        }
+
+        const float d = max / -8;
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK4_0].d = d;
+
+        for (int j = 0; j < QK4_0/2; ++j) {
+            const float x0 = src[0       + j]*id;
+            const float x1 = src[QK4_0/2 + j]*id;
+
+            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
+            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
+
+            dst_data[i00/QK4_0].qs[j]  = xi0;
+            dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
+        }
+    }
+}
+
+kernel void kernel_cpy_f32_q4_1(
+        device const float * src0,
+        device        void * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
+
+    device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        float min = FLT_MAX;
+        float max = -FLT_MAX;
+
+        for (int j = 0; j < QK4_1; j++) {
+            const float v = src[j];
+            if (min > v) min = v;
+            if (max < v) max = v;
+        }
+
+        const float d = (max - min) / ((1 << 4) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK4_1].d = d;
+        dst_data[i00/QK4_1].m = min;
+
+        for (int j = 0; j < QK4_1/2; ++j) {
+            const float x0 = (src[0       + j] - min)*id;
+            const float x1 = (src[QK4_1/2 + j] - min)*id;
+
+            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
+            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
+
+            dst_data[i00/QK4_1].qs[j]  = xi0;
+            dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
+        }
+    }
+}
+
+kernel void kernel_cpy_f32_q5_0(
+        device const float * src0,
+        device        void * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0;
+
+    device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        float amax = 0.0f; // absolute max
+        float max  = 0.0f;
+
+        for (int j = 0; j < QK5_0; j++) {
+            const float v = src[j];
+            if (amax < fabs(v)) {
+                amax = fabs(v);
+                max  = v;
+            }
+        }
+
+        const float d = max / -16;
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK5_0].d = d;
+
+        uint32_t qh = 0;
+        for (int j = 0; j < QK5_0/2; ++j) {
+            const float x0 = src[0       + j]*id;
+            const float x1 = src[QK5_0/2 + j]*id;
+
+            const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
+            const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
+
+            dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+            qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
+        }
+        thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
+        for (int j = 0; j < 4; ++j) {
+            dst_data[i00/QK5_0].qh[j] = qh8[j];
+        }
+    }
+}
+
+kernel void kernel_cpy_f32_q5_1(
+        device const float * src0,
+        device        void * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1;
+
+    device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        float max = src[0];
+        float min = src[0];
+
+        for (int j = 1; j < QK5_1; j++) {
+            const float v = src[j];
+            min = v < min ? v : min;
+            max = v > max ? v : max;
+        }
+
+        const float d = (max - min) / 31;
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK5_1].d = d;
+        dst_data[i00/QK5_1].m = min;
+
+        uint32_t qh = 0;
+        for (int j = 0; j < QK5_1/2; ++j) {
+            const float x0 = (src[0       + j] - min)*id;
+            const float x1 = (src[QK5_1/2 + j] - min)*id;
+
+            const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
+            const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
+
+            dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+            qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
+        }
+        thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
+        for (int j = 0; j < 4; ++j) {
+            dst_data[i00/QK5_1].qh[j] = qh8[j];
+        }
+    }
+}
+
+static inline int best_index_int8(int n, constant float * val, float x) {
+    if (x <= val[0]) return 0;
+    if (x >= val[n-1]) return n-1;
+    int ml = 0, mu = n-1;
+    while (mu-ml > 1) {
+        int mav = (ml+mu)/2;
+        if (x < val[mav]) mu = mav; else ml = mav;
+    }
+    return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
+}
+
+constexpr constant static float kvalues_iq4nl_f[16] = {
+    -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
+};
+
+kernel void kernel_cpy_f32_iq4_nl(
+        device const float * src0,
+        device        void * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL;
+
+    device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        float amax = 0.0f; // absolute max
+        float max  = 0.0f;
+
+        for (int j = 0; j < QK4_0; j++) {
+            const float v = src[j];
+            if (amax < fabs(v)) {
+                amax = fabs(v);
+                max  = v;
+            }
+        }
+
+        const float d = max / kvalues_iq4nl_f[0];
+        const float id = d ? 1.0f/d : 0.0f;
+
+        float sumqx = 0, sumq2 = 0;
+        for (int j = 0; j < QK4_NL/2; ++j) {
+            const float x0 = src[0        + j]*id;
+            const float x1 = src[QK4_NL/2 + j]*id;
+
+            const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
+            const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
+
+            dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
+
+            const float v0 = kvalues_iq4nl_f[xi0];
+            const float v1 = kvalues_iq4nl_f[xi1];
+            const float w0 = src[0        + j]*src[0        + j];
+            const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
+            sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
+            sumq2 += w0*v0*v0 + w1*v1*v1;
+
+        }
+
+        dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
+
+    }
+}
+
+kernel void kernel_concat(
+    device  const char * src0,
+    device  const char * src1,
+    device        char * dst,
+    constant   int64_t & ne00,
+    constant   int64_t & ne01,
+    constant   int64_t & ne02,
+    constant   int64_t & ne03,
+    constant  uint64_t & nb00,
+    constant  uint64_t & nb01,
+    constant  uint64_t & nb02,
+    constant  uint64_t & nb03,
+    constant   int64_t & ne10,
+    constant   int64_t & ne11,
+    constant   int64_t & ne12,
+    constant   int64_t & ne13,
+    constant  uint64_t & nb10,
+    constant  uint64_t & nb11,
+    constant  uint64_t & nb12,
+    constant  uint64_t & nb13,
+    constant   int64_t & ne0,
+    constant   int64_t & ne1,
+    constant   int64_t & ne2,
+    constant   int64_t & ne3,
+    constant  uint64_t & nb0,
+    constant  uint64_t & nb1,
+    constant  uint64_t & nb2,
+    constant  uint64_t & nb3,
+    constant   int32_t & dim,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    int64_t o[4] = {0, 0, 0, 0};
+    o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
+
+    device const float * x;
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+            x = (device const float *)(src0 + (i3       )*nb03 + (i2       )*nb02 + (i1       )*nb01 + (i0       )*nb00);
+        } else {
+            x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
+        }
+
+        device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+        *y = *x;
+    }
+}
+
+void kernel_mul_mv_q2_K_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
+
+    const int nb = ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int ib_row = first_row * nb;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+    device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
+    device const float      * y = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int step = sizeof(block_q2_K) * nb;
+
+    const int ix = tiisg/8;  // 0...3
+    const int it = tiisg%8;  // 0...7
+    const int iq = it/4;     // 0 or 1
+    const int ir = it%4;     // 0...3
+    const int is = (8*ir)/16;// 0 or 1
+
+    device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
+
+    for (int ib = ix; ib < nb; ib += 4) {
+
+        float4 sumy = {0.f, 0.f, 0.f, 0.f};
+        for (int i = 0; i < 8; ++i) {
+            yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
+            yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
+            yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
+            yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
+        }
+
+        device const uint8_t  * sc = (device const uint8_t  *)x[ib].scales + 8*iq + is;
+        device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
+        device const half     * dh = &x[ib].d;
+
+        for (int row = 0; row < N_DST; row++) {
+
+            float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+            float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+            for (int i = 0; i < 8; i += 2) {
+                acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
+                acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
+                acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
+                acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
+                acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
+                acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
+                acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
+                acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
+            }
+            float dall = dh[0];
+            float dmin = dh[1] * 1.f/16.f;
+            sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
+                                 (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
+                                 (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
+                                 (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
+                         dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
+
+            qs += step/2;
+            sc += step;
+            dh += step/2;
+        }
+
+        y4 += 4 * QK_K;
+    }
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_q2_K_f32")]]
+kernel void kernel_mul_mv_q2_K_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_q3_K_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
+
+    const int nb = ne00/QK_K;
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+    const int64_t im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+    device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
+    device const float     * yy = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float yl[32];
+
+    //const uint16_t kmask1 = 0x3030;
+    //const uint16_t kmask2 = 0x0f0f;
+
+    const int tid = tiisg/4;
+    const int ix  = tiisg%4;
+    const int ip  = tid/4;          // 0 or 1
+    const int il  = 2*((tid%4)/2);  // 0 or 2
+    const int ir  = tid%2;
+    const int n   = 8;
+    const int l0  = n*ir;
+
+    // One would think that the Metal compiler would figure out that ip and il can only have
+    // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
+    // with these two tales.
+    //
+    // Possible masks for the high bit
+    const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200},  // ip = 0, il = 0
+                           {0x0004, 0x0400, 0x0008, 0x0800},  // ip = 0, il = 2
+                           {0x0010, 0x1000, 0x0020, 0x2000},  // ip = 1, il = 0
+                           {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
+
+    // Possible masks for the low 2 bits
+    const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
+
+    const ushort4 hm = mm[2*ip + il/2];
+
+    const int shift = 2*il;
+    const float    v1 = il == 0 ? 4.f : 64.f;
+    const float    v2 = 4.f * v1;
+
+    const uint16_t s_shift1 = 4*ip;
+    const uint16_t s_shift2 = s_shift1 + il;
+
+    const int q_offset = 32*ip + l0;
+    const int y_offset = 128*ip + 32*il + l0;
+
+    const int step = sizeof(block_q3_K) * nb / 2;
+
+    device const float * y1 = yy + ix*QK_K + y_offset;
+
+    uint32_t scales32, aux32;
+    thread uint16_t * scales16 = (thread uint16_t *)&scales32;
+    thread const int8_t * scales = (thread const int8_t *)&scales32;
+
+    float sumf1[2] = {0.f};
+    float sumf2[2] = {0.f};
+    for (int i = ix; i < nb; i += 4) {
+
+        for (int l = 0; l < 8; ++l) {
+            yl[l+ 0] = y1[l+ 0];
+            yl[l+ 8] = y1[l+16];
+            yl[l+16] = y1[l+32];
+            yl[l+24] = y1[l+48];
+        }
+
+        device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
+        device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
+        device const uint16_t * a = (device const uint16_t *)(x[i].scales);
+        device const half * dh = &x[i].d;
+
+        for (int row = 0; row < 2; ++row) {
+
+            const float d_all = (float)dh[0];
+
+            scales16[0] = a[4];
+            scales16[1] = a[5];
+            aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
+            scales16[0] = a[il+0];
+            scales16[1] = a[il+1];
+            scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
+
+            float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
+            for (int l = 0; l < n; l += 2) {
+                const int32_t qs = q[l/2];
+                s1 += yl[l+0] * (qs & qm[il/2][0]);
+                s2 += yl[l+1] * (qs & qm[il/2][1]);
+                s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
+                s4 += yl[l+16] * (qs & qm[il/2][2]);
+                s5 += yl[l+17] * (qs & qm[il/2][3]);
+                s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
+            }
+            float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
+            float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
+            sumf1[row] += d1 * (scales[0] - 32);
+            sumf2[row] += d2 * (scales[2] - 32);
+
+            s1 = s2 = s3 = s4 = s5 = s6 = 0;
+            for (int l = 0; l < n; l += 2) {
+                const int32_t qs = q[l/2+8];
+                s1 += yl[l+8] * (qs & qm[il/2][0]);
+                s2 += yl[l+9] * (qs & qm[il/2][1]);
+                s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
+                s4 += yl[l+24] * (qs & qm[il/2][2]);
+                s5 += yl[l+25] * (qs & qm[il/2][3]);
+                s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
+            }
+            d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
+            d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
+            sumf1[row] += d1 * (scales[1] - 32);
+            sumf2[row] += d2 * (scales[3] - 32);
+
+            q  += step;
+            h  += step;
+            a  += step;
+            dh += step;
+
+        }
+
+        y1 += 4 * QK_K;
+
+    }
+
+    for (int row = 0; row < 2; ++row) {
+        const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
+        sumf1[row] = simd_sum(sumf);
+    }
+    if (tiisg == 0) {
+        for (int row = 0; row < 2; ++row) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_q3_K_f32")]]
+kernel void kernel_mul_mv_q3_K_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_q4_K_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
+    const int ix = tiisg/8;  // 0...3
+    const int it = tiisg%8;  // 0...7
+    const int iq = it/4;     // 0 or 1
+    const int ir = it%4;     // 0...3
+
+    const int nb = ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+    //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = r0 * N_DST;
+    const int ib_row = first_row * nb;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+    device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
+    device const float      * y = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float yl[16];
+    float yh[16];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int step = sizeof(block_q4_K) * nb / 2;
+
+    device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
+
+    uint16_t sc16[4];
+    thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
+
+    for (int ib = ix; ib < nb; ib += 4) {
+
+        float4 sumy = {0.f, 0.f, 0.f, 0.f};
+        for (int i = 0; i < 8; ++i) {
+            yl[i+0] = y4[i+  0]; sumy[0] += yl[i+0];
+            yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
+            yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
+            yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
+        }
+
+        device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
+        device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
+        device const half     * dh = &x[ib].d;
+
+        for (int row = 0; row < N_DST; row++) {
+
+            sc16[0] = sc[0] & kmask1;
+            sc16[1] = sc[2] & kmask1;
+            sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
+            sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
+
+            device const uint16_t * q2 = q1 + 32;
+
+            float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+            float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+            for (int i = 0; i < 8; i += 2) {
+                acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
+                acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
+                acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
+                acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
+                acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
+                acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
+                acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
+                acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
+            }
+
+            float dall = dh[0];
+            float dmin = dh[1];
+            sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
+                                 (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
+                                 (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
+                                 (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
+                         dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
+
+            q1 += step;
+            sc += step;
+            dh += step;
+        }
+
+        y4 += 4 * QK_K;
+    }
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_q4_K_f32")]]
+kernel void kernel_mul_mv_q4_K_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_q5_K_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
+
+    const int nb = ne00/QK_K;
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+    device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
+    device const float     * yy = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float sumf[2]={0.f};
+
+    const int step = sizeof(block_q5_K) * nb;
+
+    float yl[16], yh[16];
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
+    const int tid = tiisg/4;
+    const int ix  = tiisg%4;
+    const int iq  = tid/4;
+    const int ir  = tid%4;
+    const int n   = 8;
+
+    const int l0 = n*ir;
+    const int q_offset = 32*iq + l0;
+    const int y_offset = 64*iq + l0;
+
+    const uint8_t hm1 = 1u << (2*iq);
+    const uint8_t hm2 = hm1 << 1;
+    const uint8_t hm3 = hm1 << 4;
+    const uint8_t hm4 = hm2 << 4;
+
+    uint16_t sc16[4];
+    thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
+
+    device const float * y1 = yy + ix*QK_K + y_offset;
+
+    for (int i = ix; i < nb; i += 4) {
+
+        device const uint8_t * q1 = x[i].qs + q_offset;
+        device const uint8_t * qh = x[i].qh + l0;
+        device const half * dh = &x[i].d;
+        device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
+
+        device const float * y2 = y1 + 128;
+        float4 sumy = {0.f, 0.f, 0.f, 0.f};
+        for (int l = 0; l < 8; ++l) {
+            yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
+            yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
+            yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
+            yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
+        }
+
+        for (int row = 0; row < 2; ++row) {
+
+            device const uint8_t * q2 = q1 + 64;
+
+            sc16[0] = a[0] & kmask1;
+            sc16[1] = a[2] & kmask1;
+            sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
+            sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
+
+            float4 acc1 = {0.f};
+            float4 acc2 = {0.f};
+            for (int l = 0; l < n; ++l) {
+                uint8_t h = qh[l];
+                acc1[0] += yl[l+0] * (q1[l] & 0x0F);
+                acc1[1] += yl[l+8] * (q1[l] & 0xF0);
+                acc1[2] += yh[l+0] * (q2[l] & 0x0F);
+                acc1[3] += yh[l+8] * (q2[l] & 0xF0);
+                acc2[0] += h & hm1 ? yl[l+0] : 0.f;
+                acc2[1] += h & hm2 ? yl[l+8] : 0.f;
+                acc2[2] += h & hm3 ? yh[l+0] : 0.f;
+                acc2[3] += h & hm4 ? yh[l+8] : 0.f;
+            }
+            const float dall = dh[0];
+            const float dmin = dh[1];
+            sumf[row] += dall * (sc8[0] * (acc1[0] +  16.f*acc2[0]) +
+                                 sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
+                                 sc8[4] * (acc1[2] +  16.f*acc2[2]) +
+                                 sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
+                         dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
+
+            q1 += step;
+            qh += step;
+            dh += step/2;
+            a  += step/2;
+
+        }
+
+        y1 += 4 * QK_K;
+
+    }
+
+    for (int row = 0; row < 2; ++row) {
+        const float tot = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_q5_K_f32")]]
+kernel void kernel_mul_mv_q5_K_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_q6_K_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
+
+    const uint8_t kmask1 = 0x03;
+    const uint8_t kmask2 = 0x0C;
+    const uint8_t kmask3 = 0x30;
+    const uint8_t kmask4 = 0xC0;
+
+    const int nb = ne00/QK_K;
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+    const int     im = tgpig.z;
+
+    const int row = 2 * r0 + sgitg;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+    device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
+    device const float     * yy = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float sumf = 0;
+
+    const int tid  = tiisg/2;
+    const int ix   = tiisg%2;
+    const int ip   = tid/8;         // 0 or 1
+    const int il   = tid%8;
+    const int n    = 4;
+    const int l0   = n*il;
+    const int is   = 8*ip + l0/16;
+
+    const int y_offset = 128*ip + l0;
+    const int q_offset_l = 64*ip + l0;
+    const int q_offset_h = 32*ip + l0;
+
+    for (int i = ix; i < nb; i += 2) {
+
+        device const uint8_t * q1 = x[i].ql + q_offset_l;
+        device const uint8_t * q2 = q1 + 32;
+        device const uint8_t * qh = x[i].qh + q_offset_h;
+        device const int8_t  * sc = x[i].scales + is;
+
+        device const float * y = yy + i * QK_K + y_offset;
+
+        const float dall = x[i].d;
+
+        float4 sums = {0.f, 0.f, 0.f, 0.f};
+        for (int l = 0; l < n; ++l) {
+            sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
+            sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
+            sums[2] += y[l+64] * ((int8_t)((q1[l]  >> 4) | ((qh[l] & kmask3) << 0)) - 32);
+            sums[3] += y[l+96] * ((int8_t)((q2[l]  >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
+        }
+
+        sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
+
+    }
+
+    const float tot = simd_sum(sumf);
+    if (tiisg == 0) {
+        dst[r1*ne0 + im*ne0*ne1 + row] = tot;
+    }
+}
+
+[[host_name("kernel_mul_mv_q6_K_f32")]]
+kernel void kernel_mul_mv_q6_K_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+}
+
+// ======================= "True" 2-bit
+
+void kernel_mul_mv_iq2_xxs_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
+
+    const int nb = ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int ib_row = first_row * nb;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+    device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0;
+    device const float         * y = (device const float         *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int nb32 = nb * (QK_K / 32);
+
+    threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
+    threadgroup uint8_t  * shared_signs = (threadgroup uint8_t *)(values + 256);
+    {
+        int nval = 4;
+        int pos  = (32*sgitg + tiisg)*nval;
+        for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i];
+        nval = 2;
+        pos  = (32*sgitg + tiisg)*nval;
+        for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    const int ix = tiisg;
+
+    device const float * y4 = y + 32 * ix;
+
+    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+        for (int i = 0; i < 32; ++i) {
+            yl[i] = y4[i];
+        }
+
+        const int ibl = ib32 / (QK_K / 32);
+        const int ib  = ib32 % (QK_K / 32);
+
+        device const block_iq2_xxs * xr = x + ibl;
+        device const uint16_t * q2 = xr->qs + 4 * ib;
+        device const half * dh = &xr->d;
+
+        for (int row = 0; row < N_DST; row++) {
+
+            const float db = dh[0];
+            device const uint8_t * aux8 = (device const uint8_t *)q2;
+            const uint32_t aux32 = q2[2] | (q2[3] << 16);
+            const float d = db * (0.5f + (aux32 >> 28));
+
+            float sum = 0;
+            for (int l = 0; l < 4; ++l) {
+                const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]);
+                const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
+                for (int j = 0; j < 8; ++j) {
+                    sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+                }
+            }
+            sumf[row] += d * sum;
+
+            dh += nb*sizeof(block_iq2_xxs)/2;
+            q2 += nb*sizeof(block_iq2_xxs)/2;
+        }
+
+        y4 += 32 * 32;
+    }
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_iq2_xxs_f32")]]
+kernel void kernel_mul_mv_iq2_xxs_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        threadgroup int8_t * shared_values [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_iq2_xs_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
+
+    const int nb = ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int ib_row = first_row * nb;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+    device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0;
+    device const float        * y = (device const float        *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int nb32 = nb * (QK_K / 32);
+
+    threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
+    threadgroup uint8_t  * shared_signs = (threadgroup uint8_t *)(values + 512);
+    {
+        int nval = 8;
+        int pos  = (32*sgitg + tiisg)*nval;
+        for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i];
+        nval = 2;
+        pos  = (32*sgitg + tiisg)*nval;
+        for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    const int ix = tiisg;
+
+    device const float * y4 = y + 32 * ix;
+
+    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+        for (int i = 0; i < 32; ++i) {
+            yl[i] = y4[i];
+        }
+
+        const int ibl = ib32 / (QK_K / 32);
+        const int ib  = ib32 % (QK_K / 32);
+
+        device const block_iq2_xs * xr = x + ibl;
+        device const uint16_t * q2 = xr->qs + 4 * ib;
+        device const uint8_t  * sc = xr->scales + ib;
+        device const half * dh = &xr->d;
+
+        for (int row = 0; row < N_DST; row++) {
+
+            const float db = dh[0];
+            const uint8_t ls1 = sc[0] & 0xf;
+            const uint8_t ls2 = sc[0] >>  4;
+            const float d1 = db * (0.5f + ls1);
+            const float d2 = db * (0.5f + ls2);
+
+            float sum1 = 0, sum2 = 0;
+            for (int l = 0; l < 2; ++l) {
+                const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
+                const uint8_t signs = shared_signs[(q2[l] >> 9)];
+                for (int j = 0; j < 8; ++j) {
+                    sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+                }
+            }
+            for (int l = 2; l < 4; ++l) {
+                const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
+                const uint8_t signs = shared_signs[(q2[l] >> 9)];
+                for (int j = 0; j < 8; ++j) {
+                    sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+                }
+            }
+            sumf[row] += d1 * sum1 + d2 * sum2;
+
+            dh += nb*sizeof(block_iq2_xs)/2;
+            q2 += nb*sizeof(block_iq2_xs)/2;
+            sc += nb*sizeof(block_iq2_xs);
+        }
+
+        y4 += 32 * 32;
+    }
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_iq2_xs_f32")]]
+kernel void kernel_mul_mv_iq2_xs_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        threadgroup int8_t * shared_values [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_iq3_xxs_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
+
+    const int nb = ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int ib_row = first_row * nb;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+    device const block_iq3_xxs * x = (device const block_iq3_xxs *) src0 + ib_row + offset0;
+    device const float         * y = (device const float         *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int nb32 = nb * (QK_K / 32);
+
+    threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
+    threadgroup uint8_t  * shared_signs = (threadgroup uint8_t *)(values + 256);
+    {
+        int nval = 4;
+        int pos  = (32*sgitg + tiisg)*nval;
+        for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i];
+        nval = 2;
+        pos  = (32*sgitg + tiisg)*nval;
+        for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    const int ix = tiisg;
+
+    device const float * y4 = y + 32 * ix;
+
+    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+        for (int i = 0; i < 32; ++i) {
+            yl[i] = y4[i];
+        }
+
+        const int ibl = ib32 / (QK_K / 32);
+        const int ib  = ib32 % (QK_K / 32);
+
+        device const block_iq3_xxs * xr = x + ibl;
+        device const uint8_t  * q3 = xr->qs + 8 * ib;
+        device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
+        device const half * dh = &xr->d;
+
+        for (int row = 0; row < N_DST; row++) {
+
+            const float db = dh[0];
+            const uint32_t aux32 = gas[0] | (gas[1] << 16);
+            const float d = db * (0.5f + (aux32 >> 28));
+
+            float2 sum = {0};
+            for (int l = 0; l < 4; ++l) {
+                const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]);
+                const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]);
+                const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
+                for (int j = 0; j < 4; ++j) {
+                    sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+                    sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+                }
+            }
+            sumf[row] += d * (sum[0] + sum[1]);
+
+            dh  += nb*sizeof(block_iq3_xxs)/2;
+            q3  += nb*sizeof(block_iq3_xxs);
+            gas += nb*sizeof(block_iq3_xxs)/2;
+        }
+
+        y4 += 32 * 32;
+    }
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_iq3_xxs_f32")]]
+kernel void kernel_mul_mv_iq3_xxs_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        threadgroup int8_t * shared_values [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_iq3_s_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
+
+    const int nb = ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int ib_row = first_row * nb;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+    device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0;
+    device const float       * y = (device const float       *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int nb32 = nb * (QK_K / 32);
+
+    threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
+    {
+        int nval = 8;
+        int pos  = (32*sgitg + tiisg)*nval;
+        for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i];
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    const int ix = tiisg;
+
+    device const float * y4 = y + 32 * ix;
+
+    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+        for (int i = 0; i < 32; ++i) {
+            yl[i] = y4[i];
+        }
+
+        const int ibl = ib32 / (QK_K / 32);
+        const int ib  = ib32 % (QK_K / 32);
+
+        device const block_iq3_s * xr = x + ibl;
+        device const uint8_t * qs = xr->qs + 8 * ib;
+        device const uint8_t * qh = xr->qh + ib;
+        device const uint8_t * sc = xr->scales + (ib/2);
+        device const uint8_t * signs = xr->signs + 4 * ib;
+        device const half * dh = &xr->d;
+
+        for (int row = 0; row < N_DST; row++) {
+
+            const float db = dh[0];
+            const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
+
+            float2 sum = {0};
+            for (int l = 0; l < 4; ++l) {
+                const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values;
+                const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values;
+                const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
+                const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
+                for (int j = 0; j < 4; ++j) {
+                    sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
+                    sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
+                }
+            }
+            sumf[row] += d * (sum[0] + sum[1]);
+
+            dh  += nb*sizeof(block_iq3_s)/2;
+            qs  += nb*sizeof(block_iq3_s);
+            qh  += nb*sizeof(block_iq3_s);
+            sc  += nb*sizeof(block_iq3_s);
+            signs += nb*sizeof(block_iq3_s);
+        }
+
+        y4 += 32 * 32;
+    }
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_iq3_s_f32")]]
+kernel void kernel_mul_mv_iq3_s_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        threadgroup int8_t * shared_values [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_iq2_s_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
+
+    const int nb = ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int ib_row = first_row * nb;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+    device const block_iq2_s * x = (device const block_iq2_s *) src0 + ib_row + offset0;
+    device const float       * y = (device const float       *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int nb32 = nb * (QK_K / 32);
+
+    //threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
+    //{
+    //    int nval = 32;
+    //    int pos  = (32*sgitg + tiisg)*nval;
+    //    for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i];
+    //    threadgroup_barrier(mem_flags::mem_threadgroup);
+    //}
+
+    const int ix = tiisg;
+
+    device const float * y4 = y + 32 * ix;
+
+    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+        for (int i = 0; i < 32; ++i) {
+            yl[i] = y4[i];
+        }
+
+        const int ibl = ib32 / (QK_K / 32);
+        const int ib  = ib32 % (QK_K / 32);
+
+        device const block_iq2_s * xr = x + ibl;
+        device const uint8_t * qs = xr->qs + 4 * ib;
+        device const uint8_t * qh = xr->qh + ib;
+        device const uint8_t * sc = xr->scales + ib;
+        device const uint8_t * signs = qs + QK_K/8;
+        device const half * dh = &xr->d;
+
+        for (int row = 0; row < N_DST; row++) {
+
+            const float db = dh[0];
+            const float d1 = db * (0.5f + (sc[0] & 0xf));
+            const float d2 = db * (0.5f + (sc[0] >>  4));
+
+            float2 sum = {0};
+            for (int l = 0; l < 2; ++l) {
+                //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
+                //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
+                constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
+                constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
+                for (int j = 0; j < 8; ++j) {
+                    sum[0] += yl[8*l + j +  0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
+                    sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
+                }
+            }
+            sumf[row] += d1 * sum[0] + d2 * sum[1];
+
+            dh  += nb*sizeof(block_iq2_s)/2;
+            qs  += nb*sizeof(block_iq2_s);
+            qh  += nb*sizeof(block_iq2_s);
+            sc  += nb*sizeof(block_iq2_s);
+            signs += nb*sizeof(block_iq2_s);
+        }
+
+        y4 += 32 * 32;
+    }
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_iq2_s_f32")]]
+kernel void kernel_mul_mv_iq2_s_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        threadgroup int8_t * shared_values [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_iq1_s_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_value,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
+
+    const int nb = ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int ib_row = first_row * nb;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
+    device const float       * y = (device const float       *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int nb32 = nb * (QK_K / 32);
+
+    const int ix = tiisg;
+
+    device const float * y4 = y + 32 * ix;
+
+    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+        float sumy = 0;
+        for (int i = 0; i < 32; ++i) {
+            yl[i] = y4[i];
+            sumy += yl[i];
+        }
+
+        const int ibl = ib32 / (QK_K / 32);
+        const int ib  = ib32 % (QK_K / 32);
+
+        device const block_iq1_s * xr = x + ibl;
+        device const uint8_t  * qs = xr->qs + 4 * ib;
+        device const uint16_t * qh = xr->qh + ib;
+        device const half     * dh = &xr->d;
+
+        for (int row = 0; row < N_DST; row++) {
+
+            constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
+            constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
+            constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
+            constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
+
+            float sum = 0;
+            for (int j = 0; j < 4; ++j) {
+                sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
+                     + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
+                     + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
+                     + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
+            }
+            sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1);
+
+            dh += nb*sizeof(block_iq1_s)/2;
+            qs += nb*sizeof(block_iq1_s);
+            qh += nb*sizeof(block_iq1_s)/2;
+        }
+
+        y4 += 32 * 32;
+    }
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+        }
+    }
+}
+
+void kernel_mul_mv_iq1_m_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_value,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
+
+    const int nb = ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int ib_row = first_row * nb;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0;
+    device const float       * y = (device const float       *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int nb32 = nb * (QK_K / 32);
+
+    const int ix = tiisg;
+
+    device const float * y4 = y + 32 * ix;
+
+    iq1m_scale_t scale;
+
+    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+        float4 sumy = {0.f};
+        for (int i = 0; i < 8; ++i) {
+            yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
+            yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
+            yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
+            yl[i+24] = y4[i+24]; sumy[3] += yl[i+24];
+        }
+
+        const int ibl = ib32 / (QK_K / 32);
+        const int ib  = ib32 % (QK_K / 32);
+
+        device const block_iq1_m * xr = x + ibl;
+        device const uint8_t  * qs = xr->qs + 4 * ib;
+        device const uint8_t  * qh = xr->qh + 2 * ib;
+        device const uint16_t * sc = (device const uint16_t *)xr->scales;
+
+        for (int row = 0; row < N_DST; row++) {
+            scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+
+            constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
+            constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
+            constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700)));
+            constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
+
+            float2 sum = {0.f};
+            for (int j = 0; j < 4; ++j) {
+                sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
+                        + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
+                sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
+                        + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
+            }
+            const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+            const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+
+            sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
+                                             (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
+
+            sc += nb*sizeof(block_iq1_m)/2;
+            qs += nb*sizeof(block_iq1_m);
+            qh += nb*sizeof(block_iq1_m);
+        }
+
+        y4 += 32 * 32;
+    }
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+        }
+    }
+}
+
+void kernel_mul_mv_iq4_nl_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values_i8,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
+
+    threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
+    const int nb = ne00/QK4_NL;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+    const int first_row = (r0 * 2 + sgitg) * 2;
+    const int ib_row = first_row * nb;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0;
+    device const float        * y = (device const float        *) src1 + r1*ne10 + im*ne00*ne1;
+
+    const int ix = tiisg/2;  // 0...15
+    const int it = tiisg%2;  // 0 or 1
+
+    shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    float4 yl[4];
+    float sumf[2]={0.f}, all_sum;
+
+    device const float * yb = y + ix * QK4_NL + it * 8;
+
+    uint32_t aux32[2];
+    thread const uint8_t * q8 = (thread const uint8_t *)aux32;
+
+    float4 qf1, qf2;
+
+    for (int ib = ix; ib < nb; ib += 16) {
+
+        device const float4 * y4 = (device const float4 *)yb;
+        yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
+
+        for (int row = 0; row < 2; ++row) {
+
+            device const block_iq4_nl & xb = x[row*nb + ib];
+            device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
+
+            float4 acc1 = {0.f}, acc2 = {0.f};
+
+            aux32[0] = q4[0] | (q4[1] << 16);
+            aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
+            aux32[0] &= 0x0f0f0f0f;
+            qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
+            qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
+            acc1 += yl[0] * qf1;
+            acc2 += yl[1] * qf2;
+
+            aux32[0] = q4[2] | (q4[3] << 16);
+            aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
+            aux32[0] &= 0x0f0f0f0f;
+            qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
+            qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
+            acc1 += yl[2] * qf1;
+            acc2 += yl[3] * qf2;
+
+            acc1 += acc2;
+
+            sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
+
+        }
+
+        yb += 16 * QK4_NL;
+    }
+
+    for (int row = 0; row < 2; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+        }
+    }
+}
+
+void kernel_mul_mv_iq4_xs_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values_i8,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
+
+    threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
+    const int nb = ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+    const int first_row = (r0 * 2 + sgitg) * 2;
+    const int ib_row = first_row * nb;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0;
+    device const float        * y = (device const float        *) src1 + r1*ne10 + im*ne00*ne1;
+
+    const int ix = tiisg/16;  // 0 or 1
+    const int it = tiisg%16;  // 0...15
+    const int ib = it/2;
+    const int il = it%2;
+
+    shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    float4 yl[4];
+    float sumf[2]={0.f}, all_sum;
+
+    device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
+
+    uint32_t aux32[2];
+    thread const uint8_t * q8 = (thread const uint8_t *)aux32;
+
+    float4 qf1, qf2;
+
+    for (int ibl = ix; ibl < nb; ibl += 2) {
+
+        device const float4 * y4 = (device const float4 *)yb;
+        yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
+
+        for (int row = 0; row < 2; ++row) {
+
+            device const block_iq4_xs & xb = x[row*nb + ibl];
+            device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
+
+            float4 acc1 = {0.f}, acc2 = {0.f};
+
+            aux32[0] = q4[0] & 0x0f0f0f0f;
+            aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
+            qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
+            qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
+            acc1 += yl[0] * qf1;
+            acc2 += yl[1] * qf2;
+
+            aux32[0] = q4[1] & 0x0f0f0f0f;
+            aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
+            qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
+            qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
+            acc1 += yl[2] * qf1;
+            acc2 += yl[3] * qf2;
+
+            acc1 += acc2;
+
+            const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
+            sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
+
+        }
+
+        yb += 2 * QK_K;
+    }
+
+    for (int row = 0; row < 2; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_iq1_s_f32")]]
+kernel void kernel_mul_mv_iq1_s_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+}
+
+[[host_name("kernel_mul_mv_iq1_m_f32")]]
+kernel void kernel_mul_mv_iq1_m_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+}
+
+[[host_name("kernel_mul_mv_iq4_nl_f32")]]
+kernel void kernel_mul_mv_iq4_nl_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        threadgroup int8_t * shared_values [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+}
+
+[[host_name("kernel_mul_mv_iq4_xs_f32")]]
+kernel void kernel_mul_mv_iq4_xs_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        threadgroup int8_t * shared_values [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+}
+
+//============================= templates and their specializations =============================
+
+// NOTE: this is not dequantizing - we are simply fitting the template
+template <typename type4x4>
+void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
+    float4x4 temp = *(((device float4x4 *)src));
+    for (int i = 0; i < 16; i++){
+        reg[i/4][i%4] = temp[i/4][i%4];
+    }
+}
+
+template <typename type4x4>
+void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
+    half4x4 temp = *(((device half4x4 *)src));
+    for (int i = 0; i < 16; i++){
+        reg[i/4][i%4] = temp[i/4][i%4];
+    }
+}
+
+template <typename type4x4>
+void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 1);
+    const float d1 = il ? (xb->d / 16.h) : xb->d;
+    const float d2 = d1 / 256.f;
+    const float md = -8.h * xb->d;
+    const ushort mask0 = il ? 0x00F0 : 0x000F;
+    const ushort mask1 = mask0 << 8;
+
+    for (int i=0;i<8;i++) {
+        reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
+        reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
+    }
+}
+
+template <typename type4x4>
+void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 2);
+    const float d1 = il ? (xb->d / 16.h) : xb->d;
+    const float d2 = d1 / 256.f;
+    const float  m = xb->m;
+    const ushort mask0 = il ? 0x00F0 : 0x000F;
+    const ushort mask1 = mask0 << 8;
+
+    for (int i=0;i<8;i++) {
+        reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
+        reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
+    }
+}
+
+template <typename type4x4>
+void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 3);
+    const float d = xb->d;
+    const float md = -16.h * xb->d;
+    const ushort mask = il ? 0x00F0 : 0x000F;
+
+    const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+    const int x_mv = il ? 4 : 0;
+
+    const int gh_mv = il ? 12 : 0;
+    const int gh_bk = il ?  0 : 4;
+
+    for (int i = 0; i < 8; i++) {
+        // extract the 5-th bits for x0 and x1
+        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
+        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+        // combine the 4-bits from qs with the 5th bit
+        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
+        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+        reg[i/2][2*(i%2)+0] = d * x0 + md;
+        reg[i/2][2*(i%2)+1] = d * x1 + md;
+    }
+}
+
+template <typename type4x4>
+void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 4);
+    const float d = xb->d;
+    const float m = xb->m;
+    const ushort mask = il ? 0x00F0 : 0x000F;
+
+    const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+    const int x_mv = il ? 4 : 0;
+
+    const int gh_mv = il ? 12 : 0;
+    const int gh_bk = il ?  0 : 4;
+
+    for (int i = 0; i < 8; i++) {
+        // extract the 5-th bits for x0 and x1
+        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
+        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+        // combine the 4-bits from qs with the 5th bit
+        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
+        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+        reg[i/2][2*(i%2)+0] = d * x0 + m;
+        reg[i/2][2*(i%2)+1] = d * x1 + m;
+    }
+}
+
+template <typename type4x4>
+void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
+    device const int8_t * qs = ((device const int8_t *)xb->qs);
+    const half d = xb->d;
+
+    for (int i = 0; i < 16; i++) {
+        reg[i/4][i%4] = (qs[i + 16*il] * d);
+    }
+}
+
+template <typename type4x4>
+void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
+    const float d = xb->d;
+    const float min = xb->dmin;
+    device const uint8_t * q = (device const uint8_t *)xb->qs;
+    float dl, ml;
+    uint8_t sc = xb->scales[il];
+
+    q = q + 32*(il/8) + 16*(il&1);
+    il = (il/2)%4;
+
+    half  coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+    uchar mask = il>1 ? (il>2 ? 192    : 48)     : (il>0 ? 12    : 3);
+    dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
+    for (int i = 0; i < 16; ++i) {
+        reg[i/4][i%4] = dl * (q[i] & mask) - ml;
+    }
+}
+
+template <typename type4x4>
+void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
+    const half d_all = xb->d;
+    device const uint8_t * q = (device const uint8_t *)xb->qs;
+    device const uint8_t * h = (device const uint8_t *)xb->hmask;
+    device const int8_t * scales = (device const int8_t *)xb->scales;
+
+    q = q + 32 * (il/8) + 16 * (il&1);
+    h = h + 16 * (il&1);
+    uint8_t m = 1 << (il/2);
+    uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
+                                 ((il/4)>0 ? 12  : 3);
+    uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
+    uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
+    int16_t  dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
+                               : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
+    float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
+    const float ml = 4.f * dl;
+
+    il = (il/2) & 3;
+    const half    coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+    const uint8_t mask = il>1 ? (il>2 ? 192    : 48)     : (il>0 ? 12    : 3);
+    dl *= coef;
+
+    for (int i = 0; i < 16; ++i) {
+        reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
+    }
+}
+
+static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
+    return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
+                 : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
+}
+
+template <typename type4x4>
+void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
+    device const uchar * q = xb->qs;
+
+    short is = (il/4) * 2;
+    q = q + (il/4) * 32 + 16 * (il&1);
+    il = il & 3;
+    const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
+    const float d   = il < 2 ? xb->d : xb->d / 16.h;
+    const float min = xb->dmin;
+    const float dl = d * sc[0];
+    const float ml = min * sc[1];
+
+    const ushort mask = il<2 ? 0x0F : 0xF0;
+    for (int i = 0; i < 16; ++i) {
+        reg[i/4][i%4] = dl * (q[i] & mask) - ml;
+    }
+}
+
+template <typename type4x4>
+void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
+    device const uint8_t * q  = xb->qs;
+    device const uint8_t * qh = xb->qh;
+
+    short is = (il/4) * 2;
+    q  = q + 32 * (il/4) + 16 * (il&1);
+    qh = qh + 16 * (il&1);
+    uint8_t ul = 1 << (il/2);
+    il = il & 3;
+    const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
+    const float d = il < 2 ? xb->d : xb->d / 16.f;
+    const float min = xb->dmin;
+    const float dl = d * sc[0];
+    const float ml = min * sc[1];
+
+    const ushort mask  = il<2 ? 0x0F : 0xF0;
+    const float qh_val = il<2 ? 16.f : 256.f;
+    for (int i = 0; i < 16; ++i) {
+        reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
+    }
+}
+
+template <typename type4x4>
+void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
+    const half d_all = xb->d;
+    device const uint8_t * ql = (device const uint8_t *)xb->ql;
+    device const uint8_t * qh = (device const uint8_t *)xb->qh;
+    device const int8_t * scales = (device const int8_t *)xb->scales;
+
+    ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
+    qh = qh + 32*(il/8) + 16*(il&1);
+    float sc = scales[(il%2) + 2 * ((il/2))];
+    il = (il/2) & 3;
+
+    const uint16_t  kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+    const uint16_t  kmask2 = il>1 ? 0xF0              : 0x0F;
+    const float       coef = il>1 ? 1.f/16.f          : 1.f;
+    const float ml = d_all * sc * 32.f;
+    const float dl = d_all * sc * coef;
+    for (int i = 0; i < 16; ++i) {
+        const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
+                            : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
+        reg[i/4][i%4] = dl * q - ml;
+    }
+}
+
+template <typename type4x4>
+void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const float d = xb->d;
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
+    device const uint16_t * q2 = xb->qs + 4*ib32;
+    const uint32_t aux32_g = q2[0] | (q2[1] << 16);
+    const uint32_t aux32_s = q2[2] | (q2[3] << 16);
+    thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
+    const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
+    constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
+    uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
+    for (int i = 0; i < 8; ++i) {
+        reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+    }
+    grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
+    signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
+    for (int i = 0; i < 8; ++i) {
+        reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+    }
+}
+
+template <typename type4x4>
+void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const float d = xb->d;
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    device const uint16_t * q2 = xb->qs + 4*ib32;
+    const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
+    constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
+    uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
+    for (int i = 0; i < 8; ++i) {
+        reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+    }
+    grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
+    signs = ksigns_iq2xs[q2[2*il+1] >> 9];
+    for (int i = 0; i < 8; ++i) {
+        reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+    }
+}
+
+template <typename type4x4>
+void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const float d = xb->d;
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    device const uint8_t * q3 = xb->qs + 8*ib32;
+    device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
+    const uint32_t aux32 = gas[0] | (gas[1] << 16);
+    const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
+    constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
+    constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
+    uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
+    for (int i = 0; i < 4; ++i) {
+        reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
+        reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
+    }
+    grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
+    grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
+    signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
+    for (int i = 0; i < 4; ++i) {
+        reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
+        reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
+    }
+}
+
+template <typename type4x4>
+void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const float d = xb->d;
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    device const uint8_t * qs = xb->qs + 8*ib32;
+    device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
+    const uint8_t qh = xb->qh[ib32] >> 4*il;
+    const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
+    constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
+    constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
+    for (int i = 0; i < 4; ++i) {
+        reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
+        reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
+    }
+    grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
+    grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
+    for (int i = 0; i < 4; ++i) {
+        reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
+        reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
+    }
+}
+
+template <typename type4x4>
+void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const float d = xb->d;
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
+    device const uint8_t * signs = qs + QK_K/8;
+    const uint8_t qh = xb->qh[ib32] >> 4*il;
+    const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
+    constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
+    constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
+    for (int i = 0; i < 8; ++i) {
+        reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
+        reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
+    }
+}
+
+template <typename type4x4>
+void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const int ib32 = il/2;
+    il = il%2;
+    const float d = xb->d;
+    device const uint8_t  * qs = xb->qs + 4*ib32 + 2*il;
+    device const uint16_t * qh = xb->qh;
+    const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
+    const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
+    const uint16_t h = qh[ib32] >> 6*il;
+    constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
+    constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
+    for (int i = 0; i < 4; ++i) {
+        reg[0][i] = dl * (grid1[i] & 0xf) + ml;
+        reg[1][i] = dl * (grid1[i] >>  4) + ml;
+        reg[2][i] = dl * (grid2[i] & 0xf) + ml;
+        reg[3][i] = dl * (grid2[i] >>  4) + ml;
+    }
+}
+
+template <typename type4x4>
+void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const int ib32 = il/2;
+    il = il%2;
+    device const uint16_t * sc = (device const uint16_t *)xb->scales;
+
+    iq1m_scale_t scale;
+    scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+    const float d = scale.f16;
+
+    device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
+    device const uint8_t * qh = xb->qh + 2*ib32 + il;
+
+    const float dl  = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
+    const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+    const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+    constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
+    constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
+    for (int i = 0; i < 4; ++i) {
+        reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
+        reg[1][i] = dl * (grid1[i] >>  4) + ml1;
+        reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
+        reg[3][i] = dl * (grid2[i] >>  4) + ml2;
+    }
+}
+
+template <typename type4x4>
+void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
+    device const uint16_t * q4 = (device const uint16_t *)xb->qs;
+    const float d = xb->d;
+    uint32_t aux32;
+    thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
+    for (int i = 0; i < 4; ++i) {
+        aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
+        reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
+        reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
+        reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
+        reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
+    }
+}
+
+template <typename type4x4>
+void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
+    const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
+    const float d = (float)xb->d * (ls - 32);
+    uint32_t aux32;
+    thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
+    for (int i = 0; i < 4; ++i) {
+        aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
+        reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
+        reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
+        reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
+        reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
+    }
+}
+
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
+kernel void kernel_get_rows(
+        device const  void * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        uint3                tgpig[[threadgroup_position_in_grid]],
+        uint                 tiitg[[thread_index_in_threadgroup]],
+        uint3                tptg [[threads_per_threadgroup]]) {
+    //const int64_t i = tgpig;
+    //const int64_t r = ((device int32_t *) src1)[i];
+
+    const int64_t i10 = tgpig.x;
+    const int64_t i11 = tgpig.y;
+
+    const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    const int64_t i02 = i11;
+
+    for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
+        float4x4 temp;
+        dequantize_func(
+            ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
+        *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
+    }
+}
+
+kernel void kernel_get_rows_f32(
+        device const  void * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        uint3                tgpig[[threadgroup_position_in_grid]],
+        uint                 tiitg[[thread_index_in_threadgroup]],
+        uint3                tptg [[threads_per_threadgroup]]) {
+    const int64_t i10 = tgpig.x;
+    const int64_t i11 = tgpig.y;
+
+    const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    const int64_t i02 = i11;
+
+    for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+        ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
+            ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
+    }
+}
+
+kernel void kernel_get_rows_f16(
+        device const  void * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        uint3                tgpig[[threadgroup_position_in_grid]],
+        uint                 tiitg[[thread_index_in_threadgroup]],
+        uint3                tptg [[threads_per_threadgroup]]) {
+    const int64_t i10 = tgpig.x;
+    const int64_t i11 = tgpig.y;
+
+    const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    const int64_t i02 = i11;
+
+    for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+        ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
+            ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
+    }
+}
+
+kernel void kernel_get_rows_i32(
+        device const  void * src0,
+        device const  char * src1,
+        device     int32_t * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        uint3                tgpig[[threadgroup_position_in_grid]],
+        uint                 tiitg[[thread_index_in_threadgroup]],
+        uint3                tptg [[threads_per_threadgroup]]) {
+    const int64_t i10 = tgpig.x;
+    const int64_t i11 = tgpig.y;
+
+    const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    const int64_t i02 = i11;
+
+    for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+        ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
+            ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
+    }
+}
+
+
+#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
+#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
+#define BLOCK_SIZE_K 32
+#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
+#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
+#define THREAD_PER_BLOCK 128
+#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
+#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
+#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
+#define SG_MAT_ROW 8
+
+// each block_q contains 16*nl weights
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
+void kernel_mul_mm_impl(device const  uchar * src0,
+                        device const  uchar * src1,
+                        device        float * dst,
+                        constant    int64_t & ne00,
+                        constant    int64_t & ne02,
+                        constant   uint64_t & nb01,
+                        constant   uint64_t & nb02,
+                        constant    int64_t & ne12,
+                        constant   uint64_t & nb10,
+                        constant   uint64_t & nb11,
+                        constant   uint64_t & nb12,
+                        constant    int64_t & ne0,
+                        constant    int64_t & ne1,
+                        constant       uint & r2,
+                        constant       uint & r3,
+                        threadgroup   uchar * shared_memory [[threadgroup(0)]],
+                        uint3                 tgpig[[threadgroup_position_in_grid]],
+                        uint                  tiitg[[thread_index_in_threadgroup]],
+                        uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    threadgroup half  * sa = (threadgroup half  *)(shared_memory);
+    threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
+
+    const uint r0 = tgpig.y;
+    const uint r1 = tgpig.x;
+    const uint im = tgpig.z;
+
+    // if this block is of 64x32 shape or smaller
+    short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
+    short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
+
+    // a thread shouldn't load data outside of the matrix
+    short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
+    short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
+
+    simdgroup_half8x8  ma[4];
+    simdgroup_float8x8 mb[2];
+    simdgroup_float8x8 c_res[8];
+    for (int i = 0; i < 8; i++){
+        c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
+    }
+
+    short il = (tiitg % THREAD_PER_ROW);
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    uint   offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
+    ushort offset1 = il/nl;
+
+    device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
+    device const float   * y = (device const float   *)(src1
+        + nb12 * im
+        + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
+        + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+
+    for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
+        // load data and store to threadgroup memory
+        half4x4 temp_a;
+        dequantize_func(x, il, temp_a);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        #pragma unroll(16)
+        for (int i = 0; i < 16; i++) {
+            *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
+            +                     (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
+            +                     (tiitg / THREAD_PER_ROW) % 8  + (i & 7) * 8) = temp_a[i/4][i%4];
+        }
+
+        *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
+
+        il = (il + 2 < nl) ? il + 2 : il % 2;
+        x  = (il < 2) ? x + (2+nl-1)/nl : x;
+        y += BLOCK_SIZE_K;
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        // load matrices from threadgroup memory and conduct outer products
+        threadgroup half  * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
+        threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
+
+        #pragma unroll(4)
+        for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
+            #pragma unroll(4)
+            for (int i = 0; i < 4; i++) {
+                simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
+            }
+            simdgroup_barrier(mem_flags::mem_none);
+            #pragma unroll(2)
+            for (int i = 0; i < 2; i++) {
+                simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
+            }
+
+            lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
+            lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
+
+            #pragma unroll(8)
+            for (int i = 0; i < 8; i++){
+                simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
+            }
+        }
+    }
+
+    if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
+        device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg &  1)) \
+                               + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
+        for (int i = 0; i < 8; i++) {
+            simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
+        }
+    } else {
+        // block is smaller than 64x32, we should avoid writing data outside of the matrix
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
+                                      + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
+        for (int i = 0; i < 8; i++) {
+            simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
+        if (sgitg == 0) {
+            for (int i = 0; i < n_rows; i++) {
+                for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
+                    *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
+                }
+            }
+        }
+    }
+}
+
+// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
+void kernel_mul_mm_id_impl(
+        device const  uchar * src0,
+        device const  uchar * src1,
+        threadgroup ushort2 * rowids,
+        device        float * dst,
+        constant    int64_t & ne00,
+        constant    int64_t & ne02,
+        constant   uint64_t & nb01,
+        constant   uint64_t & nb02,
+        constant    int64_t & ne11,
+        constant    int64_t & ne12,
+        constant   uint64_t & nb10,
+        constant   uint64_t & nb11,
+        constant   uint64_t & nb12,
+        constant    int64_t & ne0,
+                    int64_t   ne1,
+                    int64_t   ne0ne1,
+        threadgroup   uchar * shared_memory,
+        uint3                 tgpig[[threadgroup_position_in_grid]],
+        uint                  tiitg[[thread_index_in_threadgroup]],
+        uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    threadgroup half  * sa = (threadgroup half  *)(shared_memory);
+    threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
+
+    const uint r0 = tgpig.y;
+    const uint r1 = tgpig.x;
+
+    if (r1 * BLOCK_SIZE_N >= ne1) return;
+
+    // if this block is of 64x32 shape or smaller
+    short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
+    short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
+
+    // a thread shouldn't load data outside of the matrix
+    short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
+    short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
+
+    simdgroup_half8x8  ma[4];
+    simdgroup_float8x8 mb[2];
+    simdgroup_float8x8 c_res[8];
+    for (int i = 0; i < 8; i++){
+        c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
+    }
+    short il = (tiitg % THREAD_PER_ROW);
+
+    ushort offset1 = il/nl;
+
+    threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
+
+    device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
+    device const float   * y = (device const float   *)(src1
+        + nb12 * id[1]
+        + nb11 * (id[0] % ne11)
+        + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+
+    for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
+        // load data and store to threadgroup memory
+        half4x4 temp_a;
+        dequantize_func(x, il, temp_a);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        for (int i = 0; i < 16; i++) {
+            *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
+            +                     (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
+            +                     (tiitg / THREAD_PER_ROW) % 8  + (i & 7) * 8) = temp_a[i/4][i%4];
+        }
+
+        *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
+
+        il = (il + 2 < nl) ? il + 2 : il % 2;
+        x  = (il < 2) ? x + (2+nl-1)/nl : x;
+        y += BLOCK_SIZE_K;
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        // load matrices from threadgroup memory and conduct outer products
+        threadgroup half  * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
+        threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
+
+        for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
+            for (int i = 0; i < 4; i++) {
+                simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
+            }
+            simdgroup_barrier(mem_flags::mem_none);
+            for (int i = 0; i < 2; i++) {
+                simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
+            }
+
+            lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
+            lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
+
+            for (int i = 0; i < 8; i++){
+                simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
+            }
+        }
+    }
+
+    {
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
+                                      + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
+        for (int i = 0; i < 8; i++) {
+            simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        device float * C = dst + (BLOCK_SIZE_M * r0);
+        if (sgitg == 0) {
+            for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
+                threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
+                int joff =  jid[0] * ne0 + jid[1] * ne0ne1;
+                for (int i = 0; i < n_rows; i++) {
+                    *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M);
+                }
+            }
+        }
+    }
+}
+
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
+kernel void kernel_mul_mm(device const  uchar * src0,
+                          device const  uchar * src1,
+                          device        float * dst,
+                          constant    int64_t & ne00,
+                          constant    int64_t & ne02,
+                          constant   uint64_t & nb01,
+                          constant   uint64_t & nb02,
+                          constant    int64_t & ne12,
+                          constant   uint64_t & nb10,
+                          constant   uint64_t & nb11,
+                          constant   uint64_t & nb12,
+                          constant    int64_t & ne0,
+                          constant    int64_t & ne1,
+                          constant       uint & r2,
+                          constant       uint & r3,
+                          threadgroup   uchar * shared_memory [[threadgroup(0)]],
+                          uint3                 tgpig[[threadgroup_position_in_grid]],
+                          uint                  tiitg[[thread_index_in_threadgroup]],
+                          uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
+    kernel_mul_mm_impl<block_q, nl, dequantize_func>(
+        src0,
+        src1,
+        dst,
+        ne00,
+        ne02,
+        nb01,
+        nb02,
+        ne12,
+        nb10,
+        nb11,
+        nb12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        shared_memory,
+        tgpig,
+        tiitg,
+        sgitg);
+}
+
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
+kernel void kernel_mul_mm_id(
+        device const   uchar * src0s,
+        device const   uchar * src1,
+        device         float * dst,
+        device const   uchar * ids,
+        constant     int64_t & nei0,
+        constant     int64_t & nei1,
+        constant    uint64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant    uint64_t & nb1,
+        threadgroup    uchar * shared_memory [[threadgroup(0)]],
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    const int32_t i02 = tgpig.z;
+    tgpig.z = 0;
+
+    device const uchar * src0 = src0s + i02*nb02;
+
+    // row indices
+    threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
+
+    // TODO: parallelize this loop
+    int64_t _ne1 = 0;
+    for (ushort ii1 = 0; ii1 < nei1; ii1++) {
+        for (ushort ii0 = 0; ii0 < nei0; ii0++) {
+            int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
+            if (id == i02) {
+                //if (tiitg == 0) {
+                    rowids[_ne1] = ushort2(ii0, ii1);
+                //}
+                _ne1++;
+            }
+        }
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
+        src0,
+        src1,
+        rowids,
+        dst,
+        ne00,
+        ne02,
+        nb01,
+        nb02,
+        ne11,
+        ne12,
+        nb10,
+        nb11,
+        nb12,
+        ne0,
+        _ne1,
+        ne0*ne1,
+        shared_memory,
+        tgpig,
+        tiitg,
+        sgitg);
+}
+
+#define QK_NL 16
+
+//
+// get rows
+//
+
+typedef void (get_rows_t)(
+        device const void * src0,
+        device const char * src1,
+        device      float * dst,
+        constant  int64_t & ne00,
+        constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant  int64_t & ne10,
+        constant uint64_t & nb10,
+        constant uint64_t & nb11,
+        constant uint64_t & nb1,
+        constant uint64_t & nb2,
+        uint3, uint, uint3);
+
+//template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_t kernel_get_rows<float4x4,   1, dequantize_f32>;
+//template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_t kernel_get_rows<half4x4,    1, dequantize_f16>;
+template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
+template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
+template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
+template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
+template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
+template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
+template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
+template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
+template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
+template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
+template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
+template [[host_name("kernel_get_rows_iq2_xs")]]  kernel get_rows_t kernel_get_rows<block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
+template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
+template [[host_name("kernel_get_rows_iq3_s")]]   kernel get_rows_t kernel_get_rows<block_iq3_s,   QK_NL, dequantize_iq3_s>;
+template [[host_name("kernel_get_rows_iq2_s")]]   kernel get_rows_t kernel_get_rows<block_iq2_s,   QK_NL, dequantize_iq2_s>;
+template [[host_name("kernel_get_rows_iq1_s")]]   kernel get_rows_t kernel_get_rows<block_iq1_s,   QK_NL, dequantize_iq1_s>;
+template [[host_name("kernel_get_rows_iq1_m")]]   kernel get_rows_t kernel_get_rows<block_iq1_m,   QK_NL, dequantize_iq1_m>;
+template [[host_name("kernel_get_rows_iq4_nl")]]  kernel get_rows_t kernel_get_rows<block_iq4_nl,  2,     dequantize_iq4_nl>;
+template [[host_name("kernel_get_rows_iq4_xs")]]  kernel get_rows_t kernel_get_rows<block_iq4_xs,  QK_NL, dequantize_iq4_xs>;
+
+//
+// matrix-matrix multiplication
+//
+
+typedef decltype(kernel_mul_mm<float4x4, 1, dequantize_f32>) mat_mm_t;
+
+template [[host_name("kernel_mul_mm_f32_f32")]]     kernel mat_mm_t kernel_mul_mm<float4x4,      1,     dequantize_f32>;
+template [[host_name("kernel_mul_mm_f16_f32")]]     kernel mat_mm_t kernel_mul_mm<half4x4,       1,     dequantize_f16>;
+template [[host_name("kernel_mul_mm_q4_0_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q4_0,    2,     dequantize_q4_0>;
+template [[host_name("kernel_mul_mm_q4_1_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q4_1,    2,     dequantize_q4_1>;
+template [[host_name("kernel_mul_mm_q5_0_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q5_0,    2,     dequantize_q5_0>;
+template [[host_name("kernel_mul_mm_q5_1_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q5_1,    2,     dequantize_q5_1>;
+template [[host_name("kernel_mul_mm_q8_0_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q8_0,    2,     dequantize_q8_0>;
+template [[host_name("kernel_mul_mm_q2_K_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q2_K,    QK_NL, dequantize_q2_K>;
+template [[host_name("kernel_mul_mm_q3_K_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q3_K,    QK_NL, dequantize_q3_K>;
+template [[host_name("kernel_mul_mm_q4_K_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q4_K,    QK_NL, dequantize_q4_K>;
+template [[host_name("kernel_mul_mm_q5_K_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q5_K,    QK_NL, dequantize_q5_K>;
+template [[host_name("kernel_mul_mm_q6_K_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q6_K,    QK_NL, dequantize_q6_K>;
+template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
+template [[host_name("kernel_mul_mm_iq2_xs_f32")]]  kernel mat_mm_t kernel_mul_mm<block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
+template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
+template [[host_name("kernel_mul_mm_iq3_s_f32")]]   kernel mat_mm_t kernel_mul_mm<block_iq3_s,   QK_NL, dequantize_iq3_s>;
+template [[host_name("kernel_mul_mm_iq2_s_f32")]]   kernel mat_mm_t kernel_mul_mm<block_iq2_s,   QK_NL, dequantize_iq2_s>;
+template [[host_name("kernel_mul_mm_iq1_s_f32")]]   kernel mat_mm_t kernel_mul_mm<block_iq1_s,   QK_NL, dequantize_iq1_s>;
+template [[host_name("kernel_mul_mm_iq1_m_f32")]]   kernel mat_mm_t kernel_mul_mm<block_iq1_m,   QK_NL, dequantize_iq1_m>;
+template [[host_name("kernel_mul_mm_iq4_nl_f32")]]  kernel mat_mm_t kernel_mul_mm<block_iq4_nl,  2,     dequantize_iq4_nl>;
+template [[host_name("kernel_mul_mm_iq4_xs_f32")]]  kernel mat_mm_t kernel_mul_mm<block_iq4_xs,  QK_NL, dequantize_iq4_xs>;
+
+//
+// indirect matrix-matrix multiplication
+//
+
+typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
+
+template [[host_name("kernel_mul_mm_id_f32_f32")]]     kernel mat_mm_id_t kernel_mul_mm_id<float4x4,      1,     dequantize_f32>;
+template [[host_name("kernel_mul_mm_id_f16_f32")]]     kernel mat_mm_id_t kernel_mul_mm_id<half4x4,       1,     dequantize_f16>;
+template [[host_name("kernel_mul_mm_id_q4_0_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0,    2,     dequantize_q4_0>;
+template [[host_name("kernel_mul_mm_id_q4_1_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1,    2,     dequantize_q4_1>;
+template [[host_name("kernel_mul_mm_id_q5_0_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0,    2,     dequantize_q5_0>;
+template [[host_name("kernel_mul_mm_id_q5_1_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1,    2,     dequantize_q5_1>;
+template [[host_name("kernel_mul_mm_id_q8_0_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0,    2,     dequantize_q8_0>;
+template [[host_name("kernel_mul_mm_id_q2_K_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K,    QK_NL, dequantize_q2_K>;
+template [[host_name("kernel_mul_mm_id_q3_K_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K,    QK_NL, dequantize_q3_K>;
+template [[host_name("kernel_mul_mm_id_q4_K_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K,    QK_NL, dequantize_q4_K>;
+template [[host_name("kernel_mul_mm_id_q5_K_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K,    QK_NL, dequantize_q5_K>;
+template [[host_name("kernel_mul_mm_id_q6_K_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K,    QK_NL, dequantize_q6_K>;
+template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
+template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
+template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
+template [[host_name("kernel_mul_mm_id_iq3_s_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s,   QK_NL, dequantize_iq3_s>;
+template [[host_name("kernel_mul_mm_id_iq2_s_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s,   QK_NL, dequantize_iq2_s>;
+template [[host_name("kernel_mul_mm_id_iq1_s_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s,   QK_NL, dequantize_iq1_s>;
+template [[host_name("kernel_mul_mm_id_iq1_m_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m,   QK_NL, dequantize_iq1_m>;
+template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl,  2,     dequantize_iq4_nl>;
+template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs,  QK_NL, dequantize_iq4_xs>;
+
+//
+// matrix-vector multiplication
+//
+
+typedef void (kernel_mul_mv_impl_t)(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                  uint64_t   nb00,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                   int64_t   ne10,
+                   int64_t   ne11,
+                   int64_t   ne12,
+                  uint64_t   nb10,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+                   uint3     tgpig,
+                   uint      tiisg);
+
+typedef void (kernel_mul_mv2_impl_t)(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg);
+
+template<kernel_mul_mv_impl_t impl_fn>
+void mmv_fn(
+        device const    char * src0,
+        device const    char * src1,
+        device         float * dst,
+                     int64_t   ne00,
+                     int64_t   ne01,
+                     int64_t   ne02,
+                    uint64_t   nb00,
+                    uint64_t   nb01,
+                    uint64_t   nb02,
+                     int64_t   ne10,
+                     int64_t   ne11,
+                     int64_t   ne12,
+                     int64_t   ne13,
+                    uint64_t   nb10,
+                    uint64_t   nb11,
+                    uint64_t   nb12,
+                     int64_t   ne0,
+                     int64_t   ne1,
+                    uint64_t   nb1,
+                        uint   r2,
+                        uint   r3,
+        threadgroup int8_t   * shared_values,
+        uint3                  tgpig,
+        uint                   tiitg,
+        uint                   tiisg,
+        uint                   sgitg) {
+    impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg);
+}
+
+template<kernel_mul_mv2_impl_t impl_fn>
+void mmv_fn(
+        device const    char * src0,
+        device const    char * src1,
+        device         float * dst,
+                     int64_t   ne00,
+                     int64_t   ne01,
+                     int64_t   ne02,
+                    uint64_t   nb00,
+                    uint64_t   nb01,
+                    uint64_t   nb02,
+                     int64_t   ne10,
+                     int64_t   ne11,
+                     int64_t   ne12,
+                     int64_t   ne13,
+                    uint64_t   nb10,
+                    uint64_t   nb11,
+                    uint64_t   nb12,
+                     int64_t   ne0,
+                     int64_t   ne1,
+                    uint64_t   nb1,
+                        uint   r2,
+                        uint   r3,
+        threadgroup int8_t   * shared_values,
+        uint3                  tgpig,
+        uint                   tiitg,
+        uint                   tiisg,
+        uint                   sgitg) {
+    impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
+}
+
+typedef decltype(mmv_fn<kernel_mul_mv_f32_f32_impl>) mul_mv_impl_fn_t;
+
+template<mul_mv_impl_fn_t impl_fn>
+kernel void kernel_mul_mv_id(
+        device const    char * src0s,
+        device const    char * src1,
+        device         float * dst,
+        device const    char * ids,
+        constant     int64_t & nei0,
+        constant     int64_t & nei1,
+        constant    uint64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant    uint64_t & nb1,
+        threadgroup int8_t   * shared_values [[threadgroup(0)]],
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    const int iid1 = tgpig.z/nei0;
+    const int idx = tgpig.z%nei0;
+
+    tgpig.z = 0;
+
+    const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx];
+
+    const int64_t i11 = idx % ne11;
+    const int64_t i12 = iid1;
+
+    const int64_t i1 = idx;
+    const int64_t i2 = i12;
+
+    device const char * src0_cur = src0s + i02*nb02;
+    device const char * src1_cur = src1 + i11*nb11 + i12*nb12;
+    device      float * dst_cur  = dst + i1*ne0 + i2*ne1*ne0;
+
+    impl_fn(
+        /* src0 */ src0_cur,
+        /* src1 */ src1_cur,
+        /* dst  */ dst_cur,
+        /* ne00 */ ne00,
+        /* ne01 */ ne01,
+        /* ne02 */ 1,//ne02,
+        /* nb00 */ nb00,
+        /* nb01 */ nb01,
+        /* nb02 */ nb02,
+        /* ne10 */ ne10,
+        /* ne11 */ 1,//ne11,
+        /* ne12 */ 1,//ne12,
+        /* ne13 */ 1,//ne13,
+        /* nb10 */ nb10,
+        /* nb11 */ nb11,
+        /* nb12 */ nb12,
+        /* ne0  */ ne0,
+        /* ne1  */ 1,//ne1,
+        /* nb1  */ nb1,
+        /* r2   */ 1,
+        /* r3   */ 1,
+        shared_values,
+        tgpig,
+        tiitg,
+        tiisg,
+        sgitg);
+}
+
+typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>) kernel_mul_mv_id_t;
+
+template [[host_name("kernel_mul_mv_id_f32_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_f16_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq1_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq1_m_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq3_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq2_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
+

+ 26 - 2
llama/llama.go

@@ -4,10 +4,10 @@ package llama
 // #cgo CXXFLAGS: -std=c++11 -DNDEBUG -DLOG_DISABLE_LOGS
 // #cgo darwin,arm64 CFLAGS: -DGGML_USE_METAL -DGGML_USE_ACCELERATE -DGGML_METAL_EMBED_LIBRARY -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
 // #cgo darwin,arm64 CXXFLAGS: -DGGML_USE_METAL -DGGML_USE_ACCELERATE -DGGML_METAL_EMBED_LIBRARY -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
-// #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/ggml-metal.o -framework Foundation -framework Metal -framework MetalKit -framework Accelerate
+// #cgo darwin,arm64 LDFLAGS: -framework Foundation -framework Metal -framework MetalKit -framework Accelerate
 // #cgo darwin,amd64 CFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
 // #cgo darwin,amd64 CXXFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
-// #cgo darwin,amd64 LDFLAGS: ${SRCDIR}/ggml-metal.o -framework Foundation
+// #cgo darwin,amd64 LDFLAGS: -framework Foundation
 // #cgo darwin,amd64,avx2 CFLAGS: -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
 // #cgo darwin,amd64,avx2 CXXFLAGS: -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
 // #cgo darwin,amd64,avx2 LDFLAGS: -framework Accelerate
@@ -35,8 +35,11 @@ package llama
 // #include "sampling_ext.h"
 //
 // bool llamaProgressCallback(float progress, void *user_data);
+// const char* ggml_metallib_start;
+// const char* ggml_metallib_end;
 import "C"
 import (
+	_ "embed"
 	"errors"
 	"fmt"
 	"runtime"
@@ -45,6 +48,21 @@ import (
 	"unsafe"
 )
 
+//go:embed ggml-common.h
+var ggmlCommon string
+
+//go:embed ggml-metal.metal
+var ggmlMetal string
+
+// TODO: write me somewhere else
+func init() {
+	metal := strings.ReplaceAll(ggmlMetal, `#include "ggml-common.h"`, ggmlCommon)
+	fmt.Println(metal)
+	cMetal := C.CString(metal)
+	C.ggml_metallib_start = cMetal
+	C.ggml_metallib_end = (*C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(cMetal)) + uintptr(len(metal))))
+}
+
 func BackendInit() {
 	C.llama_backend_init()
 }
@@ -81,6 +99,12 @@ func llamaProgressCallback(progress C.float, userData unsafe.Pointer) C.bool {
 }
 
 func NewModelParams(numGpuLayers int, mainGpu int, callback func(float32)) ModelParams {
+	fmt.Println("Contents of ggml-common.h:")
+	fmt.Println(ggmlCommon)
+
+	fmt.Println("\nContents of ggml-metal.in.metal:")
+	fmt.Println(ggmlMetal)
+
 	params := C.llama_model_default_params()
 	params.n_gpu_layers = C.int(numGpuLayers)
 	params.main_gpu = C.int32_t(mainGpu)

+ 1 - 0
llama/llama_test.go

@@ -0,0 +1 @@
+package llama

+ 1 - 15
llama/sync_llama.sh

@@ -28,7 +28,7 @@ cp $src_dir/ggml.c $dst_dir/ggml.c
 cp $src_dir/ggml.h $dst_dir/ggml.h
 cp $src_dir/ggml-quants.c $dst_dir/ggml-quants.c
 cp $src_dir/ggml-quants.h $dst_dir/ggml-quants.h
-cp $src_dir/ggml-metal.metal $dst_dir/ggml-metal.metal
+cp $src_dir/ggml-metal.metal $dst_dir/ggml-metal.in.metal
 cp $src_dir/ggml-metal.h $dst_dir/ggml-metal.h
 cp $src_dir/ggml-metal.m $dst_dir/ggml-metal-darwin_arm64.m
 cp $src_dir/ggml-impl.h $dst_dir/ggml-impl.h
@@ -105,17 +105,3 @@ for IN in $dst_dir/*.{c,h,cpp,m,metal,cu}; do
     cat $TEMP_LICENSE $IN >$TMP
     mv $TMP $IN
 done
-
-# ggml-metal
-# TODO: just embed the files
-sed -e '/#include "ggml-common.h"/r ggml-common.h' -e '/#include "ggml-common.h"/d' < $dst_dir/ggml-metal.metal > temp.metal
-TEMP_ASSEMBLY=$(mktemp)
-echo ".section __DATA, __ggml_metallib"   >  $TEMP_ASSEMBLY
-echo ".globl _ggml_metallib_start"        >> $TEMP_ASSEMBLY
-echo "_ggml_metallib_start:"              >> $TEMP_ASSEMBLY
-echo ".incbin \"temp.metal\"" >> $TEMP_ASSEMBLY
-echo ".globl _ggml_metallib_end"          >> $TEMP_ASSEMBLY
-echo "_ggml_metallib_end:"                >> $TEMP_ASSEMBLY
-as $TEMP_ASSEMBLY -o $dst_dir/ggml-metal.o
-rm -f $TEMP_ASSEMBLY
-rm -rf temp.metal