Ver código fonte

llama: update vendored code to commit 40c6d79f (#7875)

Jeffrey Morgan 4 meses atrás
pai
commit
527cc97899
100 arquivos alterados com 32879 adições e 5958 exclusões
  1. 18 0
      .github/workflows/test.yaml
  2. 0 2
      api/types.go
  3. 0 1
      docs/api.md
  4. 222 0
      llama/amx.cpp
  5. 46 0
      llama/amx.h
  6. 1 1
      llama/build-info.cpp
  7. 22 5
      llama/clip.cpp
  8. 1 1
      llama/clip.h
  9. 218 313
      llama/common.cpp
  10. 179 112
      llama/common.h
  11. 47 3127
      llama/ggml-aarch64.c
  12. 1 21
      llama/ggml-aarch64.h
  13. 24 28
      llama/ggml-alloc.c
  14. 2 2
      llama/ggml-alloc.h
  15. 186 84
      llama/ggml-backend-impl.h
  16. 545 0
      llama/ggml-backend-reg.cpp
  17. 155 607
      llama/ggml-backend.cpp
  18. 182 72
      llama/ggml-backend.h
  19. 221 72
      llama/ggml-blas.cpp
  20. 6 4
      llama/ggml-blas.h
  21. 7 1
      llama/ggml-common.h
  22. 64 0
      llama/ggml-cpp.h
  23. 3849 0
      llama/ggml-cpu-aarch64.c
  24. 58 0
      llama/ggml-cpu-aarch64.h
  25. 15 243
      llama/ggml-cpu-impl.h
  26. 10861 0
      llama/ggml-cpu-quants.c
  27. 89 0
      llama/ggml-cpu-quants.h
  28. 14054 0
      llama/ggml-cpu.c
  29. 741 0
      llama/ggml-cpu.cpp
  30. 178 0
      llama/ggml-cpu.h
  31. 17 19
      llama/ggml-cuda.h
  32. 1 1
      llama/ggml-cuda/acc.cu
  33. 1 1
      llama/ggml-cuda/acc.cuh
  34. 1 1
      llama/ggml-cuda/arange.cu
  35. 1 1
      llama/ggml-cuda/arange.cuh
  36. 117 0
      llama/ggml-cuda/argmax.cu
  37. 29 0
      llama/ggml-cuda/argmax.cuh
  38. 1 1
      llama/ggml-cuda/argsort.cu
  39. 1 1
      llama/ggml-cuda/argsort.cuh
  40. 1 1
      llama/ggml-cuda/binbcast.cu
  41. 1 1
      llama/ggml-cuda/binbcast.cuh
  42. 1 1
      llama/ggml-cuda/clamp.cu
  43. 1 1
      llama/ggml-cuda/clamp.cuh
  44. 63 40
      llama/ggml-cuda/common.cuh
  45. 1 1
      llama/ggml-cuda/concat.cu
  46. 1 1
      llama/ggml-cuda/concat.cuh
  47. 1 1
      llama/ggml-cuda/conv-transpose-1d.cu
  48. 1 1
      llama/ggml-cuda/conv-transpose-1d.cuh
  49. 1 1
      llama/ggml-cuda/convert.cu
  50. 1 1
      llama/ggml-cuda/convert.cuh
  51. 90 0
      llama/ggml-cuda/count-equal.cu
  52. 31 0
      llama/ggml-cuda/count-equal.cuh
  53. 1 1
      llama/ggml-cuda/cpy.cu
  54. 2 2
      llama/ggml-cuda/cpy.cuh
  55. 1 1
      llama/ggml-cuda/cross-entropy-loss.cu
  56. 1 1
      llama/ggml-cuda/cross-entropy-loss.cuh
  57. 1 1
      llama/ggml-cuda/dequantize.cuh
  58. 1 1
      llama/ggml-cuda/diagmask.cu
  59. 1 1
      llama/ggml-cuda/diagmask.cuh
  60. 0 709
      llama/ggml-cuda/dmmv.cu
  61. 3 3
      llama/ggml-cuda/fattn-common.cuh
  62. 4 4
      llama/ggml-cuda/fattn-tile-f16.cu
  63. 1 1
      llama/ggml-cuda/fattn-tile-f16.cuh
  64. 3 3
      llama/ggml-cuda/fattn-tile-f32.cu
  65. 1 1
      llama/ggml-cuda/fattn-tile-f32.cuh
  66. 6 7
      llama/ggml-cuda/fattn-vec-f16.cuh
  67. 3 4
      llama/ggml-cuda/fattn-vec-f32.cuh
  68. 3 3
      llama/ggml-cuda/fattn-wmma-f16.cuh
  69. 6 6
      llama/ggml-cuda/fattn.cu
  70. 1 1
      llama/ggml-cuda/fattn.cuh
  71. 1 1
      llama/ggml-cuda/getrows.cu
  72. 1 1
      llama/ggml-cuda/getrows.cuh
  73. 144 333
      llama/ggml-cuda/ggml-cuda.cu
  74. 4 4
      llama/ggml-cuda/im2col.cu
  75. 1 1
      llama/ggml-cuda/im2col.cuh
  76. 1 1
      llama/ggml-cuda/mma.cuh
  77. 3 5
      llama/ggml-cuda/mmq.cu
  78. 14 14
      llama/ggml-cuda/mmq.cuh
  79. 249 0
      llama/ggml-cuda/mmv.cu
  80. 5 13
      llama/ggml-cuda/mmv.cuh
  81. 6 6
      llama/ggml-cuda/mmvq.cu
  82. 1 1
      llama/ggml-cuda/mmvq.cuh
  83. 1 1
      llama/ggml-cuda/norm.cu
  84. 1 1
      llama/ggml-cuda/norm.cuh
  85. 34 36
      llama/ggml-cuda/opt-step-adamw.cu
  86. 1 1
      llama/ggml-cuda/opt-step-adamw.cuh
  87. 1 1
      llama/ggml-cuda/out-prod.cu
  88. 1 1
      llama/ggml-cuda/out-prod.cuh
  89. 1 1
      llama/ggml-cuda/pad.cu
  90. 1 1
      llama/ggml-cuda/pad.cuh
  91. 1 1
      llama/ggml-cuda/pool2d.cu
  92. 1 1
      llama/ggml-cuda/pool2d.cuh
  93. 5 5
      llama/ggml-cuda/quantize.cu
  94. 1 1
      llama/ggml-cuda/quantize.cuh
  95. 1 1
      llama/ggml-cuda/rope.cu
  96. 1 1
      llama/ggml-cuda/rope.cuh
  97. 1 1
      llama/ggml-cuda/scale.cu
  98. 1 1
      llama/ggml-cuda/scale.cuh
  99. 1 1
      llama/ggml-cuda/softmax.cu
  100. 1 1
      llama/ggml-cuda/softmax.cuh

+ 18 - 0
.github/workflows/test.yaml

@@ -269,6 +269,15 @@ jobs:
       - uses: actions/checkout@v4
         with:
           submodules: recursive
+      - name: Add msys paths
+        if: ${{ startsWith(matrix.os, 'windows-') }}
+        run: |
+          echo "c:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
+          echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
+      - name: Install msys2 tools
+        if: ${{ startsWith(matrix.os, 'windows-') }}
+        run: |
+          Start-Process "c:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang") -NoNewWindow -Wait
       - uses: actions/setup-go@v5
         with:
           go-version-file: go.mod
@@ -300,6 +309,15 @@ jobs:
       - uses: actions/checkout@v4
         with:
           submodules: recursive
+      - name: Add msys paths
+        if: ${{ startsWith(matrix.os, 'windows-') }}
+        run: |
+          echo "c:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
+          echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
+      - name: Install msys2 tools
+        if: ${{ startsWith(matrix.os, 'windows-') }}
+        run: |
+          Start-Process "c:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang") -NoNewWindow -Wait
       - uses: actions/setup-go@v5
         with:
           go-version-file: go.mod

+ 0 - 2
api/types.go

@@ -216,7 +216,6 @@ type Options struct {
 	TopK             int      `json:"top_k,omitempty"`
 	TopP             float32  `json:"top_p,omitempty"`
 	MinP             float32  `json:"min_p,omitempty"`
-	TFSZ             float32  `json:"tfs_z,omitempty"`
 	TypicalP         float32  `json:"typical_p,omitempty"`
 	RepeatLastN      int      `json:"repeat_last_n,omitempty"`
 	Temperature      float32  `json:"temperature,omitempty"`
@@ -595,7 +594,6 @@ func DefaultOptions() Options {
 		Temperature:      0.8,
 		TopK:             40,
 		TopP:             0.9,
-		TFSZ:             1.0,
 		TypicalP:         1.0,
 		RepeatLastN:      64,
 		RepeatPenalty:    1.1,

+ 0 - 1
docs/api.md

@@ -387,7 +387,6 @@ curl http://localhost:11434/api/generate -d '{
     "top_k": 20,
     "top_p": 0.9,
     "min_p": 0.0,
-    "tfs_z": 0.5,
     "typical_p": 0.7,
     "repeat_last_n": 33,
     "temperature": 0.8,

+ 222 - 0
llama/amx.cpp

@@ -0,0 +1,222 @@
+/**
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "amx.h"
+#include "common.h"
+#include "mmq.h"
+#include "ggml-backend-impl.h"
+#include "ggml-backend.h"
+#include "ggml-impl.h"
+#include "ggml-cpu.h"
+
+#if defined(__gnu_linux__)
+#include <sys/syscall.h>
+#include <unistd.h>
+#endif
+
+#include <cstdlib>
+#include <cstring>
+#include <memory>
+
+#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
+
+// AMX buffer interface
+static void ggml_backend_amx_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    free(buffer->context);
+}
+
+static void * ggml_backend_amx_buffer_get_base(ggml_backend_buffer_t buffer) {
+    return (void *)(buffer->context);
+}
+
+static void ggml_backend_amx_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+    memset((char *)tensor->data + offset, value, size);
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_amx_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    if (qtype_has_amx_kernels(tensor->type)) {
+        ggml_backend_amx_convert_weight(tensor, data, offset, size);
+    } else {
+        memcpy((char *)tensor->data + offset, data, size);
+    }
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_amx_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_ASSERT(!qtype_has_amx_kernels(tensor->type));
+    memcpy(data, (const char *)tensor->data + offset, size);
+
+    GGML_UNUSED(buffer);
+}
+
+static bool ggml_backend_amx_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
+    if (ggml_backend_buffer_is_host(src->buffer)) {
+        if (qtype_has_amx_kernels(src->type)) {
+            ggml_backend_amx_convert_weight(dst, src->data, 0, ggml_nbytes(dst));
+        } else {
+            memcpy(dst->data, src->data, ggml_nbytes(src));
+        }
+        return true;
+    }
+    return false;
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+    memset(buffer->context, value, buffer->size);
+}
+
+static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = {
+    /* .free_buffer     = */ ggml_backend_amx_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_amx_buffer_get_base,
+    /* .init_tensor     = */ NULL, // no initialization required
+    /* .memset_tensor   = */ ggml_backend_amx_buffer_memset_tensor,
+    /* .set_tensor      = */ ggml_backend_amx_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_amx_buffer_get_tensor,
+    /* .cpy_tensor      = */ ggml_backend_amx_buffer_cpy_tensor,
+    /* .clear           = */ ggml_backend_amx_buffer_clear,
+    /* .reset           = */ NULL,
+};
+
+static const char * ggml_backend_amx_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+    return "AMX";
+
+    GGML_UNUSED(buft);
+}
+
+static ggml_backend_buffer_t ggml_backend_amx_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    void * data = aligned_alloc(TENSOR_ALIGNMENT, size);
+    if (data == NULL) {
+        fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
+        return NULL;
+    }
+
+    return ggml_backend_buffer_init(buft, ggml_backend_amx_buffer_interface, data, size);
+}
+
+static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+    return TENSOR_ALIGNMENT;
+
+    GGML_UNUSED(buft);
+}
+
+static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
+    return ggml_backend_amx_get_alloc_size(tensor);
+
+    GGML_UNUSED(buft);
+}
+
+static bool ggml_backend_amx_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
+    return false;
+
+    GGML_UNUSED(buft);
+}
+
+#define ARCH_GET_XCOMP_PERM     0x1022
+#define ARCH_REQ_XCOMP_PERM     0x1023
+#define XFEATURE_XTILECFG       17
+#define XFEATURE_XTILEDATA      18
+
+static bool ggml_amx_init() {
+#if defined(__gnu_linux__)
+    if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
+        fprintf(stderr, "AMX is not ready to be used!\n");
+        return false;
+    }
+    return true;
+#elif defined(_WIN32)
+    return true;
+#endif
+}
+ggml_backend_buffer_type_t ggml_backend_amx_buffer_type() {
+    static struct ggml_backend_buffer_type ggml_backend_buffer_type_amx = {
+        /* .iface = */ {
+            /* .get_name         = */ ggml_backend_amx_buffer_type_get_name,
+            /* .alloc_buffer     = */ ggml_backend_amx_buffer_type_alloc_buffer,
+            /* .get_alignment    = */ ggml_backend_amx_buffer_type_get_alignment,
+            /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
+            /* .get_alloc_size   = */ ggml_backend_amx_buffer_type_get_alloc_size,
+            /* .is_host          = */ ggml_backend_amx_buffer_type_is_host,
+        },
+        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
+        /* .context = */ NULL,
+    };
+
+    if (!ggml_amx_init()) {
+        return NULL;
+    }
+
+    return &ggml_backend_buffer_type_amx;
+}
+
+bool ggml_backend_amx_buft_is_amx(ggml_backend_buffer_type_t buft) {
+    return buft->iface.get_name == ggml_backend_amx_buffer_type_get_name;
+}
+
+bool ggml_backend_amx_device_supports_op(const struct ggml_tensor * op) {
+    // handle only 2d gemm for now
+    auto is_contiguous_2d = [](const struct ggml_tensor * t) {
+        return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
+    };
+
+    switch (op->op) {
+        case GGML_OP_NONE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+            return true;
+
+        case GGML_OP_MUL_MAT: {
+            const struct ggml_tensor * src0 = op->src[0];
+            const struct ggml_tensor * src1 = op->src[1];
+
+            const enum ggml_type type = src0->type;
+            const int64_t ne0 = op->ne[0];
+
+            // amx kernels enables for Q4_0, Q4_1, Q8_0, F16
+            // Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
+            bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16);
+
+            bool can_use_amx =
+                is_contiguous_2d(src0) &&       // src0 must be contiguous
+                is_contiguous_2d(src1) &&       // src1 must be contiguous
+                src1->type == GGML_TYPE_F32 &&  // src1 must be float32
+                has_amx_kernels &&              // with amx kernel impls
+                ne0 % (TILE_N * 2) == 0;        // out_features is 32x
+
+            return can_use_amx;
+        }
+        default:
+            return false;
+    }
+}
+
+#endif // defined(__AMX_INT8__) && defined(__AVX512VNNI__)

+ 46 - 0
llama/amx.h

@@ -0,0 +1,46 @@
+/**
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "ggml-backend.h"
+#include "ggml-cpu-impl.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
+
+ggml_backend_buffer_type_t ggml_backend_amx_buffer_type(void);
+bool ggml_backend_amx_buft_is_amx(ggml_backend_buffer_type_t buft);
+bool ggml_backend_amx_device_supports_op(const struct ggml_tensor * op);
+void ggml_backend_amx_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst);
+
+#endif
+
+#ifdef __cplusplus
+}
+#endif

+ 1 - 1
llama/build-info.cpp

@@ -1,4 +1,4 @@
 int LLAMA_BUILD_NUMBER = 0;
-char const *LLAMA_COMMIT = "3f1ae2e32cde00c39b96be6d01c2997c29bae555";
+char const *LLAMA_COMMIT = "40c6d79fb52f995f47507fedfeaae2ac05d9b35c";
 char const *LLAMA_COMPILER = "";
 char const *LLAMA_BUILD_TARGET = "";

+ 22 - 5
llama/clip.cpp

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -30,6 +30,7 @@
 // Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
 #include "clip.h"
 #include "ggml.h"
+#include "ggml-cpu.h"
 #include "ggml-alloc.h"
 #include "ggml-backend.h"
 
@@ -37,6 +38,10 @@
 #include "ggml-cuda.h"
 #endif
 
+#ifdef GGML_USE_SYCL
+#include "ggml-sycl.h"
+#endif
+
 #ifdef GGML_USE_METAL
 #include "ggml-metal.h"
 #endif
@@ -65,10 +70,17 @@
 #include <cinttypes>
 #include <limits>
 
-#define LOG_INF(...) do { fprintf(stdout, __VA_ARGS__); } while (0)
-#define LOG_WRN(...) do { fprintf(stderr, __VA_ARGS__); } while (0)
-#define LOG_ERR(...) do { fprintf(stderr, __VA_ARGS__); } while (0)
-#define LOG_DBG(...) do { fprintf(stderr, __VA_ARGS__); } while (0)
+#if defined(LLAVA_LOG_OFF)
+#   define LOG_INF(...)
+#   define LOG_WRN(...)
+#   define LOG_ERR(...)
+#   define LOG_DBG(...)
+#else // defined(LLAVA_LOG_OFF)
+#   define LOG_INF(...) do { fprintf(stdout, __VA_ARGS__); } while (0)
+#   define LOG_WRN(...) do { fprintf(stderr, __VA_ARGS__); } while (0)
+#   define LOG_ERR(...) do { fprintf(stderr, __VA_ARGS__); } while (0)
+#   define LOG_DBG(...) do { fprintf(stdout, __VA_ARGS__); } while (0)
+#endif // defined(LLAVA_LOG_OFF)
 
 #if defined(_WIN32)
 #define WIN32_LEAN_AND_MEAN
@@ -1200,6 +1212,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
     LOG_INF("%s: CLIP using Vulkan backend\n", __func__);
 #endif
 
+#ifdef GGML_USE_SYCL
+    new_clip->backend = ggml_backend_sycl_init(0);
+    LOG_INF("%s: CLIP using SYCL backend\n", __func__);
+#endif
+
     if (!new_clip->backend) {
         new_clip->backend = ggml_backend_cpu_init();
         LOG_INF("%s: CLIP using CPU backend\n", __func__);

+ 1 - 1
llama/clip.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 218 - 313
llama/common.cpp

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -38,6 +38,7 @@
 
 #include <algorithm>
 #include <cinttypes>
+#include <climits>
 #include <cmath>
 #include <codecvt>
 #include <cstdarg>
@@ -49,10 +50,10 @@
 #include <regex>
 #include <sstream>
 #include <string>
+#include <thread>
 #include <unordered_map>
 #include <unordered_set>
 #include <vector>
-#include <thread>
 
 #if defined(__APPLE__) && defined(__MACH__)
 #include <sys/types.h>
@@ -388,10 +389,10 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
     return true;
 }
 
-void gpt_init() {
+void common_init() {
     llama_log_set([](ggml_log_level level, const char * text, void * /*user_data*/) {
-        if (LOG_DEFAULT_LLAMA <= gpt_log_verbosity_thold) {
-            gpt_log_add(gpt_log_main(), level, "%s", text);
+        if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
+            common_log_add(common_log_main(), level, "%s", text);
         }
     }, NULL);
 
@@ -404,7 +405,7 @@ void gpt_init() {
     LOG_INF("build: %d (%s) with %s for %s%s\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT, LLAMA_COMPILER, LLAMA_BUILD_TARGET, build_type);
 }
 
-std::string gpt_params_get_system_info(const gpt_params & params) {
+std::string common_params_get_system_info(const common_params & params) {
     std::ostringstream os;
 
     os << "system_info: n_threads = " << params.cpuparams.n_threads;
@@ -426,17 +427,19 @@ std::string gpt_params_get_system_info(const gpt_params & params) {
 // String utils
 //
 
-std::vector<std::string> string_split(std::string input, char separator) {
-    std::vector<std::string> parts;
-    size_t separator_pos = input.find(separator);
-    while (separator_pos != std::string::npos) {
-        std::string part = input.substr(0, separator_pos);
-        parts.emplace_back(part);
-        input = input.substr(separator_pos + 1);
-        separator_pos = input.find(separator);
-    }
-    parts.emplace_back(input);
-    return parts;
+std::string string_format(const char * fmt, ...) {
+    va_list ap;
+    va_list ap2;
+    va_start(ap, fmt);
+    va_copy(ap2, ap);
+    int size = vsnprintf(NULL, 0, fmt, ap);
+    GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
+    std::vector<char> buf(size + 1);
+    int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
+    GGML_ASSERT(size2 == size);
+    va_end(ap2);
+    va_end(ap);
+    return std::string(buf.data(), size);
 }
 
 std::string string_strip(const std::string & str) {
@@ -519,7 +522,7 @@ std::string string_from(const struct llama_context * ctx, const std::vector<llam
             first = false;
         }
 
-        auto detokenized = llama_token_to_piece(ctx, token);
+        auto detokenized = common_token_to_piece(ctx, token);
 
         detokenized.erase(
             std::remove_if(
@@ -550,7 +553,7 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
             first = false;
         }
 
-        auto detokenized = llama_token_to_piece(ctx, batch.token[i]);
+        auto detokenized = common_token_to_piece(ctx, batch.token[i]);
 
         detokenized.erase(
                 std::remove_if(
@@ -559,12 +562,12 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
                     [](const unsigned char c) { return !std::isprint(c); }),
                 detokenized.end());
 
-        buf << "\n" << std::to_string(i)
-            << ":token '" << detokenized << "'"
-            << ":pos " << std::to_string(batch.pos[i])
-            << ":n_seq_id  " << std::to_string(batch.n_seq_id[i])
-            << ":seq_id " << std::to_string(batch.seq_id[i][0])
-            << ":logits " << std::to_string(batch.logits[i]);
+        buf << "\n"          << std::to_string(i)
+            << ", token '"   << detokenized << "'"
+            << ", pos "      << std::to_string(batch.pos[i])
+            << ", n_seq_id " << std::to_string(batch.n_seq_id[i])
+            << ", seq_id "   << std::to_string(batch.seq_id[i][0])
+            << ", logits "   << std::to_string(batch.logits[i]);
     }
 
     buf << " ]";
@@ -675,7 +678,17 @@ bool fs_validate_filename(const std::string & filename) {
 
     std::u32string filename_utf32;
     try {
+#if defined(__clang__)
+        // disable C++17 deprecation warning for std::codecvt_utf8
+#    pragma clang diagnostic push
+#    pragma clang diagnostic ignored "-Wdeprecated-declarations"
+#endif
         std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
+
+#if defined(__clang__)
+#    pragma clang diagnostic pop
+#endif
+
         filename_utf32 = converter.from_bytes(filename);
 
         // If the reverse conversion mismatches, it means overlong UTF-8 sequences were used,
@@ -845,16 +858,16 @@ std::string fs_get_cache_file(const std::string & filename) {
 //
 // Model utils
 //
-struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
-    llama_init_result iparams;
-    auto mparams = llama_model_params_from_gpt_params(params);
+struct common_init_result common_init_from_params(common_params & params) {
+    common_init_result iparams;
+    auto mparams = common_model_params_to_llama(params);
 
     llama_model * model = nullptr;
 
     if (!params.hf_repo.empty() && !params.hf_file.empty()) {
-        model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams);
+        model = common_load_model_from_hf(params.hf_repo, params.hf_file, params.model, params.hf_token, mparams);
     } else if (!params.model_url.empty()) {
-        model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams);
+        model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams);
     } else {
         model = llama_load_model_from_file(params.model.c_str(), mparams);
     }
@@ -864,7 +877,32 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
         return iparams;
     }
 
-    auto cparams = llama_context_params_from_gpt_params(params);
+    if (params.reranking) {
+        bool ok = true;
+
+        if (llama_token_bos(model) == LLAMA_TOKEN_NULL) {
+            LOG_WRN("%s: warning: model does not have a  BOS token, reranking will not work\n", __func__);
+            ok = false;
+        }
+
+        if (llama_token_eos(model) == LLAMA_TOKEN_NULL) {
+            LOG_WRN("%s: warning: model does not have an EOS token, reranking will not work\n", __func__);
+            ok = false;
+        }
+
+        if (llama_token_sep(model) == LLAMA_TOKEN_NULL) {
+            LOG_WRN("%s: warning: model does not have a  SEP token, reranking will not work\n", __func__);
+            ok = false;
+        }
+
+        if (!ok) {
+            llama_free_model(model);
+
+            return iparams;
+        }
+    }
+
+    auto cparams = common_context_params_to_llama(params);
 
     llama_context * lctx = llama_new_context_with_model(model, cparams);
     if (lctx == NULL) {
@@ -873,14 +911,21 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
         return iparams;
     }
 
+    if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) {
+        LOG_ERR("%s: KV cache shifting is not supported for this model (--no-context-shift to disable)'\n", __func__);
+        llama_free_model(model);
+        return iparams;
+    }
+
     if (!params.control_vectors.empty()) {
         if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1;
         if (params.control_vector_layer_end   <= 0) params.control_vector_layer_end   = llama_n_layer(model);
 
-        const auto cvec = llama_control_vector_load(params.control_vectors);
+        const auto cvec = common_control_vector_load(params.control_vectors);
         if (cvec.n_embd == -1) {
             llama_free(lctx);
             llama_free_model(model);
+
             return iparams;
         }
 
@@ -893,13 +938,14 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
         if (err) {
             llama_free(lctx);
             llama_free_model(model);
+
             return iparams;
         }
     }
 
     // load and optionally apply lora adapters
     for (auto & la : params.lora_adapters) {
-        llama_lora_adapter_container loaded_la;
+        common_lora_adapter_container loaded_la;
         loaded_la.path = la.path;
         loaded_la.scale = la.scale;
         loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str());
@@ -912,12 +958,12 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
         iparams.lora_adapters.push_back(loaded_la); // copy to list of loaded adapters
     }
     if (!params.lora_init_without_apply) {
-        llama_lora_adapters_apply(lctx, iparams.lora_adapters);
+        common_lora_adapters_apply(lctx, iparams.lora_adapters);
     }
 
-    if (params.sparams.ignore_eos && llama_token_eos(model) == -1) {
+    if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) {
         LOG_WRN("%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__);
-        params.sparams.ignore_eos = false;
+        params.sampling.ignore_eos = false;
     }
 
     if (params.warmup) {
@@ -938,7 +984,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
         }
 
         if (llama_model_has_encoder(model)) {
-            llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));
+            llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size()));
             llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
             if (decoder_start_token_id == -1) {
                 decoder_start_token_id = bos;
@@ -947,7 +993,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
             tmp.push_back(decoder_start_token_id);
         }
         if (llama_model_has_decoder(model)) {
-            llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
+            llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
         }
         llama_kv_cache_clear(lctx);
         llama_synchronize(lctx);
@@ -956,10 +1002,11 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
 
     iparams.model   = model;
     iparams.context = lctx;
+
     return iparams;
 }
 
-void llama_lora_adapters_apply(struct llama_context * ctx, std::vector<llama_lora_adapter_container> & lora_adapters) {
+void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters) {
     llama_lora_adapter_clear(ctx);
     for (auto & la : lora_adapters) {
         if (la.scale != 0.0f) {
@@ -968,9 +1015,12 @@ void llama_lora_adapters_apply(struct llama_context * ctx, std::vector<llama_lor
     }
 }
 
-struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) {
+struct llama_model_params common_model_params_to_llama(common_params & params) {
     auto mparams = llama_model_default_params();
 
+    if (!params.devices.empty()) {
+        mparams.devices = params.devices.data();
+    }
     if (params.n_gpu_layers != -1) {
         mparams.n_gpu_layers = params.n_gpu_layers;
     }
@@ -998,6 +1048,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
     if (s == "f16") {
         return GGML_TYPE_F16;
     }
+    if (s == "bf16") {
+        return GGML_TYPE_BF16;
+    }
     if (s == "q8_0") {
         return GGML_TYPE_Q8_0;
     }
@@ -1017,10 +1070,10 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
         return GGML_TYPE_Q5_1;
     }
 
-    throw std::runtime_error("Invalid cache type: " + s);
+    throw std::runtime_error("Unsupported cache type: " + s);
 }
 
-struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
+struct llama_context_params common_context_params_to_llama(const common_params & params) {
     auto cparams = llama_context_default_params();
 
     cparams.n_ctx             = params.n_ctx;
@@ -1029,7 +1082,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
     cparams.n_ubatch          = params.n_ubatch;
     cparams.n_threads         = params.cpuparams.n_threads;
     cparams.n_threads_batch   = params.cpuparams_batch.n_threads == -1 ?
-                                    params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
+                                params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
     cparams.logits_all        = params.logits_all;
     cparams.embeddings        = params.embedding;
     cparams.rope_scaling_type = params.rope_scaling_type;
@@ -1110,7 +1163,7 @@ static bool curl_perform_with_retry(const std::string& url, CURL* curl, int max_
     return false;
 }
 
-static bool llama_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
+static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
 
     // Initialize libcurl
     std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
@@ -1180,15 +1233,15 @@ static bool llama_download_file(const std::string & url, const std::string & pat
     }
 
     // Send a HEAD request to retrieve the etag and last-modified headers
-    struct llama_load_model_from_url_headers {
+    struct common_load_model_from_url_headers {
         std::string etag;
         std::string last_modified;
     };
-    llama_load_model_from_url_headers headers;
+    common_load_model_from_url_headers headers;
     {
         typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
         auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
-            llama_load_model_from_url_headers *headers = (llama_load_model_from_url_headers *) userdata;
+            common_load_model_from_url_headers *headers = (common_load_model_from_url_headers *) userdata;
 
             static std::regex header_regex("([^:]+): (.*)\r\n");
             static std::regex etag_regex("ETag", std::regex_constants::icase);
@@ -1324,18 +1377,18 @@ static bool llama_download_file(const std::string & url, const std::string & pat
     return true;
 }
 
-struct llama_model * llama_load_model_from_url(
-        const char * model_url,
-        const char * path_model,
-        const char * hf_token,
+struct llama_model * common_load_model_from_url(
+        const std::string & model_url,
+        const std::string & local_path,
+        const std::string & hf_token,
         const struct llama_model_params & params) {
     // Basic validation of the model_url
-    if (!model_url || strlen(model_url) == 0) {
+    if (model_url.empty()) {
         LOG_ERR("%s: invalid model_url\n", __func__);
         return NULL;
     }
 
-    if (!llama_download_file(model_url, path_model, hf_token)) {
+    if (!common_download_file(model_url, local_path, hf_token)) {
         return NULL;
     }
 
@@ -1346,9 +1399,9 @@ struct llama_model * llama_load_model_from_url(
             /*.no_alloc = */ true,
             /*.ctx      = */ NULL,
         };
-        auto * ctx_gguf = gguf_init_from_file(path_model, gguf_params);
+        auto * ctx_gguf = gguf_init_from_file(local_path.c_str(), gguf_params);
         if (!ctx_gguf) {
-            LOG_ERR("\n%s:  failed to load input GGUF from %s\n", __func__, path_model);
+            LOG_ERR("\n%s:  failed to load input GGUF from %s\n", __func__, local_path.c_str());
             return NULL;
         }
 
@@ -1367,13 +1420,13 @@ struct llama_model * llama_load_model_from_url(
         // Verify the first split file format
         // and extract split URL and PATH prefixes
         {
-            if (!llama_split_prefix(split_prefix, sizeof(split_prefix), path_model, 0, n_split)) {
-                LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, path_model, n_split);
+            if (!llama_split_prefix(split_prefix, sizeof(split_prefix), local_path.c_str(), 0, n_split)) {
+                LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, local_path.c_str(), n_split);
                 return NULL;
             }
 
-            if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model_url, 0, n_split)) {
-                LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model_url, n_split);
+            if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model_url.c_str(), 0, n_split)) {
+                LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model_url.c_str(), n_split);
                 return NULL;
             }
         }
@@ -1388,7 +1441,7 @@ struct llama_model * llama_load_model_from_url(
                 char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0};
                 llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split);
 
-                return llama_download_file(split_url, split_path, hf_token);
+                return common_download_file(split_url, split_path, hf_token);
             }, idx));
         }
 
@@ -1400,14 +1453,14 @@ struct llama_model * llama_load_model_from_url(
         }
     }
 
-    return llama_load_model_from_file(path_model, params);
+    return llama_load_model_from_file(local_path.c_str(), params);
 }
 
-struct llama_model * llama_load_model_from_hf(
-        const char * repo,
-        const char * model,
-        const char * path_model,
-        const char * hf_token,
+struct llama_model * common_load_model_from_hf(
+        const std::string & repo,
+        const std::string & remote_path,
+        const std::string & local_path,
+        const std::string & hf_token,
         const struct llama_model_params & params) {
     // construct hugging face model url:
     //
@@ -1421,27 +1474,27 @@ struct llama_model * llama_load_model_from_hf(
     std::string model_url = "https://huggingface.co/";
     model_url += repo;
     model_url += "/resolve/main/";
-    model_url += model;
+    model_url += remote_path;
 
-    return llama_load_model_from_url(model_url.c_str(), path_model, hf_token, params);
+    return common_load_model_from_url(model_url, local_path, hf_token, params);
 }
 
 #else
 
-struct llama_model * llama_load_model_from_url(
-        const char * /*model_url*/,
-        const char * /*path_model*/,
-        const char * /*hf_token*/,
+struct llama_model * common_load_model_from_url(
+        const std::string & /*model_url*/,
+        const std::string & /*local_path*/,
+        const std::string & /*hf_token*/,
         const struct llama_model_params & /*params*/) {
     LOG_WRN("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
     return nullptr;
 }
 
-struct llama_model * llama_load_model_from_hf(
-        const char * /*repo*/,
-        const char * /*model*/,
-        const char * /*path_model*/,
-        const char * /*hf_token*/,
+struct llama_model * common_load_model_from_hf(
+        const std::string & /*repo*/,
+        const std::string & /*remote_path*/,
+        const std::string & /*local_path*/,
+        const std::string & /*hf_token*/,
         const struct llama_model_params & /*params*/) {
     LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
     return nullptr;
@@ -1453,11 +1506,11 @@ struct llama_model * llama_load_model_from_hf(
 // Batch utils
 //
 
-void llama_batch_clear(struct llama_batch & batch) {
+void common_batch_clear(struct llama_batch & batch) {
     batch.n_tokens = 0;
 }
 
-void llama_batch_add(
+void common_batch_add(
                  struct llama_batch & batch,
                         llama_token   id,
                           llama_pos   pos,
@@ -1476,19 +1529,79 @@ void llama_batch_add(
     batch.n_tokens++;
 }
 
+//
+// Token utils
+//
+
+size_t common_lcp(const llama_tokens & a, const llama_tokens & b) {
+    size_t i;
+    for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
+
+    return i;
+}
+
+size_t common_lcs(const llama_tokens & a, const llama_tokens & b) {
+    // check for empty sequences
+    if (a.empty() || b.empty()) {
+        return 0;
+    }
+
+    // get the lengths of the input sequences
+    size_t a_len = a.size();
+    size_t b_len = b.size();
+
+    // initialize the maximum length of the longest common subsequence (LCS)
+    size_t max_length = 0;
+
+    // use two rows instead of a 2D matrix to optimize space
+    std::vector<size_t> prev_row(b_len + 1, 0);
+    std::vector<size_t> curr_row(b_len + 1, 0);
+
+    // iterate through the elements of a
+    for (size_t i = 1; i <= a_len; i++) {
+        // iterate through the elements of b
+        for (size_t j = 1; j <= b_len; j++) {
+            // if elements at the current positions match
+            if (a[i - 1] == b[j - 1]) {
+                // if it's the first element of either sequences, set LCS length to 1
+                if (i == 1 || j == 1) {
+                    curr_row[j] = 1;
+                } else {
+                    // increment LCS length by 1 compared to the previous element
+                    curr_row[j] = prev_row[j - 1] + 1;
+                }
+
+                // update max_length if necessary
+                if (curr_row[j] > max_length) {
+                    max_length = curr_row[j];
+                }
+            } else {
+                // reset LCS length if elements don't match
+                curr_row[j] = 0;
+            }
+        }
+
+        // update the previous row for the next iteration
+        prev_row = curr_row;
+    }
+
+    // return the maximum length of the LCS
+    return max_length;
+}
+
 //
 // Vocab utils
 //
 
-std::vector<llama_token> llama_tokenize(
+std::vector<llama_token> common_tokenize(
   const struct llama_context * ctx,
            const std::string & text,
                         bool   add_special,
                         bool   parse_special) {
-    return llama_tokenize(llama_get_model(ctx), text, add_special, parse_special);
+    return common_tokenize(llama_get_model(ctx), text, add_special, parse_special);
 }
 
-std::vector<llama_token> llama_tokenize(
+std::vector<llama_token> common_tokenize(
     const struct llama_model * model,
            const std::string & text,
                         bool   add_special,
@@ -1507,7 +1620,7 @@ std::vector<llama_token> llama_tokenize(
     return result;
 }
 
-std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
+std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
     std::string piece;
     piece.resize(piece.capacity());  // using string internal cache, 15 bytes + '\n'
     const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special);
@@ -1523,7 +1636,7 @@ std::string llama_token_to_piece(const struct llama_context * ctx, llama_token t
     return piece;
 }
 
-std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
+std::string common_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
     std::string text;
     text.resize(std::max(text.capacity(), tokens.size()));
     int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
@@ -1543,15 +1656,15 @@ std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token>
 // Chat template utils
 //
 
-bool llama_chat_verify_template(const std::string & tmpl) {
+bool common_chat_verify_template(const std::string & tmpl) {
     llama_chat_message chat[] = {{"user", "test"}};
     int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
     return res >= 0;
 }
 
-std::string llama_chat_apply_template(const struct llama_model * model,
+std::string common_chat_apply_template(const struct llama_model * model,
         const std::string & tmpl,
-        const std::vector<llama_chat_msg> & msgs,
+        const std::vector<common_chat_msg> & msgs,
         bool add_ass) {
     int alloc_size = 0;
     bool fallback = false; // indicate if we must fallback to default chatml
@@ -1593,42 +1706,42 @@ std::string llama_chat_apply_template(const struct llama_model * model,
     return formatted_chat;
 }
 
-std::string llama_chat_format_single(const struct llama_model * model,
+std::string common_chat_format_single(const struct llama_model * model,
         const std::string & tmpl,
-        const std::vector<llama_chat_msg> & past_msg,
-        const llama_chat_msg & new_msg,
+        const std::vector<common_chat_msg> & past_msg,
+        const common_chat_msg & new_msg,
         bool add_ass) {
     std::ostringstream ss;
-    auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false);
-    std::vector<llama_chat_msg> chat_new(past_msg);
+    auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, past_msg, false);
+    std::vector<common_chat_msg> chat_new(past_msg);
     // if the past_msg ends with a newline, we must preserve it in the formatted version
     if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
         ss << "\n";
     };
     // format chat with new_msg
     chat_new.push_back(new_msg);
-    auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass);
+    auto fmt_new_msg = common_chat_apply_template(model, tmpl, chat_new, add_ass);
     // get the diff part
     ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
     return ss.str();
 }
 
-std::string llama_chat_format_example(const struct llama_model * model,
+std::string common_chat_format_example(const struct llama_model * model,
         const std::string & tmpl) {
-    std::vector<llama_chat_msg> msgs = {
+    std::vector<common_chat_msg> msgs = {
         {"system",    "You are a helpful assistant"},
         {"user",      "Hello"},
         {"assistant", "Hi there"},
         {"user",      "How are you?"},
     };
-    return llama_chat_apply_template(model, tmpl, msgs, true);
+    return common_chat_apply_template(model, tmpl, msgs, true);
 }
 
 //
 // KV cache utils
 //
 
-void llama_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size) {
+void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size) {
     static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+";
 
     printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d",
@@ -1651,7 +1764,7 @@ void llama_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size) {
     printf("\n=== Done dumping\n");
 }
 
-void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size) {
+void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size) {
     static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
 
     printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n",
@@ -1703,7 +1816,7 @@ void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_siz
 // Embedding utils
 //
 
-void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm) {
+void common_embd_normalize(const float * inp, float * out, int n, int embd_norm) {
     double sum = 0.0;
 
     switch (embd_norm) {
@@ -1737,7 +1850,7 @@ void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm)
     }
 }
 
-float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n){
+float common_embd_similarity_cos(const float * embd1, const float * embd2, int n){
     double sum  = 0.0;
     double sum1 = 0.0;
     double sum2 = 0.0;
@@ -1763,8 +1876,8 @@ float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n)
 // Control vector utils
 //
 
-static llama_control_vector_data llama_control_vector_load_one(const llama_control_vector_load_info & load_info) {
-    llama_control_vector_data result = { -1, {} };
+static common_control_vector_data common_control_vector_load_one(const common_control_vector_load_info & load_info) {
+    common_control_vector_data result = { -1, {} };
 
     ggml_context * ctx = nullptr;
     struct gguf_init_params meta_gguf_params = {
@@ -1848,11 +1961,11 @@ static llama_control_vector_data llama_control_vector_load_one(const llama_contr
     return result;
 }
 
-llama_control_vector_data llama_control_vector_load(const std::vector<llama_control_vector_load_info> & load_infos) {
-    llama_control_vector_data result = { -1, {} };
+common_control_vector_data common_control_vector_load(const std::vector<common_control_vector_load_info> & load_infos) {
+    common_control_vector_data result = { -1, {} };
 
     for (const auto & info : load_infos) {
-        auto cur = llama_control_vector_load_one(info);
+        auto cur = common_control_vector_load_one(info);
 
         if (cur.n_embd == -1) {
             result.n_embd = -1;
@@ -1882,211 +1995,3 @@ llama_control_vector_data llama_control_vector_load(const std::vector<llama_cont
     return result;
 }
 
-//
-// YAML utils
-//
-
-void yaml_dump_vector_float(FILE * stream, const char * prop_name, const std::vector<float> & data) {
-    if (data.empty()) {
-        fprintf(stream, "%s:\n", prop_name);
-        return;
-    }
-
-    fprintf(stream, "%s: [", prop_name);
-    for (size_t i = 0; i < data.size() - 1; ++i) {
-        fprintf(stream, "%e, ", data[i]);
-    }
-    fprintf(stream, "%e]\n", data.back());
-}
-
-void yaml_dump_vector_int(FILE * stream, const char * prop_name, const std::vector<int> & data) {
-    if (data.empty()) {
-        fprintf(stream, "%s:\n", prop_name);
-        return;
-    }
-
-    fprintf(stream, "%s: [", prop_name);
-    for (size_t i = 0; i < data.size() - 1; ++i) {
-        fprintf(stream, "%d, ", data[i]);
-    }
-    fprintf(stream, "%d]\n", data.back());
-}
-
-void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const char * data) {
-    std::string data_str(data == NULL ? "" : data);
-
-    if (data_str.empty()) {
-        fprintf(stream, "%s:\n", prop_name);
-        return;
-    }
-
-    size_t pos_start = 0;
-    size_t pos_found = 0;
-
-    if (std::isspace(data_str[0]) || std::isspace(data_str.back())) {
-        data_str = std::regex_replace(data_str, std::regex("\n"), "\\n");
-        data_str = std::regex_replace(data_str, std::regex("\""), "\\\"");
-        data_str = std::regex_replace(data_str, std::regex(R"(\\[^n"])"), R"(\$&)");
-        data_str = "\"" + data_str + "\"";
-        fprintf(stream, "%s: %s\n", prop_name, data_str.c_str());
-        return;
-    }
-
-    if (data_str.find('\n') == std::string::npos) {
-        fprintf(stream, "%s: %s\n", prop_name, data_str.c_str());
-        return;
-    }
-
-    fprintf(stream, "%s: |\n", prop_name);
-    while ((pos_found = data_str.find('\n', pos_start)) != std::string::npos) {
-        fprintf(stream, "  %s\n", data_str.substr(pos_start, pos_found-pos_start).c_str());
-        pos_start = pos_found + 1;
-    }
-}
-
-void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const llama_context * lctx,
-                               const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
-    const auto & sparams = params.sparams;
-
-    fprintf(stream, "build_commit: %s\n",        LLAMA_COMMIT);
-    fprintf(stream, "build_number: %d\n",        LLAMA_BUILD_NUMBER);
-    fprintf(stream, "cpu_has_arm_fma: %s\n",     ggml_cpu_has_arm_fma()     ? "true" : "false");
-    fprintf(stream, "cpu_has_avx: %s\n",         ggml_cpu_has_avx()         ? "true" : "false");
-    fprintf(stream, "cpu_has_avx_vnni: %s\n",    ggml_cpu_has_avx_vnni()    ? "true" : "false");
-    fprintf(stream, "cpu_has_avx2: %s\n",        ggml_cpu_has_avx2()        ? "true" : "false");
-    fprintf(stream, "cpu_has_avx512: %s\n",      ggml_cpu_has_avx512()      ? "true" : "false");
-    fprintf(stream, "cpu_has_avx512_vbmi: %s\n", ggml_cpu_has_avx512_vbmi() ? "true" : "false");
-    fprintf(stream, "cpu_has_avx512_vnni: %s\n", ggml_cpu_has_avx512_vnni() ? "true" : "false");
-    fprintf(stream, "cpu_has_cuda: %s\n",        ggml_cpu_has_cuda()        ? "true" : "false");
-    fprintf(stream, "cpu_has_vulkan: %s\n",      ggml_cpu_has_vulkan()      ? "true" : "false");
-    fprintf(stream, "cpu_has_kompute: %s\n",     ggml_cpu_has_kompute()     ? "true" : "false");
-    fprintf(stream, "cpu_has_fma: %s\n",         ggml_cpu_has_fma()         ? "true" : "false");
-    fprintf(stream, "cpu_has_gpublas: %s\n",     ggml_cpu_has_gpublas()     ? "true" : "false");
-    fprintf(stream, "cpu_has_neon: %s\n",        ggml_cpu_has_neon()        ? "true" : "false");
-    fprintf(stream, "cpu_has_sve: %s\n",         ggml_cpu_has_sve()         ? "true" : "false");
-    fprintf(stream, "cpu_has_f16c: %s\n",        ggml_cpu_has_f16c()        ? "true" : "false");
-    fprintf(stream, "cpu_has_fp16_va: %s\n",     ggml_cpu_has_fp16_va()     ? "true" : "false");
-    fprintf(stream, "cpu_has_riscv_v: %s\n",     ggml_cpu_has_riscv_v()     ? "true" : "false");
-    fprintf(stream, "cpu_has_wasm_simd: %s\n",   ggml_cpu_has_wasm_simd()   ? "true" : "false");
-    fprintf(stream, "cpu_has_blas: %s\n",        ggml_cpu_has_blas()        ? "true" : "false");
-    fprintf(stream, "cpu_has_sse3: %s\n",        ggml_cpu_has_sse3()        ? "true" : "false");
-    fprintf(stream, "cpu_has_vsx: %s\n",         ggml_cpu_has_vsx()         ? "true" : "false");
-    fprintf(stream, "cpu_has_matmul_int8: %s\n", ggml_cpu_has_matmul_int8() ? "true" : "false");
-
-#ifdef NDEBUG
-    fprintf(stream, "debug: false\n");
-#else
-    fprintf(stream, "debug: true\n");
-#endif // NDEBUG
-
-    fprintf(stream, "model_desc: %s\n", model_desc);
-    fprintf(stream, "n_vocab: %d  # output size of the final layer, 32001 for some models\n", llama_n_vocab(llama_get_model(lctx)));
-
-#ifdef __OPTIMIZE__
-    fprintf(stream, "optimize: true\n");
-#else
-    fprintf(stream, "optimize: false\n");
-#endif // __OPTIMIZE__
-
-    fprintf(stream, "time: %s\n", timestamp.c_str());
-
-    fprintf(stream, "\n");
-    fprintf(stream, "###############\n");
-    fprintf(stream, "# User Inputs #\n");
-    fprintf(stream, "###############\n");
-    fprintf(stream, "\n");
-
-    fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str());
-    fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch);
-    fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
-    fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
-    fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
-    fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
-    fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
-    fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);
-    yaml_dump_string_multiline(stream, "grammar", sparams.grammar.c_str());
-    fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
-    fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
-    fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
-    fprintf(stream, "ignore_eos: %s # default: false\n", sparams.ignore_eos ? "true" : "false");
-
-    yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str());
-    fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false");
-    yaml_dump_string_multiline(stream, "in_suffix", params.input_prefix.c_str());
-    fprintf(stream, "interactive: %s # default: false\n", params.interactive ? "true" : "false");
-    fprintf(stream, "interactive_first: %s # default: false\n", params.interactive_first ? "true" : "false");
-    fprintf(stream, "keep: %d # default: 0\n", params.n_keep);
-    fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str());
-
-    fprintf(stream, "logit_bias:\n");
-    for (const auto & logit_bias : sparams.logit_bias) {
-        fprintf(stream, "  %d: %f", logit_bias.token, logit_bias.bias);
-    }
-
-    fprintf(stream, "lora:\n");
-    for (auto & la : params.lora_adapters) {
-        if (la.scale == 1.0f) {
-            fprintf(stream, "  - %s\n", la.path.c_str());
-        }
-    }
-    fprintf(stream, "lora_scaled:\n");
-    for (auto & la : params.lora_adapters) {
-        if (la.scale != 1.0f) {
-            fprintf(stream, "  - %s: %f\n", la.path.c_str(), la.scale);
-        }
-    }
-    fprintf(stream, "lora_init_without_apply: %s # default: false\n", params.lora_init_without_apply ? "true" : "false");
-    fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
-    fprintf(stream, "min_keep: %d # default: 0 (disabled)\n", sparams.min_keep);
-    fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);
-    fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau);
-    fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);
-    fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false");
-    fprintf(stream, "model: %s # default: %s\n", params.model.c_str(), DEFAULT_MODEL_PATH);
-    fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str());
-    fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false");
-    fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers);
-    fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict);
-    fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", sparams.n_probs);
-    fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false");
-    fprintf(stream, "penalize_nl: %s # default: false\n", sparams.penalize_nl ? "true" : "false");
-    fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type);
-    fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride);
-    fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present);
-    yaml_dump_string_multiline(stream, "prompt", params.prompt.c_str());
-    fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str());
-    fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false");
-    fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false");
-    yaml_dump_vector_int(stream, "prompt_tokens", prompt_tokens);
-    fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.penalty_repeat);
-
-    fprintf(stream, "reverse_prompt:\n");
-    for (std::string ap : params.antiprompt) {
-        size_t pos = 0;
-        while ((pos = ap.find('\n', pos)) != std::string::npos) {
-            ap.replace(pos, 1, "\\n");
-            pos += 1;
-        }
-
-        fprintf(stream, "  - %s\n", ap.c_str());
-    }
-
-    fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base);
-    fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale);
-    fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
-    fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
-    fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
-    fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
-
-    const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());
-    yaml_dump_vector_float(stream, "tensor_split", tensor_split_vector);
-
-    fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z);
-    fprintf(stream, "threads: %d # default: %u\n", params.cpuparams.n_threads, std::thread::hardware_concurrency());
-    fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
-    fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
-    fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
-    fprintf(stream, "typ_p: %f # default: 1.0\n", sparams.typ_p);
-    fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
-    fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
-}

+ 179 - 112
llama/common.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -50,22 +50,24 @@
 
 #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
 
-struct llama_lora_adapter_info {
+struct common_lora_adapter_info {
     std::string path;
     float scale;
 };
 
-struct llama_lora_adapter_container : llama_lora_adapter_info {
+struct common_lora_adapter_container : common_lora_adapter_info {
     struct llama_lora_adapter * adapter;
 };
 
+using llama_tokens = std::vector<llama_token>;
+
 // build info
 extern int LLAMA_BUILD_NUMBER;
 extern char const * LLAMA_COMMIT;
 extern char const * LLAMA_COMPILER;
 extern char const * LLAMA_BUILD_TARGET;
 
-struct llama_control_vector_load_info;
+struct common_control_vector_load_info;
 
 //
 // CPU utils
@@ -108,14 +110,17 @@ enum llama_example {
     LLAMA_EXAMPLE_COUNT,
 };
 
-enum gpt_sampler_type {
-    GPT_SAMPLER_TYPE_NONE        = 0,
-    GPT_SAMPLER_TYPE_TOP_K       = 1,
-    GPT_SAMPLER_TYPE_TOP_P       = 2,
-    GPT_SAMPLER_TYPE_MIN_P       = 3,
-    GPT_SAMPLER_TYPE_TFS_Z       = 4,
-    GPT_SAMPLER_TYPE_TYPICAL_P   = 5,
-    GPT_SAMPLER_TYPE_TEMPERATURE = 6,
+enum common_sampler_type {
+    COMMON_SAMPLER_TYPE_NONE        = 0,
+    COMMON_SAMPLER_TYPE_DRY         = 1,
+    COMMON_SAMPLER_TYPE_TOP_K       = 2,
+    COMMON_SAMPLER_TYPE_TOP_P       = 3,
+    COMMON_SAMPLER_TYPE_MIN_P       = 4,
+  //COMMON_SAMPLER_TYPE_TFS_Z       = 5,
+    COMMON_SAMPLER_TYPE_TYPICAL_P   = 6,
+    COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
+    COMMON_SAMPLER_TYPE_XTC         = 8,
+    COMMON_SAMPLER_TYPE_INFILL      = 9,
 };
 
 // dimensionality reduction methods, used by cvector-generator
@@ -124,39 +129,49 @@ enum dimre_method {
     DIMRE_METHOD_MEAN,
 };
 
-// sampler parameters
-struct gpt_sampler_params {
+// sampling parameters
+struct common_params_sampling {
     uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
 
-    int32_t n_prev            = 64;    // number of previous tokens to remember
-    int32_t n_probs           = 0;     // if greater than 0, output the probabilities of top n_probs tokens.
-    int32_t min_keep          = 0;     // 0 = disabled, otherwise samplers should return at least min_keep tokens
-    int32_t top_k             = 40;    // <= 0 to use vocab size
-    float   top_p             = 0.95f; // 1.0 = disabled
-    float   min_p             = 0.05f; // 0.0 = disabled
-    float   tfs_z             = 1.00f; // 1.0 = disabled
-    float   typ_p             = 1.00f; // typical_p, 1.0 = disabled
-    float   temp              = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
-    float   dynatemp_range    = 0.00f; // 0.0 = disabled
-    float   dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
-    int32_t penalty_last_n    = 64;    // last n tokens to penalize (0 = disable penalty, -1 = context size)
-    float   penalty_repeat    = 1.00f; // 1.0 = disabled
-    float   penalty_freq      = 0.00f; // 0.0 = disabled
-    float   penalty_present   = 0.00f; // 0.0 = disabled
-    int32_t mirostat          = 0;     // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
-    float   mirostat_tau      = 5.00f; // target entropy
-    float   mirostat_eta      = 0.10f; // learning rate
-    bool    penalize_nl       = false; // consider newlines as a repeatable token
-    bool    ignore_eos        = false;
-    bool    no_perf           = false; // disable performance metrics
-
-    std::vector<enum gpt_sampler_type> samplers = {
-        GPT_SAMPLER_TYPE_TOP_K,
-        GPT_SAMPLER_TYPE_TFS_Z,
-        GPT_SAMPLER_TYPE_TYPICAL_P,
-        GPT_SAMPLER_TYPE_TOP_P,
-        GPT_SAMPLER_TYPE_MIN_P,
-        GPT_SAMPLER_TYPE_TEMPERATURE
+    int32_t n_prev             = 64;    // number of previous tokens to remember
+    int32_t n_probs            = 0;     // if greater than 0, output the probabilities of top n_probs tokens.
+    int32_t min_keep           = 0;     // 0 = disabled, otherwise samplers should return at least min_keep tokens
+    int32_t top_k              = 40;    // <= 0 to use vocab size
+    float   top_p              = 0.95f; // 1.0 = disabled
+    float   min_p              = 0.05f; // 0.0 = disabled
+    float   xtc_probability    = 0.00f; // 0.0 = disabled
+    float   xtc_threshold      = 0.10f; // > 0.5 disables XTC
+    float   typ_p              = 1.00f; // typical_p, 1.0 = disabled
+    float   temp               = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
+    float   dynatemp_range     = 0.00f; // 0.0 = disabled
+    float   dynatemp_exponent  = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
+    int32_t penalty_last_n     = 64;    // last n tokens to penalize (0 = disable penalty, -1 = context size)
+    float   penalty_repeat     = 1.00f; // 1.0 = disabled
+    float   penalty_freq       = 0.00f; // 0.0 = disabled
+    float   penalty_present    = 0.00f; // 0.0 = disabled
+    float   dry_multiplier     = 0.0f;  // 0.0 = disabled;      DRY repetition penalty for tokens extending repetition:
+    float   dry_base           = 1.75f; // 0.0 = disabled;      multiplier * base ^ (length of sequence before token - allowed length)
+    int32_t dry_allowed_length = 2;     // tokens extending repetitions beyond this receive penalty
+    int32_t dry_penalty_last_n = -1;    // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
+    int32_t mirostat           = 0;     // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
+    float   mirostat_tau       = 5.00f; // target entropy
+    float   mirostat_eta       = 0.10f; // learning rate
+    bool    penalize_nl        = false; // consider newlines as a repeatable token
+    bool    ignore_eos         = false;
+    bool    no_perf            = false; // disable performance metrics
+    bool    timing_per_token   = false;
+
+    std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"};     // default sequence breakers for DRY
+
+
+    std::vector<enum common_sampler_type> samplers = {
+        COMMON_SAMPLER_TYPE_DRY,
+        COMMON_SAMPLER_TYPE_TOP_K,
+        COMMON_SAMPLER_TYPE_TYPICAL_P,
+        COMMON_SAMPLER_TYPE_TOP_P,
+        COMMON_SAMPLER_TYPE_MIN_P,
+        COMMON_SAMPLER_TYPE_XTC,
+        COMMON_SAMPLER_TYPE_TEMPERATURE,
     };
 
     std::string grammar; // optional BNF-like grammar to constrain sampling
@@ -167,21 +182,30 @@ struct gpt_sampler_params {
     std::string print() const;
 };
 
-struct gpt_params {
+struct common_params_speculative {
+    std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
+    int32_t n_ctx        =     0; // draft context size
+    int32_t n_max        =    16; // maximum number of tokens to draft during speculative decoding
+    int32_t n_min        =     5; // minimum number of draft tokens to use for speculative decoding
+    int32_t n_gpu_layers =    -1; // number of layers to store in VRAM for the draft model (-1 - use default)
+    float   p_split      =  0.1f; // speculative decoding split probability
+    float   p_min        =  0.9f; // minimum speculative decoding probability (greedy)
+
+    struct cpu_params cpuparams;
+    struct cpu_params cpuparams_batch;
+
+    std::string model = ""; // draft model for speculative decoding                          // NOLINT
+};
+
+struct common_params {
     int32_t n_predict             =    -1; // new tokens to predict
-    int32_t n_ctx                 =     0; // context size
+    int32_t n_ctx                 =  4096; // context size
     int32_t n_batch               =  2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
     int32_t n_ubatch              =   512; // physical batch size for prompt processing (must be >=32 to use BLAS)
     int32_t n_keep                =     0; // number of tokens to keep from initial prompt
-    int32_t n_draft               =     5; // number of tokens to draft during speculative decoding
     int32_t n_chunks              =    -1; // max number of chunks to process (-1 = unlimited)
     int32_t n_parallel            =     1; // number of parallel sequences to decode
     int32_t n_sequences           =     1; // number of sequences to decode
-    float   p_split               =  0.1f; // speculative decoding split probability
-    int32_t n_gpu_layers          =    -1; // number of layers to store in VRAM (-1 - use default)
-    int32_t n_gpu_layers_draft    =    -1; // number of layers to store in VRAM for the draft model (-1 - use default)
-    int32_t main_gpu              =     0; // the GPU that is used for scratch and small tensors
-    float   tensor_split[128]     =   {0}; // how split tensors should be distributed across GPUs
     int32_t grp_attn_n            =     1; // group-attention factor
     int32_t grp_attn_w            =   512; // group-attention width
     int32_t n_print               =    -1; // print token count every n tokens (-1 = disabled)
@@ -192,27 +216,31 @@ struct gpt_params {
     float   yarn_beta_fast        = 32.0f; // YaRN low correction dim
     float   yarn_beta_slow        =  1.0f; // YaRN high correction dim
     int32_t yarn_orig_ctx         =     0; // YaRN original context length
-    float   defrag_thold          = -1.0f; // KV cache defragmentation threshold
+    float   defrag_thold          =  0.1f; // KV cache defragmentation threshold
+
+    // offload params
+    std::vector<ggml_backend_dev_t> devices;         // devices to use for offloading
+    int32_t n_gpu_layers                    =    -1; // number of layers to store in VRAM (-1 - use default)
+    int32_t main_gpu                        =     0; // the GPU that is used for scratch and small tensors
+    float   tensor_split[128]               =   {0}; // how split tensors should be distributed across GPUs
+    enum llama_split_mode        split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
 
     struct cpu_params cpuparams;
     struct cpu_params cpuparams_batch;
-    struct cpu_params draft_cpuparams;
-    struct cpu_params draft_cpuparams_batch;
 
     ggml_backend_sched_eval_callback cb_eval = nullptr;
     void * cb_eval_user_data                 = nullptr;
 
     ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
 
-    enum llama_split_mode        split_mode        = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
     enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
     enum llama_pooling_type      pooling_type      = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
     enum llama_attention_type    attention_type    = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
 
-    struct gpt_sampler_params sparams;
+    struct common_params_sampling sampling;
+    struct common_params_speculative speculative;
 
     std::string model                = ""; // model path                                                    // NOLINT
-    std::string model_draft          = ""; // draft model for speculative decoding                          // NOLINT
     std::string model_alias          = "unknown"; // model alias                                            // NOLINT
     std::string model_url            = ""; // model url to download                                         // NOLINT
     std::string hf_token             = ""; // HF token                                                      // NOLINT
@@ -223,7 +251,6 @@ struct gpt_params {
     std::string path_prompt_cache    = ""; // path to file for saving/loading prompt eval state             // NOLINT
     std::string input_prefix         = ""; // string to prefix user inputs with                             // NOLINT
     std::string input_suffix         = ""; // string to suffix user inputs with                             // NOLINT
-    std::string logdir               = ""; // directory in which to save YAML log files                     // NOLINT
     std::string lookup_cache_static  = ""; // path of static ngram cache file for lookup decoding           // NOLINT
     std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding          // NOLINT
     std::string logits_file          = ""; // file for saving *all* logits                                  // NOLINT
@@ -234,9 +261,9 @@ struct gpt_params {
     std::vector<llama_model_kv_override> kv_overrides;
 
     bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply)
-    std::vector<llama_lora_adapter_info> lora_adapters; // lora adapter path with user defined scale
+    std::vector<common_lora_adapter_info> lora_adapters; // lora adapter path with user defined scale
 
-    std::vector<llama_control_vector_load_info> control_vectors; // control vector with user defined scale
+    std::vector<common_control_vector_load_info> control_vectors; // control vector with user defined scale
 
     int32_t verbosity                  = 0;
     int32_t control_vector_layer_start = -1; // layer range for control vector
@@ -294,21 +321,21 @@ struct gpt_params {
 
     // embedding
     bool embedding         = false; // get only sentence embedding
-    int32_t embd_normalize = 2;     // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
+    int32_t embd_normalize = 2;     // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
     std::string embd_out   = "";    // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
-    std::string embd_sep   = "\n";  // separator of embendings
+    std::string embd_sep   = "\n";  // separator of embeddings
     bool reranking         = false; // enable reranking support on server
 
     // server params
     int32_t port           = 8080;         // server listens on this network port
     int32_t timeout_read   = 600;          // http read timeout in seconds
     int32_t timeout_write  = timeout_read; // http write timeout in seconds
-    int     n_threads_http = -1;           // number of threads to process HTTP requests (TODO: support threadpool)
+    int32_t n_threads_http = -1;           // number of threads to process HTTP requests (TODO: support threadpool)
+    int32_t n_cache_reuse  = 0;            // min chunk size to reuse from the cache via KV shifting
 
     std::string hostname      = "127.0.0.1";
     std::string public_path   = "";                                                                         // NOLINT
     std::string chat_template = "";                                                                         // NOLINT
-    std::string system_prompt = "";                                                                         // NOLINT
     bool enable_chat_template = true;
 
     std::vector<std::string> api_keys;
@@ -316,7 +343,10 @@ struct gpt_params {
     std::string ssl_file_key  = "";                                                                         // NOLINT
     std::string ssl_file_cert = "";                                                                         // NOLINT
 
-    bool endpoint_slots   = true;
+    // "advanced" endpoints are disabled by default for better security
+    bool webui            = true;
+    bool endpoint_slots   = false;
+    bool endpoint_props   = false; // only control POST requests, not GET
     bool endpoint_metrics = false;
 
     bool log_json = false;
@@ -371,20 +401,31 @@ struct gpt_params {
 
 // call once at the start of a program if it uses libcommon
 // initializes the logging system and prints info about the build
-void gpt_init();
+void common_init();
 
-std::string gpt_params_get_system_info(const gpt_params & params);
+std::string common_params_get_system_info(const common_params & params);
 
-bool parse_cpu_range(const std::string& range, bool(&boolmask)[GGML_MAX_N_THREADS]);
-bool parse_cpu_mask(const std::string& mask, bool(&boolmask)[GGML_MAX_N_THREADS]);
-void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model = nullptr);
+bool parse_cpu_range(const std::string & range, bool(&boolmask)[GGML_MAX_N_THREADS]);
+bool parse_cpu_mask(const std::string & mask, bool(&boolmask)[GGML_MAX_N_THREADS]);
+void postprocess_cpu_params(cpu_params & cpuparams, const cpu_params * role_model = nullptr);
 bool set_process_priority(enum ggml_sched_priority prio);
 
 //
 // String utils
 //
 
-std::vector<std::string> string_split(std::string input, char separator);
+#ifdef __GNUC__
+#ifdef __MINGW32__
+#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
+#else
+#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
+#endif
+#else
+#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...)
+#endif
+
+LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
+std::string string_format(const char * fmt, ...);
 
 std::string string_strip(const std::string & str);
 std::string string_get_sortable_timestamp();
@@ -393,6 +434,7 @@ void string_replace_all(std::string & s, const std::string & search, const std::
 
 template<class T>
 static std::vector<T> string_split(const std::string & str, char delim) {
+    static_assert(!std::is_same<T, std::string>::value, "Please use the specialized version for std::string");
     std::vector<T> values;
     std::istringstream str_stream(str);
     std::string token;
@@ -405,6 +447,22 @@ static std::vector<T> string_split(const std::string & str, char delim) {
     return values;
 }
 
+template<>
+std::vector<std::string> string_split<std::string>(const std::string & input, char separator)
+{
+    std::vector<std::string> parts;
+    size_t begin_pos = 0;
+    size_t separator_pos = input.find(separator);
+    while (separator_pos != std::string::npos) {
+        std::string part = input.substr(begin_pos, separator_pos - begin_pos);
+        parts.emplace_back(part);
+        begin_pos = separator_pos + 1;
+        separator_pos = input.find(separator, begin_pos);
+    }
+    parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos));
+    return parts;
+}
+
 bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
 void string_process_escapes(std::string & input);
 
@@ -427,48 +485,69 @@ std::string fs_get_cache_file(const std::string & filename);
 // Model utils
 //
 
-struct llama_init_result {
+struct common_init_result {
     struct llama_model   * model   = nullptr;
     struct llama_context * context = nullptr;
-    std::vector<llama_lora_adapter_container> lora_adapters;
+    std::vector<common_lora_adapter_container> lora_adapters;
 };
 
-struct llama_init_result    llama_init_from_gpt_params(gpt_params & params);
+struct common_init_result     common_init_from_params(common_params & params);
 
-struct llama_model_params     llama_model_params_from_gpt_params    (const gpt_params & params);
-struct llama_context_params   llama_context_params_from_gpt_params  (const gpt_params & params);
+struct llama_model_params     common_model_params_to_llama  (      common_params & params);
+struct llama_context_params   common_context_params_to_llama(const common_params & params);
 struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);
 
-struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const char * hf_token, const struct llama_model_params & params);
-struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const char * hf_token, const struct llama_model_params & params);
+struct llama_model * common_load_model_from_url(
+    const std::string & model_url,
+    const std::string & local_path,
+    const std::string & hf_token,
+    const struct llama_model_params & params);
+struct llama_model * common_load_model_from_hf(
+    const std::string & repo,
+    const std::string & remote_path,
+    const std::string & local_path,
+    const std::string & hf_token,
+    const struct llama_model_params & params);
 
 // clear LoRA adapters from context, then apply new list of adapters
-void llama_lora_adapters_apply(struct llama_context * ctx, std::vector<llama_lora_adapter_container> & lora_adapters);
+void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters);
 
+//
 // Batch utils
+//
 
-void llama_batch_clear(struct llama_batch & batch);
+void common_batch_clear(struct llama_batch & batch);
 
-void llama_batch_add(
+void common_batch_add(
                  struct llama_batch & batch,
                         llama_token   id,
                           llama_pos   pos,
     const std::vector<llama_seq_id> & seq_ids,
                                bool   logits);
 
+//
+// Token utils
+//
+
+// longest common prefix
+size_t common_lcp(const llama_tokens & a, const llama_tokens & b);
+
+// longet common subsequence
+size_t common_lcs(const llama_tokens & a, const llama_tokens & b);
+
 //
 // Vocab utils
 //
 
 // tokenizes a string into a vector of tokens
 // should work similar to Python's `tokenizer.encode`
-std::vector<llama_token> llama_tokenize(
+std::vector<llama_token> common_tokenize(
   const struct llama_context * ctx,
            const std::string & text,
                         bool   add_special,
                         bool   parse_special = false);
 
-std::vector<llama_token> llama_tokenize(
+std::vector<llama_token> common_tokenize(
     const struct llama_model * model,
            const std::string & text,
                         bool   add_special,
@@ -476,7 +555,7 @@ std::vector<llama_token> llama_tokenize(
 
 // tokenizes a token into a piece, optionally renders special/control tokens
 // should work similar to Python's `tokenizer.id_to_piece`
-std::string llama_token_to_piece(
+std::string common_token_to_piece(
         const struct llama_context * ctx,
                        llama_token   token,
                        bool          special = true);
@@ -484,7 +563,7 @@ std::string llama_token_to_piece(
 // detokenizes a vector of tokens into a string
 // should work similar to Python's `tokenizer.decode`
 // optionally renders special/control tokens
-std::string llama_detokenize(
+std::string common_detokenize(
                          llama_context * ctx,
         const std::vector<llama_token> & tokens,
                                   bool   special = true);
@@ -494,31 +573,31 @@ std::string llama_detokenize(
 //
 
 // same with llama_chat_message, but uses std::string
-struct llama_chat_msg {
+struct common_chat_msg {
     std::string role;
     std::string content;
 };
 
 // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
-bool llama_chat_verify_template(const std::string & tmpl);
+bool common_chat_verify_template(const std::string & tmpl);
 
 // CPP wrapper for llama_chat_apply_template
 // If the built-in template is not supported, we default to chatml
 // If the custom "tmpl" is not supported, we throw an error
-std::string llama_chat_apply_template(const struct llama_model * model,
+std::string common_chat_apply_template(const struct llama_model * model,
         const std::string & tmpl,
-        const std::vector<llama_chat_msg> & chat,
+        const std::vector<common_chat_msg> & chat,
         bool add_ass);
 
 // Format single message, while taking into account the position of that message in chat history
-std::string llama_chat_format_single(const struct llama_model * model,
+std::string common_chat_format_single(const struct llama_model * model,
         const std::string & tmpl,
-        const std::vector<llama_chat_msg> & past_msg,
-        const llama_chat_msg & new_msg,
+        const std::vector<common_chat_msg> & past_msg,
+        const common_chat_msg & new_msg,
         bool add_ass);
 
 // Returns an example of formatted chat
-std::string llama_chat_format_example(const struct llama_model * model,
+std::string common_chat_format_example(const struct llama_model * model,
         const std::string & tmpl);
 
 //
@@ -526,31 +605,31 @@ std::string llama_chat_format_example(const struct llama_model * model,
 //
 
 // Dump the KV cache view with the number of sequences per cell.
-void llama_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size = 80);
+void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size = 80);
 
 // Dump the KV cache view showing individual sequences in each cell (long output).
-void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
+void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
 
 //
 // Embedding utils
 //
 
-void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2);
+void common_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2);
 
-float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n);
+float common_embd_similarity_cos(const float * embd1, const float * embd2, int n);
 
 //
 // Control vector utils
 //
 
-struct llama_control_vector_data {
+struct common_control_vector_data {
     int n_embd;
 
     // stores data for layers [1, n_layer] where n_layer = data.size() / n_embd
     std::vector<float> data;
 };
 
-struct llama_control_vector_load_info {
+struct common_control_vector_load_info {
     float strength;
 
     std::string fname;
@@ -558,7 +637,7 @@ struct llama_control_vector_load_info {
 
 // Load control vectors, scale each by strength, and add them together.
 // On error, returns {-1, empty}
-llama_control_vector_data llama_control_vector_load(const std::vector<llama_control_vector_load_info> & load_infos);
+common_control_vector_data common_control_vector_load(const std::vector<common_control_vector_load_info> & load_infos);
 
 //
 // Split utils
@@ -567,15 +646,3 @@ llama_control_vector_data llama_control_vector_load(const std::vector<llama_cont
 static const char * const LLM_KV_SPLIT_NO            = "split.no";
 static const char * const LLM_KV_SPLIT_COUNT         = "split.count";
 static const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
-
-//
-// YAML utils
-//
-
-void yaml_dump_vector_float    (FILE * stream, const char * prop_name, const std::vector<float> & data);
-void yaml_dump_vector_int      (FILE * stream, const char * prop_name, const std::vector<int> & data);
-void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const char * data);
-
-void yaml_dump_non_result_info(
-    FILE * stream, const gpt_params & params, const llama_context * lctx,
-    const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);

+ 47 - 3127
llama/ggml-aarch64.c

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -24,207 +24,52 @@
  * SOFTWARE.
  */
 
-// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
-// SPDX-License-Identifier: MIT
-//
-
-#define GGML_COMMON_IMPL_C
+#define GGML_COMMON_DECL_C
 #include "ggml-common.h"
 
-#include "ggml-quants.h"
+#include "ggml-aarch64.h"
 #include "ggml-impl.h"
-#include "ggml-cpu-impl.h"
-
-#include <math.h>
-#include <string.h>
+#include "ggml-quants.h"
 #include <assert.h>
-#include <float.h>
-#include <stdlib.h> // for qsort
-#include <stdio.h>  // for GGML_ASSERT
-
-#include "ggml-aarch64.h"
-
-#if defined(__GNUC__)
-#pragma GCC diagnostic ignored "-Woverlength-strings"
-#elif defined(_MSC_VER)
-#pragma warning(disable: 4244 4267) // possible loss of data
-#endif
 
 #define UNUSED GGML_UNUSED
 
-// Functions to create the interleaved data layout formats
-
-// interleave 4 block_q4_0s in blocks of blck_size_interleave
-// returns an interleaved block_q4_0x4
-// in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks
-// first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave
-//
-// - in                  : an array of block_q4_0 pointers
-// - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of
-//                         blck_size_interleave bytes
-// - xor_mask            : the mask to convert the nibbles in block_q4_0 quants bytes
-//                         from bias offset form to pure sign form (this saves subtract
-//                         operations durin unpacking)
-//
-#if defined(__AVX__)
-#if defined(__F16C__)
-#if defined(__AVX512F__)
-#define GGML_F32Cx8x2_LOAD(x, y)     _mm512_cvtph_ps(_mm256_set_m128i(_mm_loadu_si128((const __m128i *)(y)), _mm_loadu_si128((const __m128i *)(x))))
-#define GGML_F32Cx16_REPEAT_LOAD(x)  _mm512_cvtph_ps(_mm256_set_m128i(x, x))
-#endif
-// the  _mm256_cvt intrinsics require F16C
-#define GGML_F32Cx8_LOAD(x)     _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
-#define GGML_F32Cx8_REPEAT_LOAD(x, loadMask)     _mm256_cvtph_ps(_mm_shuffle_epi32(_mm_maskload_epi32((int const*)(x), loadMask), 68))
-#define GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask)     _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask))
-#else
-#if defined(__AVX512F__)
-static inline __m512 __avx512_f32cx8x2_load(ggml_fp16_t *x, ggml_fp16_t *y) {
-    float tmp[16];
-
-    for (int i = 0; i < 8; i++) {
-        tmp[i] = GGML_FP16_TO_FP32(x[i]);
-    }
-
-    for (int i = 0; i < 8; i++) {
-        tmp[i + 8] = GGML_FP16_TO_FP32(y[i]);
-    }
-
-    return _mm512_loadu_ps(tmp);
-}
-static inline __m512 __avx512_repeat_f32cx16_load(__m128i x) {
-    float tmp[16];
-    uint16_t tmphalf[8];
-    _mm_storeu_si128((__m128i*)tmphalf, x);
-
-    for (int i = 0; i < 4; i++) {
-        tmp[i] = GGML_FP16_TO_FP32(tmphalf[i]);
-        tmp[i + 4] = GGML_FP16_TO_FP32(tmphalf[i]);
-        tmp[i + 8] = GGML_FP16_TO_FP32(tmphalf[i]);
-        tmp[i + 12] = GGML_FP16_TO_FP32(tmphalf[i]);
-    }
-
-    return _mm512_loadu_ps(tmp);
-}
-#endif
-static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
-    float tmp[8];
-
-    for (int i = 0; i < 8; i++) {
-        tmp[i] = GGML_FP16_TO_FP32(x[i]);
-    }
-
-    return _mm256_loadu_ps(tmp);
-}
-static inline __m256 __avx_repeat_f32cx8_load(ggml_fp16_t *x) {
-    float tmp[8];
-
-    for (int i = 0; i < 4; i++) {
-        tmp[i] = GGML_FP16_TO_FP32(x[i]);
-        tmp[i + 4] = GGML_FP16_TO_FP32(x[i]);
-    }
-
-    return _mm256_loadu_ps(tmp);
-}
-static inline __m256 __avx_rearranged_f32cx8_load(ggml_fp16_t *x, __m128i arrangeMask) {
-    uint16_t tmphalf[8];
-    float tmp[8];
-
-    _mm_storeu_si128((__m128i*)tmphalf, _mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask));
-    for (int i = 0; i < 8; i++) {
-        tmp[i] = GGML_FP16_TO_FP32(tmphalf[i]);
-    }
-
-    return _mm256_loadu_ps(tmp);
-}
-
-#define GGML_F32Cx8_LOAD(x)     __avx_f32cx8_load(x)
-#define GGML_F32Cx8_REPEAT_LOAD(x, loadMask)     __avx_repeat_f32cx8_load(x)
-#define GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask)     __avx_rearranged_f32cx8_load(x, arrangeMask)
-#if defined(__AVX512F__)
-#define GGML_F32Cx8x2_LOAD(x, y)     __avx512_f32cx8x2_load(x, y)
-#define GGML_F32Cx16_REPEAT_LOAD(x)  __avx512_repeat_f32cx16_load(x)
-#endif
-#endif
-#endif
-
-
-#if defined(__AVX2__) || defined(__AVX512F__)
-#if defined(__AVX512F__)
-// add int16_t pairwise and return as 512 bit int vector
-static inline __m512i sum_i16_pairs_int_32x16(const __m512i x) {
-    const __m512i ones = _mm512_set1_epi16(1);
-    return _mm512_madd_epi16(ones, x);
-}
-
-static inline __m512i mul_sum_us8_pairs_int32x16(const __m512i ax, const __m512i sy) {
-#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
-    const __m512i zero = _mm512_setzero_si512();
-    return _mm512_dpbusd_epi32(zero, ax, sy);
-#else
-    // Perform multiplication and create 16-bit values
-    const __m512i dot = _mm512_maddubs_epi16(ax, sy);
-    return sum_i16_pairs_int_32x16(dot);
-#endif
-}
-
-// multiply int8_t, add results pairwise twice and return as 512 bit int vector
-static inline __m512i mul_sum_i8_pairs_int32x16(const __m512i x, const __m512i y) {
-    const __m512i zero = _mm512_setzero_si512();
-    // Get absolute values of x vectors
-    const __m512i ax = _mm512_abs_epi8(x);
-    // Sign the values of the y vectors
-    __mmask64 blt0 = _mm512_movepi8_mask(x);
-    const __m512i sy = _mm512_mask_sub_epi8(y, blt0, zero, y);
-    return mul_sum_us8_pairs_int32x16(ax, sy);
-}
-#endif
-
-// add int16_t pairwise and return as 256 bit int vector
-static inline __m256i sum_i16_pairs_int32x8(const __m256i x) {
-    const __m256i ones = _mm256_set1_epi16(1);
-    return _mm256_madd_epi16(ones, x);
-}
-
-static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) {
-#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
-    const __m256i zero = _mm256_setzero_si256();
-    return _mm256_dpbusd_epi32(zero, ax, sy);
-#else
-    // Perform multiplication and create 16-bit values
-    const __m256i dot = _mm256_maddubs_epi16(ax, sy);
-    return sum_i16_pairs_int32x8(dot);
-#endif
-}
-
-// Integer variant of the function defined in ggml-quants.c
-// multiply int8_t, add results pairwise twice and return as 256 bit int vector
-static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y) {
-#if __AVXVNNIINT8__
-    const __m256i zero = _mm256_setzero_si256();
-    return _mm256_dpbssd_epi32(zero, x, y);
-#else
-    // Get absolute values of x vectors
-    const __m256i ax = _mm256_sign_epi8(x, x);
-    // Sign the values of the y vectors
-    const __m256i sy = _mm256_sign_epi8(y, x);
-    return mul_sum_us8_pairs_int32x8(ax, sy);
-#endif
-}
-#endif
-
-static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) {
+static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
     block_q4_0x4 out;
 
     for (int i = 0; i < 4; i++) {
         out.d[i] = in[i].d;
     }
 
-    for (int i = 0; i < QK4_0 * 2; i++) {
-        int src_offset = (i / (4 * blck_size_interleave)) * blck_size_interleave;
-        int src_id = (i % (4 * blck_size_interleave)) / blck_size_interleave;
-        src_offset += (i % blck_size_interleave);
+    const int end = QK4_0 * 2 / blck_size_interleave;
+
+    if (blck_size_interleave == 8) {
+        const uint64_t xor_mask = 0x8888888888888888ULL;
+        for (int i = 0; i < end; ++i) {
+            int src_id = i % 4;
+            int src_offset = (i / 4) * blck_size_interleave;
+            int dst_offset = i * blck_size_interleave;
+
+            uint64_t elems;
+            // Using memcpy to avoid unaligned memory accesses
+            memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
+            elems ^= xor_mask;
+            memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
+        }
+    } else if (blck_size_interleave == 4) {
+        const uint32_t xor_mask = 0x88888888;
+        for (int i = 0; i < end; ++i) {
+            int src_id = i % 4;
+            int src_offset = (i / 4) * blck_size_interleave;
+            int dst_offset = i * blck_size_interleave;
 
-        out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask;
+            uint32_t elems;
+            memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint32_t));
+            elems ^= xor_mask;
+            memcpy(&out.qs[dst_offset], &elems, sizeof(uint32_t));
+        }
+    } else {
+        GGML_ASSERT(false);
     }
 
     return out;
@@ -234,343 +79,28 @@ static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_in
 // returns an interleaved block_q4_0x8
 // in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks
 // first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave
-static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) {
+static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave) {
     block_q4_0x8 out;
 
     for (int i = 0; i < 8; i++) {
         out.d[i] = in[i].d;
     }
 
-    for (int i = 0; i < QK4_0 * 4; i++) {
-        int src_offset = (i / (8 * blck_size_interleave)) * blck_size_interleave;
-        int src_id = (i % (8 * blck_size_interleave)) / blck_size_interleave;
-        src_offset += (i % blck_size_interleave);
-
-        out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask;
-    }
-
-    return out;
-}
-
-void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) {
-    assert(QK8_0 == 32);
-    assert(k % QK8_0 == 0);
-    const int nb = k / QK8_0;
-
-    block_q8_0x4 * restrict y = (block_q8_0x4 *) vy;
-
-#if defined(__ARM_NEON)
-    float32x4_t srcv[4][8];
-    float id[4];
-
-    for (int i = 0; i < nb; i++) {
-        float32x4_t asrcv[8];
-        float32x4_t amaxv[8];
-
-        for (int row_iter = 0; row_iter < 4; row_iter++) {
-            for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j);
-            for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
-
-            for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
-            for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
-            for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
-
-            const float amax = vmaxvq_f32(amaxv[0]);
-
-            const float d = amax / ((1 << 7) - 1);
-            id[row_iter] = d ? 1.0f / d : 0.0f;
-
-            y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
-        }
-
-        for (int j = 0; j < 8; j++) {
-            float32x4_t v = vmulq_n_f32(srcv[0][j], id[0]);
-            int32x4_t vi = vcvtnq_s32_f32(v);
-            y[i].qs[16 * j + 0] = vgetq_lane_s32(vi, 0);
-            y[i].qs[16 * j + 1] = vgetq_lane_s32(vi, 1);
-            y[i].qs[16 * j + 2] = vgetq_lane_s32(vi, 2);
-            y[i].qs[16 * j + 3] = vgetq_lane_s32(vi, 3);
-
-            v = vmulq_n_f32(srcv[1][j], id[1]);
-            vi = vcvtnq_s32_f32(v);
-            y[i].qs[16 * j + 4] = vgetq_lane_s32(vi, 0);
-            y[i].qs[16 * j + 5] = vgetq_lane_s32(vi, 1);
-            y[i].qs[16 * j + 6] = vgetq_lane_s32(vi, 2);
-            y[i].qs[16 * j + 7] = vgetq_lane_s32(vi, 3);
-
-            v = vmulq_n_f32(srcv[2][j], id[2]);
-            vi = vcvtnq_s32_f32(v);
-            y[i].qs[16 * j + 8] = vgetq_lane_s32(vi, 0);
-            y[i].qs[16 * j + 9] = vgetq_lane_s32(vi, 1);
-            y[i].qs[16 * j + 10] = vgetq_lane_s32(vi, 2);
-            y[i].qs[16 * j + 11] = vgetq_lane_s32(vi, 3);
-
-            v = vmulq_n_f32(srcv[3][j], id[3]);
-            vi = vcvtnq_s32_f32(v);
-            y[i].qs[16 * j + 12] = vgetq_lane_s32(vi, 0);
-            y[i].qs[16 * j + 13] = vgetq_lane_s32(vi, 1);
-            y[i].qs[16 * j + 14] = vgetq_lane_s32(vi, 2);
-            y[i].qs[16 * j + 15] = vgetq_lane_s32(vi, 3);
-        }
-    }
-#else
-    // scalar
-    const int blck_size_interleave = 4;
-    float srcv[4][QK8_0];
-    float id[4];
-
-    for (int i = 0; i < nb; i++) {
-        for (int row_iter = 0; row_iter < 4; row_iter++) {
-            float amax = 0.0f; // absolute max
-
-            for (int j = 0; j < QK8_0; j++) {
-                srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
-                amax = MAX(amax, fabsf(srcv[row_iter][j]));
-            }
-
-            const float d = amax / ((1 << 7) - 1);
-            id[row_iter] = d ? 1.0f / d : 0.0f;
-
-            y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
-        }
-
-        for (int j = 0; j < QK8_0 * 4; j++) {
-            int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
-            int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
-            src_offset += (j % blck_size_interleave);
-
-            float x0 = srcv[src_id][src_offset] * id[src_id];
-            y[i].qs[j] = roundf(x0);
-        }
-    }
-#endif
-}
-
-void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k) {
-    assert(QK8_0 == 32);
-    assert(k % QK8_0 == 0);
-    const int nb = k / QK8_0;
-
-    block_q8_0x4 * restrict y = (block_q8_0x4 *) vy;
-
-#if defined(__ARM_NEON)
-    float32x4_t srcv[4][8];
-    float id[4];
-
-    for (int i = 0; i < nb; i++) {
-        float32x4_t asrcv[8];
-        float32x4_t amaxv[8];
-
-        for (int row_iter = 0; row_iter < 4; row_iter++) {
-            for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j);
-            for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
-
-            for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
-            for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
-            for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
-
-            const float amax = vmaxvq_f32(amaxv[0]);
-
-            const float d = amax / ((1 << 7) - 1);
-            id[row_iter] = d ? 1.0f / d : 0.0f;
-
-            y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
-        }
-
-        for (int j = 0; j < 4; j++) {
-            float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]);
-            int32x4_t vi = vcvtnq_s32_f32(v);
-            y[i].qs[32 * j + 0] = vgetq_lane_s32(vi, 0);
-            y[i].qs[32 * j + 1] = vgetq_lane_s32(vi, 1);
-            y[i].qs[32 * j + 2] = vgetq_lane_s32(vi, 2);
-            y[i].qs[32 * j + 3] = vgetq_lane_s32(vi, 3);
-            v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]);
-            vi = vcvtnq_s32_f32(v);
-            y[i].qs[32 * j + 4] = vgetq_lane_s32(vi, 0);
-            y[i].qs[32 * j + 5] = vgetq_lane_s32(vi, 1);
-            y[i].qs[32 * j + 6] = vgetq_lane_s32(vi, 2);
-            y[i].qs[32 * j + 7] = vgetq_lane_s32(vi, 3);
-
-            v = vmulq_n_f32(srcv[1][2 * j], id[1]);
-            vi = vcvtnq_s32_f32(v);
-            y[i].qs[32 * j + 8] = vgetq_lane_s32(vi, 0);
-            y[i].qs[32 * j + 9] = vgetq_lane_s32(vi, 1);
-            y[i].qs[32 * j + 10] = vgetq_lane_s32(vi, 2);
-            y[i].qs[32 * j + 11] = vgetq_lane_s32(vi, 3);
-            v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]);
-            vi = vcvtnq_s32_f32(v);
-            y[i].qs[32 * j + 12] = vgetq_lane_s32(vi, 0);
-            y[i].qs[32 * j + 13] = vgetq_lane_s32(vi, 1);
-            y[i].qs[32 * j + 14] = vgetq_lane_s32(vi, 2);
-            y[i].qs[32 * j + 15] = vgetq_lane_s32(vi, 3);
-
-            v = vmulq_n_f32(srcv[2][2 * j], id[2]);
-            vi = vcvtnq_s32_f32(v);
-            y[i].qs[32 * j + 16] = vgetq_lane_s32(vi, 0);
-            y[i].qs[32 * j + 17] = vgetq_lane_s32(vi, 1);
-            y[i].qs[32 * j + 18] = vgetq_lane_s32(vi, 2);
-            y[i].qs[32 * j + 19] = vgetq_lane_s32(vi, 3);
-            v = vmulq_n_f32(srcv[2][2 * j + 1], id[2]);
-            vi = vcvtnq_s32_f32(v);
-            y[i].qs[32 * j + 20] = vgetq_lane_s32(vi, 0);
-            y[i].qs[32 * j + 21] = vgetq_lane_s32(vi, 1);
-            y[i].qs[32 * j + 22] = vgetq_lane_s32(vi, 2);
-            y[i].qs[32 * j + 23] = vgetq_lane_s32(vi, 3);
-
-            v = vmulq_n_f32(srcv[3][2 * j], id[3]);
-            vi = vcvtnq_s32_f32(v);
-            y[i].qs[32 * j + 24] = vgetq_lane_s32(vi, 0);
-            y[i].qs[32 * j + 25] = vgetq_lane_s32(vi, 1);
-            y[i].qs[32 * j + 26] = vgetq_lane_s32(vi, 2);
-            y[i].qs[32 * j + 27] = vgetq_lane_s32(vi, 3);
-            v = vmulq_n_f32(srcv[3][2 * j + 1], id[3]);
-            vi = vcvtnq_s32_f32(v);
-            y[i].qs[32 * j + 28] = vgetq_lane_s32(vi, 0);
-            y[i].qs[32 * j + 29] = vgetq_lane_s32(vi, 1);
-            y[i].qs[32 * j + 30] = vgetq_lane_s32(vi, 2);
-            y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3);
-        }
-    }
-#elif defined(__AVX2__) || defined(__AVX__)
-    float id[4];
-    __m256 srcv[4][4];
-    __m256 idvec[4];
-
-    for (int i = 0; i < nb; i++) {
-        for (int row_iter = 0; row_iter < 4; row_iter++) {
-            // Load elements into 4 AVX vectors
-            __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 32 );
-            __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 8 );
-            __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 16 );
-            __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 24 );
-
-            // Compute max(abs(e)) for the block
-            const __m256 signBit = _mm256_set1_ps( -0.0f );
-            __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
-            maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
-            maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
-            maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
-
-            __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
-            max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
-            max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
-            const float maxScalar = _mm_cvtss_f32( max4 );
-
-            // Divided by 127.f to mirror results in quantize_row_q8_0
-            const float d = maxScalar  / 127.f;
-            id[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; //d ? 1.0f / d : 0.0f;
-
-            // Store the scale for the individual block
-            y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
-
-            // Store the values in blocks of eight values - Aim is to use these later for block interleaving
-            srcv[row_iter][0] = v0;
-            srcv[row_iter][1] = v1;
-            srcv[row_iter][2] = v2;
-            srcv[row_iter][3] = v3;
-            idvec[row_iter] = _mm256_set1_ps(id[row_iter]);
-        }
-
-        // The loop iterates four times - The aim is to get 4 corresponding chunks of eight bytes from the original weight blocks that are interleaved
-        for (int j = 0; j < 4; j++) {
-            // Apply the multiplier
-            __m256 v0 = _mm256_mul_ps(srcv[0][j], idvec[0]);
-            __m256 v1 = _mm256_mul_ps(srcv[1][j], idvec[1]);
-            __m256 v2 = _mm256_mul_ps(srcv[2][j], idvec[2]);
-            __m256 v3 = _mm256_mul_ps(srcv[3][j], idvec[3]);
-
-            // Round to nearest integer
-            v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
-            v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
-            v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
-            v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
-
-            // Convert floats to integers
-            __m256i i0 = _mm256_cvtps_epi32( v0 );
-            __m256i i1 = _mm256_cvtps_epi32( v1 );
-            __m256i i2 = _mm256_cvtps_epi32( v2 );
-            __m256i i3 = _mm256_cvtps_epi32( v3 );
+    const int end = QK4_0 * 4 / blck_size_interleave;
+    const uint64_t xor_mask = 0x8888888888888888ULL;
 
-#if defined(__AVX2__)
-            // Convert int32 to int16
-            i0 = _mm256_packs_epi32( i0, i1 );
-            i2 = _mm256_packs_epi32( i2, i3 );
-            // Convert int16 to int8
-            i0 = _mm256_packs_epi16( i0, i2 );
+    for (int i = 0; i < end; ++i) {
+        int src_id = i % 8;
+        int src_offset = (i / 8) * blck_size_interleave;
+        int dst_offset = i * blck_size_interleave;
 
-            //  Permute and store the quantized weights in the required order after the pack instruction
-            const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
-            i0 = _mm256_permutevar8x32_epi32( i0, perm );
-
-            _mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0);
-#else
-            // Since we don't have in AVX some necessary functions,
-            // we split the registers in half and call AVX2 analogs from SSE
-            __m128i ni0 = _mm256_castsi256_si128( i0 );
-            __m128i ni1 = _mm256_extractf128_si256( i0, 1);
-            __m128i ni2 = _mm256_castsi256_si128( i1 );
-            __m128i ni3 = _mm256_extractf128_si256( i1, 1);
-            __m128i ni4 = _mm256_castsi256_si128( i2 );
-            __m128i ni5 = _mm256_extractf128_si256( i2, 1);
-            __m128i ni6 = _mm256_castsi256_si128( i3 );
-            __m128i ni7 = _mm256_extractf128_si256( i3, 1);
-
-            // Convert int32 to int16
-            ni0 = _mm_packs_epi32( ni0, ni1 );
-            ni2 = _mm_packs_epi32( ni2, ni3 );
-            ni4 = _mm_packs_epi32( ni4, ni5 );
-            ni6 = _mm_packs_epi32( ni6, ni7 );
-            // Convert int16 to int8
-            ni0 = _mm_packs_epi16( ni0, ni2 );
-            ni4 = _mm_packs_epi16( ni4, ni6 );
-            _mm_storeu_si128((__m128i *)(y[i].qs + 32 * j), ni0);
-            _mm_storeu_si128((__m128i *)(y[i].qs + 32 * j + 16), ni4);
-#endif
-        }
+        uint64_t elems;
+        memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
+        elems ^= xor_mask;
+        memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
     }
-#else
-    // scalar
-    const int blck_size_interleave = 8;
-    float srcv[4][QK8_0];
-    float id[4];
-
-    for (int i = 0; i < nb; i++) {
-        for (int row_iter = 0; row_iter < 4; row_iter++) {
-            float amax = 0.0f; // absolute max
-
-            for (int j = 0; j < QK8_0; j++) {
-                srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
-                amax = MAX(amax, fabsf(srcv[row_iter][j]));
-            }
-
-            const float d = amax / ((1 << 7) - 1);
-            id[row_iter] = d ? 1.0f / d : 0.0f;
 
-            y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
-        }
-
-        for (int j = 0; j < QK8_0 * 4; j++) {
-            int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
-            int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
-            src_offset += (j % blck_size_interleave);
-
-            float x0 = srcv[src_id][src_offset] * id[src_id];
-            y[i].qs[j] = roundf(x0);
-        }
-    }
-#endif
-}
-
-void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
-    assert(nrow == 4);
-    UNUSED(nrow);
-    if (blck_size_interleave == 4) {
-        quantize_q8_0_4x4(x, vy, n_per_row);
-    } else if (blck_size_interleave == 8) {
-        quantize_q8_0_4x8(x, vy, n_per_row);
-    } else {
-        assert(false);
-    }
+    return out;
 }
 
 static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int blck_size_interleave) {
@@ -596,11 +126,11 @@ static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict ds
             }
 
             if (nrows_interleaved == 8) {
-                *(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave, 0x88);
+                *(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave);
                 out_ptr = (block_q4_0x8 *) out_ptr + 1;
             }
             else if (nrows_interleaved == 4) {
-                *(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave, 0x88);
+                *(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave);
                 out_ptr = (block_q4_0x4 *) out_ptr + 1;
             }
         }
@@ -623,2613 +153,3 @@ size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_
     UNUSED(quant_weights);
     return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8);
 }
-
-void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
-    const int qk = QK8_0;
-    const int nb = n / qk;
-    const int ncols_interleaved = 4;
-    const int blocklen = 4;
-
-    assert (n % qk == 0);
-    assert (nc % ncols_interleaved == 0);
-
-    UNUSED(s);
-    UNUSED(bs);
-    UNUSED(vx);
-    UNUSED(vy);
-    UNUSED(nr);
-    UNUSED(nc);
-    UNUSED(nb);
-    UNUSED(ncols_interleaved);
-    UNUSED(blocklen);
-
-#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
-    if (ggml_cpu_has_neon()) {
-        const void * b_ptr = vx;
-        const void * a_ptr = vy;
-        float * res_ptr = s;
-
-        __asm__ __volatile__(
-            "movi v31.16b, #0x4\n"
-            "movi v30.16b, #0xf0\n"
-            "add %x[b_ptr], %x[b_ptr], #0x8\n"
-            "1:"  // Column loop
-            "add x22, %x[a_ptr], #0x2\n"
-            "movi v29.16b, #0x0\n"
-            "mov x21, %x[nb]\n"
-            "2:"  // Block loop
-            "ldr q28, [%x[b_ptr], #0x0]\n"
-            "ldr q27, [x22, #0x0]\n"
-            "movi v26.4s, #0x0\n"
-            "sub x20, x22, #0x2\n"
-            "ldr q25, [x22, #0x10]\n"
-            "ldr q24, [%x[b_ptr], #0x10]\n"
-            "sub x21, x21, #0x1\n"
-            "add x22, x22, #0x22\n"
-            "ldr q23, [%x[b_ptr], #0x20]\n"
-            "ldr q22, [%x[b_ptr], #0x30]\n"
-            "ld1r { v21.8h }, [x20]\n"
-            "ldr q20, [%x[b_ptr], #-0x8]\n"
-            "sshl v16.16b, v28.16b, v31.16b\n"
-            "and v28.16b, v28.16b, v30.16b\n"
-            "sshl v19.16b, v24.16b, v31.16b\n"
-            "and v24.16b, v24.16b, v30.16b\n"
-            "add %x[b_ptr], %x[b_ptr], #0x48\n"
-            "sshl v18.16b, v23.16b, v31.16b\n"
-            "and v23.16b, v23.16b, v30.16b\n"
-            ".inst 0x4f9be21a  // sdot v26.4s, v16.16b, v27.4b[0]\n"
-            "sshl v17.16b, v22.16b, v31.16b\n"
-            "and v22.16b, v22.16b, v30.16b\n"
-            "fcvtl v21.4s, v21.4h\n"
-            "fcvtl v16.4s, v20.4h\n"
-            ".inst 0x4f99e39a  // sdot v26.4s, v28.16b, v25.4b[0]\n"
-            "fmul v16.4s, v16.4s, v21.4s\n"
-            ".inst 0x4fbbe27a  // sdot v26.4s, v19.16b, v27.4b[1]\n"
-            ".inst 0x4fb9e31a  // sdot v26.4s, v24.16b, v25.4b[1]\n"
-            ".inst 0x4f9bea5a  // sdot v26.4s, v18.16b, v27.4b[2]\n"
-            ".inst 0x4f99eafa  // sdot v26.4s, v23.16b, v25.4b[2]\n"
-            ".inst 0x4fbbea3a  // sdot v26.4s, v17.16b, v27.4b[3]\n"
-            ".inst 0x4fb9eada  // sdot v26.4s, v22.16b, v25.4b[3]\n"
-            "scvtf v26.4s, v26.4s, #0x4\n"
-            "fmla v29.4s, v26.4s, v16.4s\n"
-            "cbnz x21, 2b\n"
-            "sub %x[nc], %x[nc], #0x4\n"
-            "str q29, [%x[res_ptr], #0x0]\n"
-            "add %x[res_ptr], %x[res_ptr], #0x10\n"
-            "cbnz %x[nc], 1b\n"
-            : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
-            : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
-            : "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22"
-            );
-        return;
-    }
-#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
-    float sumf[4];
-    int sumi;
-
-    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
-    for (int x = 0; x < nc / ncols_interleaved; x++) {
-        const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
-
-        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
-        for (int l = 0; l < nb; l++) {
-            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
-                for (int j = 0; j < ncols_interleaved; j++) {
-                    sumi = 0;
-                    for (int i = 0; i < blocklen; ++i) {
-                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
-                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
-                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
-                    }
-                    sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
-                }
-            }
-        }
-        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
-    }
-}
-
-void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
-    const int qk = QK8_0;
-    const int nb = n / qk;
-    const int ncols_interleaved = 4;
-    const int blocklen = 8;
-
-    assert (n % qk == 0);
-    assert (nc % ncols_interleaved == 0);
-
-    UNUSED(s);
-    UNUSED(bs);
-    UNUSED(vx);
-    UNUSED(vy);
-    UNUSED(nr);
-    UNUSED(nc);
-    UNUSED(nb);
-    UNUSED(ncols_interleaved);
-    UNUSED(blocklen);
-
-#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
-    if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
-        const void * b_ptr = vx;
-        const void * a_ptr = vy;
-        float * res_ptr = s;
-
-        __asm__ __volatile__(
-            "movi v2.16b, #0x4\n"
-            "movi v1.16b, #0xf0\n"
-            "add %x[b_ptr], %x[b_ptr], #0x8\n"
-            "1:"  // Column loop
-            "add x23, %x[a_ptr], #0x2\n"
-            "movi v0.16b, #0x0\n"
-            "mov x22, %x[nb]\n"
-            "2:"  // Block loop
-            "ldr q31, [%x[b_ptr], #0x0]\n"
-            "ldr q30, [%x[b_ptr], #0x10]\n"
-            "mov x21, x23\n"
-            "movi v29.4s, #0x0\n"
-            "ldr q28, [%x[b_ptr], #0x20]\n"
-            "ldr q27, [%x[b_ptr], #0x30]\n"
-            "movi v26.4s, #0x0\n"
-            "sub x20, x23, #0x2\n"
-            "ld1r { v25.8h }, [x20]\n"
-            "ldr q24, [%x[b_ptr], #-0x8]\n"
-            "sub x22, x22, #0x1\n"
-            "add x23, x23, #0x22\n"
-            "ld1r { v23.2d }, [x21], #0x8\n"
-            "sshl v22.16b, v31.16b, v2.16b\n"
-            "sshl v16.16b, v30.16b, v2.16b\n"
-            "add %x[b_ptr], %x[b_ptr], #0x48\n"
-            "ld1r { v21.2d }, [x21], #0x8\n"
-            "sshl v20.16b, v28.16b, v2.16b\n"
-            "sshl v19.16b, v27.16b, v2.16b\n"
-            "ld1r { v18.2d }, [x21], #0x8\n"
-            "ld1r { v17.2d }, [x21], #0x8\n"
-            "and v31.16b, v31.16b, v1.16b\n"
-            "and v30.16b, v30.16b, v1.16b\n"
-            ".inst 0x4e9796dd  // sdot v29.4s, v22.16b, v23.16b\n"
-            ".inst 0x4e97961a  // sdot v26.4s, v16.16b, v23.16b\n"
-            "and v28.16b, v28.16b, v1.16b\n"
-            "and v27.16b, v27.16b, v1.16b\n"
-            "fcvtl v25.4s, v25.4h\n"
-            "fcvtl v16.4s, v24.4h\n"
-            ".inst 0x4e95969d  // sdot v29.4s, v20.16b, v21.16b\n"
-            ".inst 0x4e95967a  // sdot v26.4s, v19.16b, v21.16b\n"
-            "fmul v16.4s, v16.4s, v25.4s\n"
-            ".inst 0x4e9297fd  // sdot v29.4s, v31.16b, v18.16b\n"
-            ".inst 0x4e9297da  // sdot v26.4s, v30.16b, v18.16b\n"
-            ".inst 0x4e91979d  // sdot v29.4s, v28.16b, v17.16b\n"
-            ".inst 0x4e91977a  // sdot v26.4s, v27.16b, v17.16b\n"
-            "addp v29.4s, v29.4s, v26.4s\n"
-            "scvtf v29.4s, v29.4s, #0x4\n"
-            "fmla v0.4s, v29.4s, v16.4s\n"
-            "cbnz x22, 2b\n"
-            "sub %x[nc], %x[nc], #0x4\n"
-            "str q0, [%x[res_ptr], #0x0]\n"
-            "add %x[res_ptr], %x[res_ptr], #0x10\n"
-            "cbnz %x[nc], 1b\n"
-            : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
-            : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
-            : "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23"
-        );
-        return;
-    }
-#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
-    float sumf[4];
-    int sumi;
-
-    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
-    for (int x = 0; x < nc / ncols_interleaved; x++) {
-        const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
-
-        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
-        for (int l = 0; l < nb; l++) {
-            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
-                for (int j = 0; j < ncols_interleaved; j++) {
-                    sumi = 0;
-                    for (int i = 0; i < blocklen; ++i) {
-                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
-                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
-                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
-                    }
-                    sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
-                }
-            }
-        }
-        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
-    }
-}
-
-void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
-    const int qk = QK8_0;
-    const int nb = n / qk;
-    const int ncols_interleaved = 8;
-    const int blocklen = 8;
-
-    assert (n % qk == 0);
-    assert (nc % ncols_interleaved == 0);
-
-    UNUSED(s);
-    UNUSED(bs);
-    UNUSED(vx);
-    UNUSED(vy);
-    UNUSED(nr);
-    UNUSED(nc);
-    UNUSED(nb);
-    UNUSED(ncols_interleaved);
-    UNUSED(blocklen);
-
-#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
-#if defined(__ARM_FEATURE_SVE)
-    if (ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) {
-        const void * b_ptr = vx;
-        const void * a_ptr = vy;
-        float * res_ptr = s;
-
-        __asm__ __volatile__(
-            "ptrue p0.b\n"
-            "add %x[b_ptr], %x[b_ptr], #0x10\n"
-            "1:"  // Column loop
-            "add x22, %x[a_ptr], #0x2\n"
-            "mov z31.b, #0x0\n"
-            "mov x21, %x[nb]\n"
-            "2:"  // Block loop
-            "ld1b { z30.b }, p0/Z, [%x[b_ptr]]\n"
-            "ld1b { z29.b }, p0/Z, [%x[b_ptr], #1, MUL VL]\n"
-            "mov z28.s, #0x0\n"
-            "mov z27.s, #0x0\n"
-            "ld1rd { z26.d }, p0/Z, [x22]\n"
-            "ld1b { z25.b }, p0/Z, [%x[b_ptr], #2, MUL VL]\n"
-            "sub x20, x22, #0x2\n"
-            "sub x21, x21, #0x1\n"
-            "ld1b { z24.b }, p0/Z, [%x[b_ptr], #3, MUL VL]\n"
-            "ld1rd { z23.d }, p0/Z, [x22, #8]\n"
-            "lsl z22.b, z30.b, #0x4\n"
-            "lsl z16.b, z29.b, #0x4\n"
-            "and z30.b, z30.b, #0xf0\n"
-            "and z29.b, z29.b, #0xf0\n"
-            "ld1rd { z21.d }, p0/Z, [x22, #16]\n"
-            "ld1rd { z20.d }, p0/Z, [x22, #24]\n"
-            "lsl z19.b, z25.b, #0x4\n"
-            "and z25.b, z25.b, #0xf0\n"
-            "ld1rh { z17.h }, p0/Z, [x20]\n"
-            "ld1h { z18.s }, p0/Z, [%x[b_ptr], #-1, MUL VL]\n"
-            "sdot z28.s, z22.b, z26.b\n"
-            "sdot z27.s, z16.b, z26.b\n"
-            "lsl z16.b, z24.b, #0x4\n"
-            "add x22, x22, #0x22\n"
-            "and z24.b, z24.b, #0xf0\n"
-            "add %x[b_ptr], %x[b_ptr], #0x90\n"
-            "fcvt z17.s, p0/m, z17.h\n"
-            "fcvt z18.s, p0/m, z18.h\n"
-            "sdot z28.s, z19.b, z23.b\n"
-            "sdot z27.s, z16.b, z23.b\n"
-            "fmul z18.s, z18.s, z17.s\n"
-            "sdot z28.s, z30.b, z21.b\n"
-            "sdot z27.s, z29.b, z21.b\n"
-            "sdot z28.s, z25.b, z20.b\n"
-            "sdot z27.s, z24.b, z20.b\n"
-            "uzp1 z17.s, z28.s, z27.s\n"
-            "uzp2 z16.s, z28.s, z27.s\n"
-            "add z17.s, z17.s, z16.s\n"
-            "asr z17.s, z17.s, #0x4\n"
-            "scvtf z17.s, p0/m, z17.s\n"
-            "fmla z31.s, p0/M, z17.s, z18.s\n"
-            "cbnz x21, 2b\n"
-            "sub %x[nc], %x[nc], #0x8\n"
-            "st1w { z31.s }, p0, [%x[res_ptr]]\n"
-            "add %x[res_ptr], %x[res_ptr], #0x20\n"
-            "cbnz %x[nc], 1b\n"
-            : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
-            : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
-            : "memory", "p0", "x20", "x21", "x22", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
-        );
-        return;
-    }
-#endif // #if defined(__ARM_FEATURE_SVE)
-#elif defined(__AVX2__)
-    // Lookup table to convert signed nibbles to signed bytes
-    __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
-    signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
-    __m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
-    __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
-
-    // Permute mask used for easier vector processing at later stages
-    const __m256i m4b = _mm256_set1_epi8(0x0F);
-
-    int64_t b_nb = n / QK4_0;
-
-    const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx;
-    const block_q8_0 * a_ptr_start = (const block_q8_0 *)vy;
-
-    // Process Q8_0 blocks one by one
-    for (int64_t y = 0; y < nr; y++) {
-
-        // Pointers to LHS blocks of block_q8_0 format
-        const block_q8_0 * a_ptr = a_ptr_start + (y * nb);
-
-        // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation
-        for (int64_t x = 0; x < nc / 8; x++) {
-
-            // Pointers to RHS blocks
-            const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
-
-            // Master FP accumulator
-            __m256 acc_row = _mm256_setzero_ps();
-
-            for (int64_t b = 0; b < nb; b++) {
-                // Load 8 blocks of Q4_0 interleaved as 8 bytes (B0 - B7)
-                const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
-                const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 1);
-                const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 2);
-                const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 3);
-
-                // 4-bit -> 8-bit - Sign is maintained
-                const __m256i rhs_vec_0123_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_0123_0, m4b)); // B0(0-7) B1(0-7) B2(0-7) B3(0-7)
-                const __m256i rhs_vec_4567_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_4567_0, m4b)); // B4(0-7) B5(0-7) B6(0-7) B7(0-7)
-                const __m256i rhs_vec_0123_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_0123_1, m4b)); // B0(8-15) B1(8-15) B2(8-15) B3(8-15)
-                const __m256i rhs_vec_4567_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_4567_1, m4b)); // B0(8-15) B1(8-15) B2(8-15) B3(8-15)
-
-                const __m256i rhs_vec_0123_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b)); // B0(16-23) B1(16-23) B2(16-23) B3(16-23)
-                const __m256i rhs_vec_4567_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b)); // B4(16-23) B5(16-23) B6(16-23) B7(16-23)
-                const __m256i rhs_vec_0123_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b)); // B0(24-31) B1(24-31) B2(24-31) B3(24-31)
-                const __m256i rhs_vec_4567_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b)); // B4(24-31) B5(24-31) B6(24-31) B7(24-31)
-
-                // Load the scale values for the 8 blocks interleaved in block_q4_0x8
-                const __m256 col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask);
-
-                // Load and convert to FP32 scale from block_q8_0
-                const __m256 row_scale_f32 = _mm256_set1_ps(GGML_FP16_TO_FP32(a_ptr[b].d));
-
-                // Load the block values in block_q8_0 in batches of 16 bytes and replicate the same across 256 bit vector
-                __m256i lhs_vec_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)a_ptr[b].qs));
-                __m256i lhs_vec_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16)));
-
-                lhs_vec_0 = _mm256_permute2f128_si256(lhs_vec_0, lhs_vec_0, 0); // A0 (0-15) A0(0-15)
-                lhs_vec_1 = _mm256_permute2f128_si256(lhs_vec_1, lhs_vec_1, 0); // A0 (16-31) A0(16-31))
-
-                __m256i iacc = _mm256_setzero_si256();
-
-                // Dot product done within 32 bit lanes and accumulated in the same vector
-                // B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3)
-                // B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7)
-                // ...........................................................................
-                // B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31)
-
-                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0)));
-                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85)));
-
-                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170)));
-                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255)));
-
-                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0)));
-                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85)));
-
-                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170)));
-                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255)));
-
-                // Accumulated values multipled with appropriate scales
-                acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);
-            }
-
-            // Accumulated output values permuted so as to be stored in appropriate order post accumulation
-            acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask);
-            _mm256_storeu_ps(s + (y * nr + x * 8), acc_row);
-        }
-    }
-    return;
-#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
-    {
-        float sumf[8];
-        int sumi;
-
-        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
-        for (int x = 0; x < nc / ncols_interleaved; x++) {
-            const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
-
-            for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
-            for (int l = 0; l < nb; l++) {
-                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
-                    for (int j = 0; j < ncols_interleaved; j++) {
-                        sumi = 0;
-                        for (int i = 0; i < blocklen; ++i) {
-                            const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
-                            const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
-                            sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
-                        }
-                        sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
-                    }
-                }
-            }
-            for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
-        }
-    }
-}
-
-void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
-    const int qk = QK8_0;
-    const int nb = n / qk;
-    const int ncols_interleaved = 4;
-    const int blocklen = 4;
-
-    assert (n % qk == 0);
-    assert (nr % 4 == 0);
-    assert (nc % ncols_interleaved == 0);
-
-    UNUSED(s);
-    UNUSED(bs);
-    UNUSED(vx);
-    UNUSED(vy);
-    UNUSED(nr);
-    UNUSED(nc);
-    UNUSED(nb);
-    UNUSED(ncols_interleaved);
-    UNUSED(blocklen);
-
-#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
-    if (ggml_cpu_has_neon()) {
-        const void * b_ptr = vx;
-        const void * a_ptr = vy;
-        float * res_ptr = s;
-        size_t res_stride = bs * sizeof(float);
-
-        __asm__ __volatile__(
-            "mov x10, %x[nr]\n"
-            "mov x9, #0x88\n"
-            "cmp x10, #0x10\n"
-            "mul x9, %x[nb], x9\n"
-            "blt 4f\n"
-            "1:"  // Row loop
-            "add x28, %x[b_ptr], #0x8\n"
-            "mov x27, %x[nc]\n"
-            "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
-            "2:"  // Column loop
-            "add x25, %x[a_ptr], #0x8\n"
-            "movi v15.16b, #0x0\n"
-            "movi v19.16b, #0x0\n"
-            "mov x24, %x[nb]\n"
-            "add x23, x25, x9\n"
-            "movi v18.16b, #0x0\n"
-            "movi v14.16b, #0x0\n"
-            "add x22, x23, x9\n"
-            "movi v11.16b, #0x0\n"
-            "movi v13.16b, #0x0\n"
-            "add x21, x22, x9\n"
-            "movi v23.16b, #0x0\n"
-            "movi v16.16b, #0x0\n"
-            "movi v25.16b, #0x0\n"
-            "movi v7.16b, #0x0\n"
-            "movi v0.16b, #0x0\n"
-            "movi v4.16b, #0x0\n"
-            "movi v5.16b, #0x0\n"
-            "movi v21.16b, #0x0\n"
-            "movi v8.16b, #0x0\n"
-            "movi v1.16b, #0x0\n"
-            "3:"  // Block loop
-            "ldr q3, [x28, #0x0]\n"
-            "ldr q31, [x25, #0x0]\n"
-            "movi v28.16b, #0x4\n"
-            "movi v10.4s, #0x0\n"
-            "ldr q22, [x28, #0x10]\n"
-            "ldr q6, [x25, #0x10]\n"
-            "movi v29.4s, #0x0\n"
-            "movi v9.4s, #0x0\n"
-            "ldr q27, [x28, #0x20]\n"
-            "ldr q30, [x28, #0x30]\n"
-            "movi v20.4s, #0x0\n"
-            "movi v24.16b, #0xf0\n"
-            "ldr d2, [x25, #-0x8]\n"
-            "ldr d26, [x23, #-0x8]\n"
-            "sshl v12.16b, v3.16b, v28.16b\n"
-            "sub x20, x28, #0x8\n"
-            "ldr d17, [x20, #0x0]\n"
-            "and v3.16b, v3.16b, v24.16b\n"
-            "subs x24, x24, #0x1\n"
-            "add x28, x28, #0x48\n"
-            ".inst 0x4f9fe18a  // sdot v10.4s, v12.16b, v31.4b[0]\n"
-            ".inst 0x4fbfe19d  // sdot v29.4s, v12.16b, v31.4b[1]\n"
-            ".inst 0x4f9fe989  // sdot v9.4s, v12.16b, v31.4b[2]\n"
-            ".inst 0x4fbfe994  // sdot v20.4s, v12.16b, v31.4b[3]\n"
-            "sshl v31.16b, v22.16b, v28.16b\n"
-            "and v22.16b, v22.16b, v24.16b\n"
-            "fcvtl v17.4s, v17.4h\n"
-            "fcvtl v2.4s, v2.4h\n"
-            "fcvtl v26.4s, v26.4h\n"
-            ".inst 0x4f86e3ea  // sdot v10.4s, v31.16b, v6.4b[0]\n"
-            ".inst 0x4fa6e3fd  // sdot v29.4s, v31.16b, v6.4b[1]\n"
-            ".inst 0x4f86ebe9  // sdot v9.4s, v31.16b, v6.4b[2]\n"
-            ".inst 0x4fa6ebf4  // sdot v20.4s, v31.16b, v6.4b[3]\n"
-            "sshl v6.16b, v27.16b, v28.16b\n"
-            "sshl v28.16b, v30.16b, v28.16b\n"
-            "and v27.16b, v27.16b, v24.16b\n"
-            "and v30.16b, v30.16b, v24.16b\n"
-            "ldr q24, [x25, #0x20]\n"
-            ".inst 0x4f98e0ca  // sdot v10.4s, v6.16b, v24.4b[0]\n"
-            ".inst 0x4fb8e0dd  // sdot v29.4s, v6.16b, v24.4b[1]\n"
-            ".inst 0x4f98e8c9  // sdot v9.4s, v6.16b, v24.4b[2]\n"
-            ".inst 0x4fb8e8d4  // sdot v20.4s, v6.16b, v24.4b[3]\n"
-            "ldr q24, [x25, #0x30]\n"
-            ".inst 0x4f98e38a  // sdot v10.4s, v28.16b, v24.4b[0]\n"
-            ".inst 0x4fb8e39d  // sdot v29.4s, v28.16b, v24.4b[1]\n"
-            ".inst 0x4f98eb89  // sdot v9.4s, v28.16b, v24.4b[2]\n"
-            ".inst 0x4fb8eb94  // sdot v20.4s, v28.16b, v24.4b[3]\n"
-            "ldr q24, [x25, #0x40]\n"
-            ".inst 0x4f98e06a  // sdot v10.4s, v3.16b, v24.4b[0]\n"
-            ".inst 0x4fb8e07d  // sdot v29.4s, v3.16b, v24.4b[1]\n"
-            ".inst 0x4f98e869  // sdot v9.4s, v3.16b, v24.4b[2]\n"
-            ".inst 0x4fb8e874  // sdot v20.4s, v3.16b, v24.4b[3]\n"
-            "ldr q24, [x25, #0x50]\n"
-            ".inst 0x4f98e2ca  // sdot v10.4s, v22.16b, v24.4b[0]\n"
-            ".inst 0x4fb8e2dd  // sdot v29.4s, v22.16b, v24.4b[1]\n"
-            ".inst 0x4f98eac9  // sdot v9.4s, v22.16b, v24.4b[2]\n"
-            ".inst 0x4fb8ead4  // sdot v20.4s, v22.16b, v24.4b[3]\n"
-            "ldr q24, [x25, #0x60]\n"
-            ".inst 0x4f98e36a  // sdot v10.4s, v27.16b, v24.4b[0]\n"
-            ".inst 0x4fb8e37d  // sdot v29.4s, v27.16b, v24.4b[1]\n"
-            ".inst 0x4f98eb69  // sdot v9.4s, v27.16b, v24.4b[2]\n"
-            ".inst 0x4fb8eb74  // sdot v20.4s, v27.16b, v24.4b[3]\n"
-            "ldr q24, [x25, #0x70]\n"
-            "add x25, x25, #0x88\n"
-            ".inst 0x4f98e3ca  // sdot v10.4s, v30.16b, v24.4b[0]\n"
-            ".inst 0x4fb8e3dd  // sdot v29.4s, v30.16b, v24.4b[1]\n"
-            ".inst 0x4f98ebc9  // sdot v9.4s, v30.16b, v24.4b[2]\n"
-            ".inst 0x4fb8ebd4  // sdot v20.4s, v30.16b, v24.4b[3]\n"
-            "fmul v24.4s, v17.4s, v2.s[0]\n"
-            "scvtf v10.4s, v10.4s, #0x4\n"
-            "scvtf v29.4s, v29.4s, #0x4\n"
-            "scvtf v9.4s, v9.4s, #0x4\n"
-            "scvtf v20.4s, v20.4s, #0x4\n"
-            "fmla v15.4s, v10.4s, v24.4s\n"
-            "ldr q24, [x23, #0x0]\n"
-            "fmul v10.4s, v17.4s, v2.s[1]\n"
-            "fmla v19.4s, v29.4s, v10.4s\n"
-            "ldr q10, [x23, #0x10]\n"
-            "fmul v29.4s, v17.4s, v2.s[2]\n"
-            "fmul v2.4s, v17.4s, v2.s[3]\n"
-            "fmla v18.4s, v9.4s, v29.4s\n"
-            "movi v9.4s, #0x0\n"
-            "movi v29.4s, #0x0\n"
-            ".inst 0x4f98e189  // sdot v9.4s, v12.16b, v24.4b[0]\n"
-            ".inst 0x4fb8e19d  // sdot v29.4s, v12.16b, v24.4b[1]\n"
-            "fmla v14.4s, v20.4s, v2.4s\n"
-            "movi v20.4s, #0x0\n"
-            "movi v2.4s, #0x0\n"
-            ".inst 0x4f98e994  // sdot v20.4s, v12.16b, v24.4b[2]\n"
-            ".inst 0x4fb8e982  // sdot v2.4s, v12.16b, v24.4b[3]\n"
-            "ldr q24, [x23, #0x20]\n"
-            ".inst 0x4f8ae3e9  // sdot v9.4s, v31.16b, v10.4b[0]\n"
-            ".inst 0x4faae3fd  // sdot v29.4s, v31.16b, v10.4b[1]\n"
-            ".inst 0x4f8aebf4  // sdot v20.4s, v31.16b, v10.4b[2]\n"
-            ".inst 0x4faaebe2  // sdot v2.4s, v31.16b, v10.4b[3]\n"
-            "ldr q10, [x23, #0x30]\n"
-            ".inst 0x4f98e0c9  // sdot v9.4s, v6.16b, v24.4b[0]\n"
-            ".inst 0x4fb8e0dd  // sdot v29.4s, v6.16b, v24.4b[1]\n"
-            ".inst 0x4f98e8d4  // sdot v20.4s, v6.16b, v24.4b[2]\n"
-            ".inst 0x4fb8e8c2  // sdot v2.4s, v6.16b, v24.4b[3]\n"
-            "ldr q24, [x23, #0x40]\n"
-            ".inst 0x4f8ae389  // sdot v9.4s, v28.16b, v10.4b[0]\n"
-            ".inst 0x4faae39d  // sdot v29.4s, v28.16b, v10.4b[1]\n"
-            ".inst 0x4f8aeb94  // sdot v20.4s, v28.16b, v10.4b[2]\n"
-            ".inst 0x4faaeb82  // sdot v2.4s, v28.16b, v10.4b[3]\n"
-            "ldr q10, [x23, #0x50]\n"
-            ".inst 0x4f98e069  // sdot v9.4s, v3.16b, v24.4b[0]\n"
-            ".inst 0x4fb8e07d  // sdot v29.4s, v3.16b, v24.4b[1]\n"
-            ".inst 0x4f98e874  // sdot v20.4s, v3.16b, v24.4b[2]\n"
-            ".inst 0x4fb8e862  // sdot v2.4s, v3.16b, v24.4b[3]\n"
-            "ldr q24, [x23, #0x60]\n"
-            ".inst 0x4f8ae2c9  // sdot v9.4s, v22.16b, v10.4b[0]\n"
-            ".inst 0x4faae2dd  // sdot v29.4s, v22.16b, v10.4b[1]\n"
-            ".inst 0x4f8aead4  // sdot v20.4s, v22.16b, v10.4b[2]\n"
-            ".inst 0x4faaeac2  // sdot v2.4s, v22.16b, v10.4b[3]\n"
-            "ldr q10, [x23, #0x70]\n"
-            "add x23, x23, #0x88\n"
-            ".inst 0x4f98e369  // sdot v9.4s, v27.16b, v24.4b[0]\n"
-            ".inst 0x4fb8e37d  // sdot v29.4s, v27.16b, v24.4b[1]\n"
-            ".inst 0x4f98eb74  // sdot v20.4s, v27.16b, v24.4b[2]\n"
-            ".inst 0x4fb8eb62  // sdot v2.4s, v27.16b, v24.4b[3]\n"
-            "ldr q24, [x22, #0x0]\n"
-            ".inst 0x4f8ae3c9  // sdot v9.4s, v30.16b, v10.4b[0]\n"
-            ".inst 0x4faae3dd  // sdot v29.4s, v30.16b, v10.4b[1]\n"
-            ".inst 0x4f8aebd4  // sdot v20.4s, v30.16b, v10.4b[2]\n"
-            ".inst 0x4faaebc2  // sdot v2.4s, v30.16b, v10.4b[3]\n"
-            "fmul v10.4s, v17.4s, v26.s[0]\n"
-            "scvtf v9.4s, v9.4s, #0x4\n"
-            "scvtf v29.4s, v29.4s, #0x4\n"
-            "scvtf v20.4s, v20.4s, #0x4\n"
-            "scvtf v2.4s, v2.4s, #0x4\n"
-            "fmla v11.4s, v9.4s, v10.4s\n"
-            "ldr q9, [x22, #0x10]\n"
-            "fmul v10.4s, v17.4s, v26.s[1]\n"
-            "fmla v13.4s, v29.4s, v10.4s\n"
-            "ldr d29, [x22, #-0x8]\n"
-            "fmul v10.4s, v17.4s, v26.s[2]\n"
-            "fmul v26.4s, v17.4s, v26.s[3]\n"
-            "fcvtl v29.4s, v29.4h\n"
-            "fmla v23.4s, v20.4s, v10.4s\n"
-            "movi v20.4s, #0x0\n"
-            "movi v10.4s, #0x0\n"
-            "fmla v16.4s, v2.4s, v26.4s\n"
-            "movi v26.4s, #0x0\n"
-            "movi v2.4s, #0x0\n"
-            ".inst 0x4f98e194  // sdot v20.4s, v12.16b, v24.4b[0]\n"
-            ".inst 0x4fb8e18a  // sdot v10.4s, v12.16b, v24.4b[1]\n"
-            ".inst 0x4f98e99a  // sdot v26.4s, v12.16b, v24.4b[2]\n"
-            ".inst 0x4fb8e982  // sdot v2.4s, v12.16b, v24.4b[3]\n"
-            "ldr q24, [x22, #0x20]\n"
-            ".inst 0x4f89e3f4  // sdot v20.4s, v31.16b, v9.4b[0]\n"
-            ".inst 0x4fa9e3ea  // sdot v10.4s, v31.16b, v9.4b[1]\n"
-            ".inst 0x4f89ebfa  // sdot v26.4s, v31.16b, v9.4b[2]\n"
-            ".inst 0x4fa9ebe2  // sdot v2.4s, v31.16b, v9.4b[3]\n"
-            "ldr q9, [x22, #0x30]\n"
-            ".inst 0x4f98e0d4  // sdot v20.4s, v6.16b, v24.4b[0]\n"
-            ".inst 0x4fb8e0ca  // sdot v10.4s, v6.16b, v24.4b[1]\n"
-            ".inst 0x4f98e8da  // sdot v26.4s, v6.16b, v24.4b[2]\n"
-            ".inst 0x4fb8e8c2  // sdot v2.4s, v6.16b, v24.4b[3]\n"
-            "ldr q24, [x22, #0x40]\n"
-            ".inst 0x4f89e394  // sdot v20.4s, v28.16b, v9.4b[0]\n"
-            ".inst 0x4fa9e38a  // sdot v10.4s, v28.16b, v9.4b[1]\n"
-            ".inst 0x4f89eb9a  // sdot v26.4s, v28.16b, v9.4b[2]\n"
-            ".inst 0x4fa9eb82  // sdot v2.4s, v28.16b, v9.4b[3]\n"
-            "ldr q9, [x22, #0x50]\n"
-            ".inst 0x4f98e074  // sdot v20.4s, v3.16b, v24.4b[0]\n"
-            ".inst 0x4fb8e06a  // sdot v10.4s, v3.16b, v24.4b[1]\n"
-            ".inst 0x4f98e87a  // sdot v26.4s, v3.16b, v24.4b[2]\n"
-            ".inst 0x4fb8e862  // sdot v2.4s, v3.16b, v24.4b[3]\n"
-            "ldr q24, [x22, #0x60]\n"
-            ".inst 0x4f89e2d4  // sdot v20.4s, v22.16b, v9.4b[0]\n"
-            ".inst 0x4fa9e2ca  // sdot v10.4s, v22.16b, v9.4b[1]\n"
-            ".inst 0x4f89eada  // sdot v26.4s, v22.16b, v9.4b[2]\n"
-            ".inst 0x4fa9eac2  // sdot v2.4s, v22.16b, v9.4b[3]\n"
-            "ldr q9, [x22, #0x70]\n"
-            "add x22, x22, #0x88\n"
-            ".inst 0x4f98e374  // sdot v20.4s, v27.16b, v24.4b[0]\n"
-            ".inst 0x4fb8e36a  // sdot v10.4s, v27.16b, v24.4b[1]\n"
-            ".inst 0x4f98eb7a  // sdot v26.4s, v27.16b, v24.4b[2]\n"
-            ".inst 0x4fb8eb62  // sdot v2.4s, v27.16b, v24.4b[3]\n"
-            "ldr q24, [x21, #0x0]\n"
-            ".inst 0x4f89e3d4  // sdot v20.4s, v30.16b, v9.4b[0]\n"
-            ".inst 0x4fa9e3ca  // sdot v10.4s, v30.16b, v9.4b[1]\n"
-            ".inst 0x4f89ebda  // sdot v26.4s, v30.16b, v9.4b[2]\n"
-            ".inst 0x4fa9ebc2  // sdot v2.4s, v30.16b, v9.4b[3]\n"
-            "fmul v9.4s, v17.4s, v29.s[0]\n"
-            "scvtf v20.4s, v20.4s, #0x4\n"
-            "scvtf v10.4s, v10.4s, #0x4\n"
-            "scvtf v26.4s, v26.4s, #0x4\n"
-            "scvtf v2.4s, v2.4s, #0x4\n"
-            "fmla v25.4s, v20.4s, v9.4s\n"
-            "ldr q9, [x21, #0x10]\n"
-            "fmul v20.4s, v17.4s, v29.s[1]\n"
-            "fmla v7.4s, v10.4s, v20.4s\n"
-            "ldr d20, [x21, #-0x8]\n"
-            "fmul v10.4s, v17.4s, v29.s[2]\n"
-            "fmul v29.4s, v17.4s, v29.s[3]\n"
-            "fcvtl v20.4s, v20.4h\n"
-            "fmla v0.4s, v26.4s, v10.4s\n"
-            "movi v26.4s, #0x0\n"
-            "movi v10.4s, #0x0\n"
-            "fmla v4.4s, v2.4s, v29.4s\n"
-            "movi v2.4s, #0x0\n"
-            "movi v29.4s, #0x0\n"
-            ".inst 0x4f98e19a  // sdot v26.4s, v12.16b, v24.4b[0]\n"
-            ".inst 0x4fb8e18a  // sdot v10.4s, v12.16b, v24.4b[1]\n"
-            ".inst 0x4f98e982  // sdot v2.4s, v12.16b, v24.4b[2]\n"
-            ".inst 0x4fb8e99d  // sdot v29.4s, v12.16b, v24.4b[3]\n"
-            "ldr q12, [x21, #0x20]\n"
-            "fmul v24.4s, v17.4s, v20.s[0]\n"
-            ".inst 0x4f89e3fa  // sdot v26.4s, v31.16b, v9.4b[0]\n"
-            ".inst 0x4fa9e3ea  // sdot v10.4s, v31.16b, v9.4b[1]\n"
-            ".inst 0x4f89ebe2  // sdot v2.4s, v31.16b, v9.4b[2]\n"
-            ".inst 0x4fa9ebfd  // sdot v29.4s, v31.16b, v9.4b[3]\n"
-            "ldr q9, [x21, #0x30]\n"
-            "fmul v31.4s, v17.4s, v20.s[1]\n"
-            ".inst 0x4f8ce0da  // sdot v26.4s, v6.16b, v12.4b[0]\n"
-            ".inst 0x4face0ca  // sdot v10.4s, v6.16b, v12.4b[1]\n"
-            ".inst 0x4f8ce8c2  // sdot v2.4s, v6.16b, v12.4b[2]\n"
-            ".inst 0x4face8dd  // sdot v29.4s, v6.16b, v12.4b[3]\n"
-            "ldr q12, [x21, #0x40]\n"
-            "fmul v6.4s, v17.4s, v20.s[2]\n"
-            "fmul v20.4s, v17.4s, v20.s[3]\n"
-            ".inst 0x4f89e39a  // sdot v26.4s, v28.16b, v9.4b[0]\n"
-            ".inst 0x4fa9e38a  // sdot v10.4s, v28.16b, v9.4b[1]\n"
-            ".inst 0x4f89eb82  // sdot v2.4s, v28.16b, v9.4b[2]\n"
-            ".inst 0x4fa9eb9d  // sdot v29.4s, v28.16b, v9.4b[3]\n"
-            "ldr q9, [x21, #0x50]\n"
-            ".inst 0x4f8ce07a  // sdot v26.4s, v3.16b, v12.4b[0]\n"
-            ".inst 0x4face06a  // sdot v10.4s, v3.16b, v12.4b[1]\n"
-            ".inst 0x4f8ce862  // sdot v2.4s, v3.16b, v12.4b[2]\n"
-            ".inst 0x4face87d  // sdot v29.4s, v3.16b, v12.4b[3]\n"
-            "ldr q12, [x21, #0x60]\n"
-            ".inst 0x4f89e2da  // sdot v26.4s, v22.16b, v9.4b[0]\n"
-            ".inst 0x4fa9e2ca  // sdot v10.4s, v22.16b, v9.4b[1]\n"
-            ".inst 0x4f89eac2  // sdot v2.4s, v22.16b, v9.4b[2]\n"
-            ".inst 0x4fa9eadd  // sdot v29.4s, v22.16b, v9.4b[3]\n"
-            "ldr q17, [x21, #0x70]\n"
-            "add x21, x21, #0x88\n"
-            ".inst 0x4f8ce37a  // sdot v26.4s, v27.16b, v12.4b[0]\n"
-            ".inst 0x4face36a  // sdot v10.4s, v27.16b, v12.4b[1]\n"
-            ".inst 0x4f8ceb62  // sdot v2.4s, v27.16b, v12.4b[2]\n"
-            ".inst 0x4faceb7d  // sdot v29.4s, v27.16b, v12.4b[3]\n"
-            ".inst 0x4f91e3da  // sdot v26.4s, v30.16b, v17.4b[0]\n"
-            ".inst 0x4fb1e3ca  // sdot v10.4s, v30.16b, v17.4b[1]\n"
-            ".inst 0x4f91ebc2  // sdot v2.4s, v30.16b, v17.4b[2]\n"
-            ".inst 0x4fb1ebdd  // sdot v29.4s, v30.16b, v17.4b[3]\n"
-            "scvtf v26.4s, v26.4s, #0x4\n"
-            "scvtf v10.4s, v10.4s, #0x4\n"
-            "fmla v5.4s, v26.4s, v24.4s\n"
-            "scvtf v2.4s, v2.4s, #0x4\n"
-            "scvtf v29.4s, v29.4s, #0x4\n"
-            "fmla v21.4s, v10.4s, v31.4s\n"
-            "fmla v8.4s, v2.4s, v6.4s\n"
-            "fmla v1.4s, v29.4s, v20.4s\n"
-            "bgt 3b\n"
-            "mov x20, %x[res_ptr]\n"
-            "subs x27, x27, #0x4\n"
-            "add %x[res_ptr], %x[res_ptr], #0x10\n"
-            "str q15, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q19, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q18, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q14, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q11, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q13, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q23, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q16, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q25, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q7, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q0, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q4, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q5, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q21, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q8, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q1, [x20, #0x0]\n"
-            "bne 2b\n"
-            "mov x20, #0x4\n"
-            "sub x10, x10, #0x10\n"
-            "cmp x10, #0x10\n"
-            "mov %x[res_ptr], x26\n"
-            "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
-            "bge 1b\n"
-            "4:"  // Row loop skip
-            "cbz x10, 9f\n"
-            "5:"  // Row tail: Row loop
-            "add x24, %x[b_ptr], #0x8\n"
-            "mov x23, %x[nc]\n"
-            "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
-            "6:"  // Row tail: Column loop
-            "movi v15.16b, #0x0\n"
-            "movi v19.16b, #0x0\n"
-            "add x25, %x[a_ptr], #0x8\n"
-            "mov x21, %x[nb]\n"
-            "movi v18.16b, #0x0\n"
-            "movi v14.16b, #0x0\n"
-            "7:"  // Row tail: Block loop
-            "ldr q7, [x24, #0x0]\n"
-            "ldr q5, [x25, #0x0]\n"
-            "movi v9.16b, #0x4\n"
-            "movi v4.4s, #0x0\n"
-            "ldr q3, [x24, #0x10]\n"
-            "ldr q2, [x25, #0x10]\n"
-            "movi v1.4s, #0x0\n"
-            "movi v0.4s, #0x0\n"
-            "ldr q13, [x24, #0x20]\n"
-            "ldr q31, [x25, #0x20]\n"
-            "movi v30.4s, #0x0\n"
-            "movi v29.16b, #0xf0\n"
-            "ldr q28, [x24, #0x30]\n"
-            "ldr q27, [x25, #0x30]\n"
-            "sshl v20.16b, v7.16b, v9.16b\n"
-            "sub x20, x24, #0x8\n"
-            "ldr q26, [x25, #0x40]\n"
-            "ldr q25, [x25, #0x50]\n"
-            "sshl v17.16b, v3.16b, v9.16b\n"
-            "and v7.16b, v7.16b, v29.16b\n"
-            "ldr q24, [x25, #0x60]\n"
-            "ldr q16, [x25, #0x70]\n"
-            "sshl v22.16b, v13.16b, v9.16b\n"
-            "and v3.16b, v3.16b, v29.16b\n"
-            "ldr d21, [x20, #0x0]\n"
-            "ldr d12, [x25, #-0x8]\n"
-            ".inst 0x4f85e284  // sdot v4.4s, v20.16b, v5.4b[0]\n"
-            ".inst 0x4fa5e281  // sdot v1.4s, v20.16b, v5.4b[1]\n"
-            ".inst 0x4f85ea80  // sdot v0.4s, v20.16b, v5.4b[2]\n"
-            ".inst 0x4fa5ea9e  // sdot v30.4s, v20.16b, v5.4b[3]\n"
-            "sshl v9.16b, v28.16b, v9.16b\n"
-            "subs x21, x21, #0x1\n"
-            "and v13.16b, v13.16b, v29.16b\n"
-            "and v28.16b, v28.16b, v29.16b\n"
-            "add x25, x25, #0x88\n"
-            "add x24, x24, #0x48\n"
-            "fcvtl v21.4s, v21.4h\n"
-            "fcvtl v12.4s, v12.4h\n"
-            ".inst 0x4f82e224  // sdot v4.4s, v17.16b, v2.4b[0]\n"
-            ".inst 0x4fa2e221  // sdot v1.4s, v17.16b, v2.4b[1]\n"
-            ".inst 0x4f82ea20  // sdot v0.4s, v17.16b, v2.4b[2]\n"
-            ".inst 0x4fa2ea3e  // sdot v30.4s, v17.16b, v2.4b[3]\n"
-            "fmul v11.4s, v21.4s, v12.s[0]\n"
-            "fmul v23.4s, v21.4s, v12.s[1]\n"
-            "fmul v17.4s, v21.4s, v12.s[2]\n"
-            ".inst 0x4f9fe2c4  // sdot v4.4s, v22.16b, v31.4b[0]\n"
-            "fmul v6.4s, v21.4s, v12.s[3]\n"
-            ".inst 0x4fbfe2c1  // sdot v1.4s, v22.16b, v31.4b[1]\n"
-            ".inst 0x4f9feac0  // sdot v0.4s, v22.16b, v31.4b[2]\n"
-            ".inst 0x4fbfeade  // sdot v30.4s, v22.16b, v31.4b[3]\n"
-            ".inst 0x4f9be124  // sdot v4.4s, v9.16b, v27.4b[0]\n"
-            ".inst 0x4fbbe121  // sdot v1.4s, v9.16b, v27.4b[1]\n"
-            ".inst 0x4f9be920  // sdot v0.4s, v9.16b, v27.4b[2]\n"
-            ".inst 0x4fbbe93e  // sdot v30.4s, v9.16b, v27.4b[3]\n"
-            ".inst 0x4f9ae0e4  // sdot v4.4s, v7.16b, v26.4b[0]\n"
-            ".inst 0x4fbae0e1  // sdot v1.4s, v7.16b, v26.4b[1]\n"
-            ".inst 0x4f9ae8e0  // sdot v0.4s, v7.16b, v26.4b[2]\n"
-            ".inst 0x4fbae8fe  // sdot v30.4s, v7.16b, v26.4b[3]\n"
-            ".inst 0x4f99e064  // sdot v4.4s, v3.16b, v25.4b[0]\n"
-            ".inst 0x4fb9e061  // sdot v1.4s, v3.16b, v25.4b[1]\n"
-            ".inst 0x4f99e860  // sdot v0.4s, v3.16b, v25.4b[2]\n"
-            ".inst 0x4fb9e87e  // sdot v30.4s, v3.16b, v25.4b[3]\n"
-            ".inst 0x4f98e1a4  // sdot v4.4s, v13.16b, v24.4b[0]\n"
-            ".inst 0x4fb8e1a1  // sdot v1.4s, v13.16b, v24.4b[1]\n"
-            ".inst 0x4f98e9a0  // sdot v0.4s, v13.16b, v24.4b[2]\n"
-            ".inst 0x4fb8e9be  // sdot v30.4s, v13.16b, v24.4b[3]\n"
-            ".inst 0x4f90e384  // sdot v4.4s, v28.16b, v16.4b[0]\n"
-            ".inst 0x4fb0e381  // sdot v1.4s, v28.16b, v16.4b[1]\n"
-            ".inst 0x4f90eb80  // sdot v0.4s, v28.16b, v16.4b[2]\n"
-            ".inst 0x4fb0eb9e  // sdot v30.4s, v28.16b, v16.4b[3]\n"
-            "scvtf v4.4s, v4.4s, #0x4\n"
-            "scvtf v1.4s, v1.4s, #0x4\n"
-            "scvtf v0.4s, v0.4s, #0x4\n"
-            "fmla v15.4s, v4.4s, v11.4s\n"
-            "scvtf v30.4s, v30.4s, #0x4\n"
-            "fmla v19.4s, v1.4s, v23.4s\n"
-            "fmla v18.4s, v0.4s, v17.4s\n"
-            "fmla v14.4s, v30.4s, v6.4s\n"
-            "bgt 7b\n"
-            "mov x20, %x[res_ptr]\n"
-            "cmp x10, #0x1\n"
-            "str q15, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "ble 8f\n"
-            "cmp x10, #0x2\n"
-            "str q19, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "ble 8f\n"
-            "cmp x10, #0x3\n"
-            "str q18, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "ble 8f\n"
-            "str q14, [x20, #0x0]\n"
-            "8:"  // Row tail: Accumulator store skip
-            "subs x23, x23, #0x4\n"
-            "add %x[res_ptr], %x[res_ptr], #0x10\n"
-            "bne 6b\n"
-            "subs x10, x10, #0x4\n"
-            "add %x[a_ptr], %x[a_ptr], x9\n"
-            "mov %x[res_ptr], x22\n"
-            "bgt 5b\n"
-            "9:"  // Row tail: Row loop skip
-            : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
-            : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
-            : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
-        );
-        return;
-    }
-#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
-    {
-        float sumf[4][4];
-        int sumi;
-
-        for (int y = 0; y < nr / 4; y++) {
-            const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
-            for (int x = 0; x < nc / ncols_interleaved; x++) {
-                const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
-                for (int m = 0; m < 4; m++) {
-                    for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
-                }
-                for (int l = 0; l < nb; l++) {
-                    for (int k = 0; k < (qk / (2 * blocklen)); k++) {
-                        for (int m = 0; m < 4; m++) {
-                            for (int j = 0; j < ncols_interleaved; j++) {
-                                sumi = 0;
-                                for (int i = 0; i < blocklen; ++i) {
-                                    const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
-                                    const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
-                                    sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
-                                            (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
-                                }
-                                sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
-                            }
-                        }
-                    }
-                }
-                for (int m = 0; m < 4; m++) {
-                    for (int j = 0; j < ncols_interleaved; j++)
-                        s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
-                }
-            }
-        }
-    }
-}
-
-void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
-    const int qk = QK8_0;
-    const int nb = n / qk;
-    const int ncols_interleaved = 4;
-    const int blocklen = 8;
-
-    assert (n % qk == 0);
-    assert (nr % 4 == 0);
-    assert (nc % ncols_interleaved == 0);
-
-    UNUSED(s);
-    UNUSED(bs);
-    UNUSED(vx);
-    UNUSED(vy);
-    UNUSED(nr);
-    UNUSED(nc);
-    UNUSED(nb);
-    UNUSED(ncols_interleaved);
-    UNUSED(blocklen);
-
-#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
-    if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
-        const void * b_ptr = vx;
-        const void * a_ptr = vy;
-        float * res_ptr = s;
-        size_t res_stride = bs * sizeof(float);
-
-        __asm__ __volatile__(
-            "mov x10, %x[nr]\n"
-            "mov x9, #0x88\n"
-            "cmp x10, #0x10\n"
-            "mul x9, %x[nb], x9\n"
-            "blt 4f\n"
-            "1:"  // Row loop
-            "add x28, %x[b_ptr], #0x8\n"
-            "mov x27, %x[nc]\n"
-            "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
-            "2:"  // Column loop
-            "add x25, %x[a_ptr], #0x8\n"
-            "movi v2.16b, #0x0\n"
-            "movi v10.16b, #0x0\n"
-            "mov x24, %x[nb]\n"
-            "add x23, x25, x9\n"
-            "movi v12.16b, #0x0\n"
-            "movi v28.16b, #0x0\n"
-            "add x22, x23, x9\n"
-            "movi v11.16b, #0x0\n"
-            "movi v13.16b, #0x0\n"
-            "add x21, x22, x9\n"
-            "movi v22.16b, #0x0\n"
-            "movi v23.16b, #0x0\n"
-            "movi v25.16b, #0x0\n"
-            "movi v5.16b, #0x0\n"
-            "movi v7.16b, #0x0\n"
-            "movi v4.16b, #0x0\n"
-            "movi v6.16b, #0x0\n"
-            "movi v30.16b, #0x0\n"
-            "movi v24.16b, #0x0\n"
-            "movi v14.16b, #0x0\n"
-            "3:"  // Block loop
-            "ldr q21, [x28, #0x0]\n"
-            "ldr q16, [x28, #0x10]\n"
-            "movi v1.16b, #0x4\n"
-            "movi v19.4s, #0x0\n"
-            "ldr q27, [x25, #0x0]\n"
-            "ldr q15, [x25, #0x10]\n"
-            "movi v26.4s, #0x0\n"
-            "movi v18.4s, #0x0\n"
-            "ldr q29, [x28, #0x20]\n"
-            "ldr q3, [x28, #0x30]\n"
-            "movi v17.4s, #0x0\n"
-            "movi v0.16b, #0xf0\n"
-            "ldr d20, [x25, #-0x8]\n"
-            "ldr d9, [x23, #-0x8]\n"
-            "sshl v8.16b, v21.16b, v1.16b\n"
-            "sshl v31.16b, v16.16b, v1.16b\n"
-            "and v21.16b, v21.16b, v0.16b\n"
-            "and v16.16b, v16.16b, v0.16b\n"
-            "sub x20, x28, #0x8\n"
-            "subs x24, x24, #0x1\n"
-            "add x28, x28, #0x48\n"
-            ".inst 0x4e88a773  // smmla v19.4s, v27.16b, v8.16b\n"
-            ".inst 0x4e9fa77a  // smmla v26.4s, v27.16b, v31.16b\n"
-            "ldr q27, [x25, #0x20]\n"
-            ".inst 0x4e88a5f2  // smmla v18.4s, v15.16b, v8.16b\n"
-            ".inst 0x4e9fa5f1  // smmla v17.4s, v15.16b, v31.16b\n"
-            "sshl v15.16b, v29.16b, v1.16b\n"
-            "sshl v1.16b, v3.16b, v1.16b\n"
-            "and v29.16b, v29.16b, v0.16b\n"
-            "and v3.16b, v3.16b, v0.16b\n"
-            "ldr q0, [x25, #0x30]\n"
-            "fcvtl v20.4s, v20.4h\n"
-            ".inst 0x4e8fa773  // smmla v19.4s, v27.16b, v15.16b\n"
-            "fcvtl v9.4s, v9.4h\n"
-            ".inst 0x4e81a77a  // smmla v26.4s, v27.16b, v1.16b\n"
-            "ldr q27, [x25, #0x40]\n"
-            ".inst 0x4e8fa412  // smmla v18.4s, v0.16b, v15.16b\n"
-            ".inst 0x4e81a411  // smmla v17.4s, v0.16b, v1.16b\n"
-            "ldr q0, [x25, #0x50]\n"
-            ".inst 0x4e95a773  // smmla v19.4s, v27.16b, v21.16b\n"
-            ".inst 0x4e90a77a  // smmla v26.4s, v27.16b, v16.16b\n"
-            "ldr q27, [x25, #0x60]\n"
-            ".inst 0x4e95a412  // smmla v18.4s, v0.16b, v21.16b\n"
-            ".inst 0x4e90a411  // smmla v17.4s, v0.16b, v16.16b\n"
-            "ldr q0, [x25, #0x70]\n"
-            "add x25, x25, #0x88\n"
-            ".inst 0x4e9da773  // smmla v19.4s, v27.16b, v29.16b\n"
-            ".inst 0x4e83a77a  // smmla v26.4s, v27.16b, v3.16b\n"
-            "ldr d27, [x20, #0x0]\n"
-            ".inst 0x4e9da412  // smmla v18.4s, v0.16b, v29.16b\n"
-            ".inst 0x4e83a411  // smmla v17.4s, v0.16b, v3.16b\n"
-            "fcvtl v27.4s, v27.4h\n"
-            "uzp1 v0.2d, v19.2d, v26.2d\n"
-            "uzp2 v26.2d, v19.2d, v26.2d\n"
-            "fmul v19.4s, v27.4s, v20.s[0]\n"
-            "scvtf v0.4s, v0.4s, #0x4\n"
-            "scvtf v26.4s, v26.4s, #0x4\n"
-            "fmla v2.4s, v0.4s, v19.4s\n"
-            "ldr q19, [x23, #0x0]\n"
-            "uzp1 v0.2d, v18.2d, v17.2d\n"
-            "uzp2 v18.2d, v18.2d, v17.2d\n"
-            "fmul v17.4s, v27.4s, v20.s[1]\n"
-            "scvtf v0.4s, v0.4s, #0x4\n"
-            "scvtf v18.4s, v18.4s, #0x4\n"
-            "fmla v10.4s, v26.4s, v17.4s\n"
-            "ldr q17, [x23, #0x10]\n"
-            "fmul v26.4s, v27.4s, v20.s[2]\n"
-            "fmul v20.4s, v27.4s, v20.s[3]\n"
-            "fmla v12.4s, v0.4s, v26.4s\n"
-            "ldr d0, [x22, #-0x8]\n"
-            "ldr d26, [x21, #-0x8]\n"
-            "fcvtl v0.4s, v0.4h\n"
-            "fmla v28.4s, v18.4s, v20.4s\n"
-            "movi v20.4s, #0x0\n"
-            "movi v18.4s, #0x0\n"
-            ".inst 0x4e88a674  // smmla v20.4s, v19.16b, v8.16b\n"
-            ".inst 0x4e9fa672  // smmla v18.4s, v19.16b, v31.16b\n"
-            "ldr q19, [x23, #0x20]\n"
-            "fcvtl v26.4s, v26.4h\n"
-            ".inst 0x4e8fa674  // smmla v20.4s, v19.16b, v15.16b\n"
-            ".inst 0x4e81a672  // smmla v18.4s, v19.16b, v1.16b\n"
-            "ldr q19, [x23, #0x40]\n"
-            ".inst 0x4e95a674  // smmla v20.4s, v19.16b, v21.16b\n"
-            ".inst 0x4e90a672  // smmla v18.4s, v19.16b, v16.16b\n"
-            "ldr q19, [x23, #0x60]\n"
-            ".inst 0x4e9da674  // smmla v20.4s, v19.16b, v29.16b\n"
-            ".inst 0x4e83a672  // smmla v18.4s, v19.16b, v3.16b\n"
-            "uzp1 v19.2d, v20.2d, v18.2d\n"
-            "scvtf v19.4s, v19.4s, #0x4\n"
-            "uzp2 v20.2d, v20.2d, v18.2d\n"
-            "fmul v18.4s, v27.4s, v9.s[0]\n"
-            "scvtf v20.4s, v20.4s, #0x4\n"
-            "fmla v11.4s, v19.4s, v18.4s\n"
-            "ldr q18, [x22, #0x0]\n"
-            "fmul v19.4s, v27.4s, v9.s[1]\n"
-            "fmla v13.4s, v20.4s, v19.4s\n"
-            "movi v19.4s, #0x0\n"
-            "movi v20.4s, #0x0\n"
-            ".inst 0x4e88a633  // smmla v19.4s, v17.16b, v8.16b\n"
-            ".inst 0x4e9fa634  // smmla v20.4s, v17.16b, v31.16b\n"
-            "ldr q17, [x23, #0x30]\n"
-            ".inst 0x4e8fa633  // smmla v19.4s, v17.16b, v15.16b\n"
-            ".inst 0x4e81a634  // smmla v20.4s, v17.16b, v1.16b\n"
-            "ldr q17, [x23, #0x50]\n"
-            ".inst 0x4e95a633  // smmla v19.4s, v17.16b, v21.16b\n"
-            ".inst 0x4e90a634  // smmla v20.4s, v17.16b, v16.16b\n"
-            "ldr q17, [x23, #0x70]\n"
-            "add x23, x23, #0x88\n"
-            ".inst 0x4e9da633  // smmla v19.4s, v17.16b, v29.16b\n"
-            ".inst 0x4e83a634  // smmla v20.4s, v17.16b, v3.16b\n"
-            "uzp1 v17.2d, v19.2d, v20.2d\n"
-            "scvtf v17.4s, v17.4s, #0x4\n"
-            "uzp2 v20.2d, v19.2d, v20.2d\n"
-            "fmul v19.4s, v27.4s, v9.s[2]\n"
-            "fmul v9.4s, v27.4s, v9.s[3]\n"
-            "scvtf v20.4s, v20.4s, #0x4\n"
-            "fmla v22.4s, v17.4s, v19.4s\n"
-            "ldr q17, [x22, #0x10]\n"
-            "movi v19.4s, #0x0\n"
-            ".inst 0x4e88a653  // smmla v19.4s, v18.16b, v8.16b\n"
-            "fmla v23.4s, v20.4s, v9.4s\n"
-            "movi v20.4s, #0x0\n"
-            "movi v9.4s, #0x0\n"
-            ".inst 0x4e9fa654  // smmla v20.4s, v18.16b, v31.16b\n"
-            "ldr q18, [x22, #0x20]\n"
-            ".inst 0x4e88a629  // smmla v9.4s, v17.16b, v8.16b\n"
-            ".inst 0x4e8fa653  // smmla v19.4s, v18.16b, v15.16b\n"
-            ".inst 0x4e81a654  // smmla v20.4s, v18.16b, v1.16b\n"
-            "ldr q18, [x22, #0x40]\n"
-            ".inst 0x4e95a653  // smmla v19.4s, v18.16b, v21.16b\n"
-            ".inst 0x4e90a654  // smmla v20.4s, v18.16b, v16.16b\n"
-            "ldr q18, [x22, #0x60]\n"
-            ".inst 0x4e9da653  // smmla v19.4s, v18.16b, v29.16b\n"
-            ".inst 0x4e83a654  // smmla v20.4s, v18.16b, v3.16b\n"
-            "movi v18.4s, #0x0\n"
-            ".inst 0x4e9fa632  // smmla v18.4s, v17.16b, v31.16b\n"
-            "ldr q17, [x22, #0x30]\n"
-            ".inst 0x4e8fa629  // smmla v9.4s, v17.16b, v15.16b\n"
-            ".inst 0x4e81a632  // smmla v18.4s, v17.16b, v1.16b\n"
-            "ldr q17, [x22, #0x50]\n"
-            ".inst 0x4e95a629  // smmla v9.4s, v17.16b, v21.16b\n"
-            ".inst 0x4e90a632  // smmla v18.4s, v17.16b, v16.16b\n"
-            "ldr q17, [x22, #0x70]\n"
-            "add x22, x22, #0x88\n"
-            ".inst 0x4e9da629  // smmla v9.4s, v17.16b, v29.16b\n"
-            ".inst 0x4e83a632  // smmla v18.4s, v17.16b, v3.16b\n"
-            "uzp1 v17.2d, v19.2d, v20.2d\n"
-            "uzp2 v20.2d, v19.2d, v20.2d\n"
-            "fmul v19.4s, v27.4s, v0.s[0]\n"
-            "scvtf v17.4s, v17.4s, #0x4\n"
-            "scvtf v20.4s, v20.4s, #0x4\n"
-            "fmla v25.4s, v17.4s, v19.4s\n"
-            "ldr q19, [x21, #0x0]\n"
-            "fmul v17.4s, v27.4s, v0.s[1]\n"
-            "fmla v5.4s, v20.4s, v17.4s\n"
-            "ldr q17, [x21, #0x10]\n"
-            "uzp1 v20.2d, v9.2d, v18.2d\n"
-            "uzp2 v9.2d, v9.2d, v18.2d\n"
-            "fmul v18.4s, v27.4s, v0.s[2]\n"
-            "fmul v0.4s, v27.4s, v0.s[3]\n"
-            "scvtf v20.4s, v20.4s, #0x4\n"
-            "scvtf v9.4s, v9.4s, #0x4\n"
-            "fmla v7.4s, v20.4s, v18.4s\n"
-            "movi v20.4s, #0x0\n"
-            "movi v18.4s, #0x0\n"
-            ".inst 0x4e88a674  // smmla v20.4s, v19.16b, v8.16b\n"
-            ".inst 0x4e9fa672  // smmla v18.4s, v19.16b, v31.16b\n"
-            "ldr q19, [x21, #0x20]\n"
-            "fmla v4.4s, v9.4s, v0.4s\n"
-            "movi v9.4s, #0x0\n"
-            "movi v0.4s, #0x0\n"
-            ".inst 0x4e88a629  // smmla v9.4s, v17.16b, v8.16b\n"
-            "fmul v8.4s, v27.4s, v26.s[0]\n"
-            ".inst 0x4e9fa620  // smmla v0.4s, v17.16b, v31.16b\n"
-            "ldr q17, [x21, #0x30]\n"
-            ".inst 0x4e8fa674  // smmla v20.4s, v19.16b, v15.16b\n"
-            "fmul v31.4s, v27.4s, v26.s[1]\n"
-            ".inst 0x4e81a672  // smmla v18.4s, v19.16b, v1.16b\n"
-            "ldr q19, [x21, #0x40]\n"
-            ".inst 0x4e8fa629  // smmla v9.4s, v17.16b, v15.16b\n"
-            "fmul v15.4s, v27.4s, v26.s[2]\n"
-            "fmul v27.4s, v27.4s, v26.s[3]\n"
-            ".inst 0x4e81a620  // smmla v0.4s, v17.16b, v1.16b\n"
-            "ldr q1, [x21, #0x50]\n"
-            ".inst 0x4e95a674  // smmla v20.4s, v19.16b, v21.16b\n"
-            ".inst 0x4e90a672  // smmla v18.4s, v19.16b, v16.16b\n"
-            "ldr q26, [x21, #0x60]\n"
-            ".inst 0x4e95a429  // smmla v9.4s, v1.16b, v21.16b\n"
-            ".inst 0x4e90a420  // smmla v0.4s, v1.16b, v16.16b\n"
-            "ldr q21, [x21, #0x70]\n"
-            "add x21, x21, #0x88\n"
-            ".inst 0x4e9da754  // smmla v20.4s, v26.16b, v29.16b\n"
-            ".inst 0x4e83a752  // smmla v18.4s, v26.16b, v3.16b\n"
-            ".inst 0x4e9da6a9  // smmla v9.4s, v21.16b, v29.16b\n"
-            ".inst 0x4e83a6a0  // smmla v0.4s, v21.16b, v3.16b\n"
-            "uzp1 v29.2d, v20.2d, v18.2d\n"
-            "uzp2 v21.2d, v20.2d, v18.2d\n"
-            "scvtf v29.4s, v29.4s, #0x4\n"
-            "uzp1 v18.2d, v9.2d, v0.2d\n"
-            "uzp2 v16.2d, v9.2d, v0.2d\n"
-            "scvtf v21.4s, v21.4s, #0x4\n"
-            "fmla v6.4s, v29.4s, v8.4s\n"
-            "scvtf v18.4s, v18.4s, #0x4\n"
-            "scvtf v16.4s, v16.4s, #0x4\n"
-            "fmla v30.4s, v21.4s, v31.4s\n"
-            "fmla v24.4s, v18.4s, v15.4s\n"
-            "fmla v14.4s, v16.4s, v27.4s\n"
-            "bgt 3b\n"
-            "mov x20, %x[res_ptr]\n"
-            "subs x27, x27, #0x4\n"
-            "add %x[res_ptr], %x[res_ptr], #0x10\n"
-            "str q2, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q10, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q12, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q28, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q11, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q13, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q22, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q23, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q25, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q5, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q7, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q4, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q6, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q30, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q24, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "str q14, [x20, #0x0]\n"
-            "bne 2b\n"
-            "mov x20, #0x4\n"
-            "sub x10, x10, #0x10\n"
-            "cmp x10, #0x10\n"
-            "mov %x[res_ptr], x26\n"
-            "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
-            "bge 1b\n"
-            "4:"  // Row loop skip
-            "cbz x10, 9f\n"
-            "5:"  // Row tail: Row loop
-            "add x24, %x[b_ptr], #0x8\n"
-            "mov x23, %x[nc]\n"
-            "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
-            "6:"  // Row tail: Column loop
-            "movi v2.16b, #0x0\n"
-            "movi v10.16b, #0x0\n"
-            "add x25, %x[a_ptr], #0x8\n"
-            "mov x21, %x[nb]\n"
-            "movi v12.16b, #0x0\n"
-            "movi v28.16b, #0x0\n"
-            "7:"  // Row tail: Block loop
-            "ldr q6, [x24, #0x0]\n"
-            "ldr q5, [x24, #0x10]\n"
-            "movi v17.16b, #0x4\n"
-            "movi v8.4s, #0x0\n"
-            "ldr q4, [x25, #0x0]\n"
-            "ldr q13, [x25, #0x10]\n"
-            "movi v27.4s, #0x0\n"
-            "movi v0.4s, #0x0\n"
-            "ldr q31, [x24, #0x20]\n"
-            "ldr q14, [x24, #0x30]\n"
-            "movi v29.4s, #0x0\n"
-            "movi v22.16b, #0xf0\n"
-            "ldr q11, [x25, #0x20]\n"
-            "ldr q23, [x25, #0x30]\n"
-            "sshl v21.16b, v6.16b, v17.16b\n"
-            "sshl v16.16b, v5.16b, v17.16b\n"
-            "ldr q20, [x25, #0x40]\n"
-            "ldr q26, [x25, #0x50]\n"
-            "and v6.16b, v6.16b, v22.16b\n"
-            "and v5.16b, v5.16b, v22.16b\n"
-            "ldr q25, [x25, #0x60]\n"
-            "ldr q3, [x25, #0x70]\n"
-            "sshl v19.16b, v31.16b, v17.16b\n"
-            "sshl v18.16b, v14.16b, v17.16b\n"
-            "ldr d17, [x25, #-0x8]\n"
-            ".inst 0x4e95a488  // smmla v8.4s, v4.16b, v21.16b\n"
-            ".inst 0x4e90a49b  // smmla v27.4s, v4.16b, v16.16b\n"
-            "and v31.16b, v31.16b, v22.16b\n"
-            ".inst 0x4e95a5a0  // smmla v0.4s, v13.16b, v21.16b\n"
-            ".inst 0x4e90a5bd  // smmla v29.4s, v13.16b, v16.16b\n"
-            "and v14.16b, v14.16b, v22.16b\n"
-            "sub x20, x24, #0x8\n"
-            "ldr d16, [x20, #0x0]\n"
-            "subs x21, x21, #0x1\n"
-            "add x25, x25, #0x88\n"
-            "fcvtl v17.4s, v17.4h\n"
-            "add x24, x24, #0x48\n"
-            ".inst 0x4e93a568  // smmla v8.4s, v11.16b, v19.16b\n"
-            ".inst 0x4e92a57b  // smmla v27.4s, v11.16b, v18.16b\n"
-            ".inst 0x4e93a6e0  // smmla v0.4s, v23.16b, v19.16b\n"
-            ".inst 0x4e92a6fd  // smmla v29.4s, v23.16b, v18.16b\n"
-            "fcvtl v16.4s, v16.4h\n"
-            ".inst 0x4e86a688  // smmla v8.4s, v20.16b, v6.16b\n"
-            ".inst 0x4e85a69b  // smmla v27.4s, v20.16b, v5.16b\n"
-            "fmul v23.4s, v16.4s, v17.s[0]\n"
-            "fmul v21.4s, v16.4s, v17.s[1]\n"
-            "fmul v1.4s, v16.4s, v17.s[2]\n"
-            "fmul v20.4s, v16.4s, v17.s[3]\n"
-            ".inst 0x4e86a740  // smmla v0.4s, v26.16b, v6.16b\n"
-            ".inst 0x4e85a75d  // smmla v29.4s, v26.16b, v5.16b\n"
-            ".inst 0x4e9fa728  // smmla v8.4s, v25.16b, v31.16b\n"
-            ".inst 0x4e8ea73b  // smmla v27.4s, v25.16b, v14.16b\n"
-            ".inst 0x4e9fa460  // smmla v0.4s, v3.16b, v31.16b\n"
-            ".inst 0x4e8ea47d  // smmla v29.4s, v3.16b, v14.16b\n"
-            "uzp1 v19.2d, v8.2d, v27.2d\n"
-            "uzp2 v18.2d, v8.2d, v27.2d\n"
-            "scvtf v19.4s, v19.4s, #0x4\n"
-            "uzp1 v17.2d, v0.2d, v29.2d\n"
-            "uzp2 v16.2d, v0.2d, v29.2d\n"
-            "scvtf v18.4s, v18.4s, #0x4\n"
-            "fmla v2.4s, v19.4s, v23.4s\n"
-            "scvtf v17.4s, v17.4s, #0x4\n"
-            "scvtf v16.4s, v16.4s, #0x4\n"
-            "fmla v10.4s, v18.4s, v21.4s\n"
-            "fmla v12.4s, v17.4s, v1.4s\n"
-            "fmla v28.4s, v16.4s, v20.4s\n"
-            "bgt 7b\n"
-            "mov x20, %x[res_ptr]\n"
-            "cmp x10, #0x1\n"
-            "str q2, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "ble 8f\n"
-            "cmp x10, #0x2\n"
-            "str q10, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "ble 8f\n"
-            "cmp x10, #0x3\n"
-            "str q12, [x20, #0x0]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "ble 8f\n"
-            "str q28, [x20, #0x0]\n"
-            "8:"  // Row tail: Accumulator store skip
-            "subs x23, x23, #0x4\n"
-            "add %x[res_ptr], %x[res_ptr], #0x10\n"
-            "bne 6b\n"
-            "subs x10, x10, #0x4\n"
-            "add %x[a_ptr], %x[a_ptr], x9\n"
-            "mov %x[res_ptr], x22\n"
-            "bgt 5b\n"
-            "9:"  // Row tail: Row loop skip
-            : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
-            : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
-            : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
-        );
-        return;
-    }
-#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
-    float sumf[4][4];
-    int sumi;
-
-    for (int y = 0; y < nr / 4; y++) {
-        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
-        for (int x = 0; x < nc / ncols_interleaved; x++) {
-            const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
-            for (int m = 0; m < 4; m++) {
-                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
-            }
-            for (int l = 0; l < nb; l++) {
-                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
-                    for (int m = 0; m < 4; m++) {
-                        for (int j = 0; j < ncols_interleaved; j++) {
-                            sumi = 0;
-                            for (int i = 0; i < blocklen; ++i) {
-                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
-                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
-                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
-                                        (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
-                            }
-                            sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
-                        }
-                    }
-                }
-            }
-            for (int m = 0; m < 4; m++) {
-                for (int j = 0; j < ncols_interleaved; j++)
-                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
-            }
-        }
-    }
-}
-
-void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
-    const int qk = QK8_0;
-    const int nb = n / qk;
-    const int ncols_interleaved = 8;
-    const int blocklen = 8;
-
-    assert (n % qk == 0);
-    assert (nr % 4 == 0);
-    assert (nc % ncols_interleaved == 0);
-
-    UNUSED(s);
-    UNUSED(bs);
-    UNUSED(vx);
-    UNUSED(vy);
-    UNUSED(nr);
-    UNUSED(nc);
-    UNUSED(nb);
-    UNUSED(ncols_interleaved);
-    UNUSED(blocklen);
-
-#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
-#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
-    if (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) {
-        const void * b_ptr = vx;
-        const void * a_ptr = vy;
-        float * res_ptr = s;
-        size_t res_stride = bs * sizeof(float);
-
-        __asm__ __volatile__(
-            "mov x20, #0x4\n"
-            "mov x13, %x[nr]\n"
-            "mov z28.s, #-0x4\n"
-            "mov x12, #0x88\n"
-            "ptrue p1.b\n"
-            "whilelt p0.s, XZR, x20\n"
-            "cmp x13, #0x10\n"
-            "mul x12, %x[nb], x12\n"
-            "blt 4f\n"
-            "1:"  // Row loop
-            "add x11, %x[b_ptr], #0x10\n"
-            "mov x10, %x[nc]\n"
-            "add x9, %x[res_ptr], %x[res_stride], LSL #4\n"
-            "2:"  // Column loop
-            "add x28, %x[a_ptr], #0x8\n"
-            "mov z24.b, #0x0\n"
-            "mov z15.b, #0x0\n"
-            "mov x27, %x[nb]\n"
-            "add x26, x28, x12\n"
-            "mov z12.b, #0x0\n"
-            "mov z0.b, #0x0\n"
-            "add x25, x26, x12\n"
-            "mov z13.b, #0x0\n"
-            "mov z1.b, #0x0\n"
-            "add x24, x25, x12\n"
-            "mov z20.b, #0x0\n"
-            "mov z25.b, #0x0\n"
-            "mov z11.b, #0x0\n"
-            "mov z16.b, #0x0\n"
-            "mov z19.b, #0x0\n"
-            "mov z26.b, #0x0\n"
-            "mov z8.b, #0x0\n"
-            "mov z29.b, #0x0\n"
-            "mov z27.b, #0x0\n"
-            "mov z10.b, #0x0\n"
-            "3:"  // Block loop
-            "ld1b { z30.b }, p1/Z, [x11]\n"
-            "ld1b { z21.b }, p1/Z, [x11, #1, MUL VL]\n"
-            "mov z18.s, #0x0\n"
-            "mov z7.s, #0x0\n"
-            "ld1rqb { z3.b }, p1/Z, [x28]\n"
-            "ld1rqb { z5.b }, p1/Z, [x28, #16]\n"
-            "mov z9.s, #0x0\n"
-            "mov z22.s, #0x0\n"
-            "ld1b { z4.b }, p1/Z, [x11, #2, MUL VL]\n"
-            "ld1b { z17.b }, p1/Z, [x11, #3, MUL VL]\n"
-            "sub x20, x11, #0x10\n"
-            "sub x23, x28, #0x8\n"
-            "lsl z31.b, z30.b, #0x4\n"
-            "lsl z6.b, z21.b, #0x4\n"
-            "ld1h { z23.s }, p1/Z, [x20]\n"
-            "sub x22, x26, #0x8\n"
-            "and z30.b, z30.b, #0xf0\n"
-            "and z21.b, z21.b, #0xf0\n"
-            "sub x21, x25, #0x8\n"
-            "sub x20, x24, #0x8\n"
-            "lsl z14.b, z4.b, #0x4\n"
-            "lsl z2.b, z17.b, #0x4\n"
-            "subs x27, x27, #0x1\n"
-            "add x11, x11, #0x90\n"
-            ".inst 0x451f9872  // smmla z18.s, z3.b, z31.b\n"
-            ".inst 0x45069867  // smmla z7.s, z3.b, z6.b\n"
-            "ld1rqb { z3.b }, p1/Z, [x28, #32]\n"
-            "and z4.b, z4.b, #0xf0\n"
-            ".inst 0x451f98a9  // smmla z9.s, z5.b, z31.b\n"
-            ".inst 0x450698b6  // smmla z22.s, z5.b, z6.b\n"
-            "ld1rqb { z5.b }, p1/Z, [x28, #48]\n"
-            "and z17.b, z17.b, #0xf0\n"
-            "fcvt z23.s, p1/m, z23.h\n"
-            ".inst 0x450e9872  // smmla z18.s, z3.b, z14.b\n"
-            ".inst 0x45029867  // smmla z7.s, z3.b, z2.b\n"
-            "ld1rqb { z3.b }, p1/Z, [x28, #64]\n"
-            ".inst 0x450e98a9  // smmla z9.s, z5.b, z14.b\n"
-            ".inst 0x450298b6  // smmla z22.s, z5.b, z2.b\n"
-            "ld1rqb { z5.b }, p1/Z, [x28, #80]\n"
-            "fscale z23.s, p1/m, z23.s, z28.s\n"
-            ".inst 0x451e9872  // smmla z18.s, z3.b, z30.b\n"
-            ".inst 0x45159867  // smmla z7.s, z3.b, z21.b\n"
-            "ld1rqb { z3.b }, p1/Z, [x28, #96]\n"
-            ".inst 0x451e98a9  // smmla z9.s, z5.b, z30.b\n"
-            ".inst 0x451598b6  // smmla z22.s, z5.b, z21.b\n"
-            "ld1rqb { z5.b }, p1/Z, [x28, #112]\n"
-            "add x28, x28, #0x88\n"
-            ".inst 0x45049872  // smmla z18.s, z3.b, z4.b\n"
-            ".inst 0x45119867  // smmla z7.s, z3.b, z17.b\n"
-            "ld1h { z3.s }, p0/Z, [x23]\n"
-            ".inst 0x450498a9  // smmla z9.s, z5.b, z4.b\n"
-            ".inst 0x451198b6  // smmla z22.s, z5.b, z17.b\n"
-            "fcvt z3.s, p1/m, z3.h\n"
-            "uzp1 z5.d, z18.d, z7.d\n"
-            "uzp2 z18.d, z18.d, z7.d\n"
-            "mov z3.q, z3.q[0]\n"
-            "uzp1 z7.d, z9.d, z22.d\n"
-            "uzp2 z22.d, z9.d, z22.d\n"
-            "fmul z9.s, z23.s, z3.s[0]\n"
-            "scvtf z5.s, p1/m, z5.s\n"
-            "scvtf z18.s, p1/m, z18.s\n"
-            "scvtf z7.s, p1/m, z7.s\n"
-            "scvtf z22.s, p1/m, z22.s\n"
-            "fmla z24.s, p1/M, z5.s, z9.s\n"
-            "ld1rqb { z5.b }, p1/Z, [x26]\n"
-            "fmul z9.s, z23.s, z3.s[1]\n"
-            "fmla z15.s, p1/M, z18.s, z9.s\n"
-            "ld1rqb { z18.b }, p1/Z, [x26, #16]\n"
-            "fmul z9.s, z23.s, z3.s[2]\n"
-            "fmul z3.s, z23.s, z3.s[3]\n"
-            "fmla z12.s, p1/M, z7.s, z9.s\n"
-            "mov z9.s, #0x0\n"
-            "ld1h { z7.s }, p0/Z, [x22]\n"
-            ".inst 0x451f98a9  // smmla z9.s, z5.b, z31.b\n"
-            "fmla z0.s, p1/M, z22.s, z3.s\n"
-            "mov z22.s, #0x0\n"
-            "ld1h { z3.s }, p0/Z, [x21]\n"
-            ".inst 0x450698b6  // smmla z22.s, z5.b, z6.b\n"
-            "ld1rqb { z5.b }, p1/Z, [x26, #32]\n"
-            "fcvt z7.s, p1/m, z7.h\n"
-            "fcvt z3.s, p1/m, z3.h\n"
-            ".inst 0x450e98a9  // smmla z9.s, z5.b, z14.b\n"
-            ".inst 0x450298b6  // smmla z22.s, z5.b, z2.b\n"
-            "ld1rqb { z5.b }, p1/Z, [x26, #64]\n"
-            "mov z7.q, z7.q[0]\n"
-            "mov z3.q, z3.q[0]\n"
-            ".inst 0x451e98a9  // smmla z9.s, z5.b, z30.b\n"
-            ".inst 0x451598b6  // smmla z22.s, z5.b, z21.b\n"
-            "ld1rqb { z5.b }, p1/Z, [x26, #96]\n"
-            ".inst 0x450498a9  // smmla z9.s, z5.b, z4.b\n"
-            ".inst 0x451198b6  // smmla z22.s, z5.b, z17.b\n"
-            "uzp1 z5.d, z9.d, z22.d\n"
-            "scvtf z5.s, p1/m, z5.s\n"
-            "uzp2 z22.d, z9.d, z22.d\n"
-            "fmul z9.s, z23.s, z7.s[0]\n"
-            "scvtf z22.s, p1/m, z22.s\n"
-            "fmla z13.s, p1/M, z5.s, z9.s\n"
-            "ld1rqb { z9.b }, p1/Z, [x25]\n"
-            "fmul z5.s, z23.s, z7.s[1]\n"
-            "fmla z1.s, p1/M, z22.s, z5.s\n"
-            "mov z5.s, #0x0\n"
-            "mov z22.s, #0x0\n"
-            ".inst 0x451f9a45  // smmla z5.s, z18.b, z31.b\n"
-            ".inst 0x45069a56  // smmla z22.s, z18.b, z6.b\n"
-            "ld1rqb { z18.b }, p1/Z, [x26, #48]\n"
-            ".inst 0x450e9a45  // smmla z5.s, z18.b, z14.b\n"
-            ".inst 0x45029a56  // smmla z22.s, z18.b, z2.b\n"
-            "ld1rqb { z18.b }, p1/Z, [x26, #80]\n"
-            ".inst 0x451e9a45  // smmla z5.s, z18.b, z30.b\n"
-            ".inst 0x45159a56  // smmla z22.s, z18.b, z21.b\n"
-            "ld1rqb { z18.b }, p1/Z, [x26, #112]\n"
-            "add x26, x26, #0x88\n"
-            ".inst 0x45049a45  // smmla z5.s, z18.b, z4.b\n"
-            ".inst 0x45119a56  // smmla z22.s, z18.b, z17.b\n"
-            "uzp1 z18.d, z5.d, z22.d\n"
-            "scvtf z18.s, p1/m, z18.s\n"
-            "uzp2 z22.d, z5.d, z22.d\n"
-            "fmul z5.s, z23.s, z7.s[2]\n"
-            "fmul z7.s, z23.s, z7.s[3]\n"
-            "scvtf z22.s, p1/m, z22.s\n"
-            "fmla z20.s, p1/M, z18.s, z5.s\n"
-            "ld1rqb { z18.b }, p1/Z, [x25, #16]\n"
-            "ld1h { z5.s }, p0/Z, [x20]\n"
-            "fcvt z5.s, p1/m, z5.h\n"
-            "fmla z25.s, p1/M, z22.s, z7.s\n"
-            "mov z22.s, #0x0\n"
-            "mov z7.s, #0x0\n"
-            ".inst 0x451f9936  // smmla z22.s, z9.b, z31.b\n"
-            ".inst 0x45069927  // smmla z7.s, z9.b, z6.b\n"
-            "ld1rqb { z9.b }, p1/Z, [x25, #32]\n"
-            "mov z5.q, z5.q[0]\n"
-            ".inst 0x450e9936  // smmla z22.s, z9.b, z14.b\n"
-            ".inst 0x45029927  // smmla z7.s, z9.b, z2.b\n"
-            "ld1rqb { z9.b }, p1/Z, [x25, #64]\n"
-            ".inst 0x451e9936  // smmla z22.s, z9.b, z30.b\n"
-            ".inst 0x45159927  // smmla z7.s, z9.b, z21.b\n"
-            "ld1rqb { z9.b }, p1/Z, [x25, #96]\n"
-            ".inst 0x45049936  // smmla z22.s, z9.b, z4.b\n"
-            ".inst 0x45119927  // smmla z7.s, z9.b, z17.b\n"
-            "uzp1 z9.d, z22.d, z7.d\n"
-            "scvtf z9.s, p1/m, z9.s\n"
-            "uzp2 z22.d, z22.d, z7.d\n"
-            "fmul z7.s, z23.s, z3.s[0]\n"
-            "scvtf z22.s, p1/m, z22.s\n"
-            "fmla z11.s, p1/M, z9.s, z7.s\n"
-            "ld1rqb { z9.b }, p1/Z, [x24]\n"
-            "fmul z7.s, z23.s, z3.s[1]\n"
-            "fmla z16.s, p1/M, z22.s, z7.s\n"
-            "mov z22.s, #0x0\n"
-            "mov z7.s, #0x0\n"
-            ".inst 0x451f9a56  // smmla z22.s, z18.b, z31.b\n"
-            ".inst 0x45069a47  // smmla z7.s, z18.b, z6.b\n"
-            "ld1rqb { z18.b }, p1/Z, [x25, #48]\n"
-            ".inst 0x450e9a56  // smmla z22.s, z18.b, z14.b\n"
-            ".inst 0x45029a47  // smmla z7.s, z18.b, z2.b\n"
-            "ld1rqb { z18.b }, p1/Z, [x25, #80]\n"
-            ".inst 0x451e9a56  // smmla z22.s, z18.b, z30.b\n"
-            ".inst 0x45159a47  // smmla z7.s, z18.b, z21.b\n"
-            "ld1rqb { z18.b }, p1/Z, [x25, #112]\n"
-            "add x25, x25, #0x88\n"
-            ".inst 0x45049a56  // smmla z22.s, z18.b, z4.b\n"
-            ".inst 0x45119a47  // smmla z7.s, z18.b, z17.b\n"
-            "uzp1 z18.d, z22.d, z7.d\n"
-            "scvtf z18.s, p1/m, z18.s\n"
-            "uzp2 z7.d, z22.d, z7.d\n"
-            "fmul z22.s, z23.s, z3.s[2]\n"
-            "fmul z3.s, z23.s, z3.s[3]\n"
-            "scvtf z7.s, p1/m, z7.s\n"
-            "fmla z19.s, p1/M, z18.s, z22.s\n"
-            "ld1rqb { z18.b }, p1/Z, [x24, #16]\n"
-            "fmul z22.s, z23.s, z5.s[0]\n"
-            "fmla z26.s, p1/M, z7.s, z3.s\n"
-            "mov z3.s, #0x0\n"
-            "mov z7.s, #0x0\n"
-            ".inst 0x451f9923  // smmla z3.s, z9.b, z31.b\n"
-            ".inst 0x45069927  // smmla z7.s, z9.b, z6.b\n"
-            "ld1rqb { z9.b }, p1/Z, [x24, #32]\n"
-            ".inst 0x450e9923  // smmla z3.s, z9.b, z14.b\n"
-            ".inst 0x45029927  // smmla z7.s, z9.b, z2.b\n"
-            "mov z9.s, #0x0\n"
-            ".inst 0x451f9a49  // smmla z9.s, z18.b, z31.b\n"
-            "mov z31.s, #0x0\n"
-            ".inst 0x45069a5f  // smmla z31.s, z18.b, z6.b\n"
-            "ld1rqb { z6.b }, p1/Z, [x24, #48]\n"
-            "ld1rqb { z18.b }, p1/Z, [x24, #64]\n"
-            ".inst 0x450e98c9  // smmla z9.s, z6.b, z14.b\n"
-            "fmul z14.s, z23.s, z5.s[1]\n"
-            ".inst 0x450298df  // smmla z31.s, z6.b, z2.b\n"
-            "ld1rqb { z6.b }, p1/Z, [x24, #80]\n"
-            "fmul z2.s, z23.s, z5.s[2]\n"
-            "fmul z23.s, z23.s, z5.s[3]\n"
-            ".inst 0x451e9a43  // smmla z3.s, z18.b, z30.b\n"
-            ".inst 0x45159a47  // smmla z7.s, z18.b, z21.b\n"
-            "ld1rqb { z5.b }, p1/Z, [x24, #96]\n"
-            ".inst 0x451e98c9  // smmla z9.s, z6.b, z30.b\n"
-            ".inst 0x451598df  // smmla z31.s, z6.b, z21.b\n"
-            "ld1rqb { z18.b }, p1/Z, [x24, #112]\n"
-            "add x24, x24, #0x88\n"
-            ".inst 0x450498a3  // smmla z3.s, z5.b, z4.b\n"
-            ".inst 0x451198a7  // smmla z7.s, z5.b, z17.b\n"
-            ".inst 0x45049a49  // smmla z9.s, z18.b, z4.b\n"
-            ".inst 0x45119a5f  // smmla z31.s, z18.b, z17.b\n"
-            "uzp1 z18.d, z3.d, z7.d\n"
-            "uzp2 z5.d, z3.d, z7.d\n"
-            "scvtf z18.s, p1/m, z18.s\n"
-            "uzp1 z6.d, z9.d, z31.d\n"
-            "uzp2 z9.d, z9.d, z31.d\n"
-            "scvtf z5.s, p1/m, z5.s\n"
-            "fmla z8.s, p1/M, z18.s, z22.s\n"
-            "scvtf z6.s, p1/m, z6.s\n"
-            "scvtf z9.s, p1/m, z9.s\n"
-            "fmla z29.s, p1/M, z5.s, z14.s\n"
-            "fmla z27.s, p1/M, z6.s, z2.s\n"
-            "fmla z10.s, p1/M, z9.s, z23.s\n"
-            "bgt 3b\n"
-            "mov x20, %x[res_ptr]\n"
-            "subs x10, x10, #0x8\n"
-            "add %x[res_ptr], %x[res_ptr], #0x20\n"
-            "st1w { z24.s }, p1, [x20]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "st1w { z15.s }, p1, [x20]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "st1w { z12.s }, p1, [x20]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "st1w { z0.s }, p1, [x20]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "st1w { z13.s }, p1, [x20]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "st1w { z1.s }, p1, [x20]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "st1w { z20.s }, p1, [x20]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "st1w { z25.s }, p1, [x20]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "st1w { z11.s }, p1, [x20]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "st1w { z16.s }, p1, [x20]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "st1w { z19.s }, p1, [x20]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "st1w { z26.s }, p1, [x20]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "st1w { z8.s }, p1, [x20]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "st1w { z29.s }, p1, [x20]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "st1w { z27.s }, p1, [x20]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "st1w { z10.s }, p1, [x20]\n"
-            "bne 2b\n"
-            "mov x20, #0x4\n"
-            "sub x13, x13, #0x10\n"
-            "cmp x13, #0x10\n"
-            "mov %x[res_ptr], x9\n"
-            "madd %x[a_ptr], x20, x12, %x[a_ptr]\n"
-            "bge 1b\n"
-            "4:"  // Row loop skip
-            "cbz x13, 9f\n"
-            "5:"  // Row tail: Row loop
-            "add x25, %x[b_ptr], #0x10\n"
-            "mov x24, %x[nc]\n"
-            "add x23, %x[res_ptr], %x[res_stride], LSL #2\n"
-            "6:"  // Row tail: Column loop
-            "mov z24.b, #0x0\n"
-            "mov z15.b, #0x0\n"
-            "add x28, %x[a_ptr], #0x8\n"
-            "mov x22, %x[nb]\n"
-            "mov z12.b, #0x0\n"
-            "mov z0.b, #0x0\n"
-            "7:"  // Row tail: Block loop
-            "ld1b { z3.b }, p1/Z, [x25]\n"
-            "ld1b { z6.b }, p1/Z, [x25, #1, MUL VL]\n"
-            "mov z2.s, #0x0\n"
-            "mov z25.s, #0x0\n"
-            "ld1rqb { z26.b }, p1/Z, [x28]\n"
-            "ld1rqb { z21.b }, p1/Z, [x28, #16]\n"
-            "mov z27.s, #0x0\n"
-            "mov z19.s, #0x0\n"
-            "ld1b { z29.b }, p1/Z, [x25, #2, MUL VL]\n"
-            "ld1b { z16.b }, p1/Z, [x25, #3, MUL VL]\n"
-            "sub x21, x25, #0x10\n"
-            "sub x20, x28, #0x8\n"
-            "lsl z20.b, z3.b, #0x4\n"
-            "lsl z4.b, z6.b, #0x4\n"
-            "ld1rqb { z10.b }, p1/Z, [x28, #32]\n"
-            "ld1rqb { z23.b }, p1/Z, [x28, #48]\n"
-            "and z3.b, z3.b, #0xf0\n"
-            "and z6.b, z6.b, #0xf0\n"
-            "ld1rqb { z11.b }, p1/Z, [x28, #64]\n"
-            "ld1rqb { z7.b }, p1/Z, [x28, #80]\n"
-            "lsl z8.b, z29.b, #0x4\n"
-            "lsl z14.b, z16.b, #0x4\n"
-            "ld1rqb { z18.b }, p1/Z, [x28, #96]\n"
-            "ld1rqb { z30.b }, p1/Z, [x28, #112]\n"
-            ".inst 0x45149b42  // smmla z2.s, z26.b, z20.b\n"
-            ".inst 0x45049b59  // smmla z25.s, z26.b, z4.b\n"
-            "and z29.b, z29.b, #0xf0\n"
-            "ld1h { z17.s }, p1/Z, [x21]\n"
-            ".inst 0x45149abb  // smmla z27.s, z21.b, z20.b\n"
-            ".inst 0x45049ab3  // smmla z19.s, z21.b, z4.b\n"
-            "and z16.b, z16.b, #0xf0\n"
-            "ld1h { z4.s }, p0/Z, [x20]\n"
-            "subs x22, x22, #0x1\n"
-            "add x28, x28, #0x88\n"
-            "fcvt z17.s, p1/m, z17.h\n"
-            "add x25, x25, #0x90\n"
-            ".inst 0x45089942  // smmla z2.s, z10.b, z8.b\n"
-            ".inst 0x450e9959  // smmla z25.s, z10.b, z14.b\n"
-            "fcvt z4.s, p1/m, z4.h\n"
-            ".inst 0x45089afb  // smmla z27.s, z23.b, z8.b\n"
-            ".inst 0x450e9af3  // smmla z19.s, z23.b, z14.b\n"
-            "fscale z17.s, p1/m, z17.s, z28.s\n"
-            "mov z4.q, z4.q[0]\n"
-            ".inst 0x45039962  // smmla z2.s, z11.b, z3.b\n"
-            ".inst 0x45069979  // smmla z25.s, z11.b, z6.b\n"
-            "fmul z23.s, z17.s, z4.s[0]\n"
-            "fmul z9.s, z17.s, z4.s[1]\n"
-            "fmul z21.s, z17.s, z4.s[2]\n"
-            "fmul z4.s, z17.s, z4.s[3]\n"
-            ".inst 0x450398fb  // smmla z27.s, z7.b, z3.b\n"
-            ".inst 0x450698f3  // smmla z19.s, z7.b, z6.b\n"
-            ".inst 0x451d9a42  // smmla z2.s, z18.b, z29.b\n"
-            ".inst 0x45109a59  // smmla z25.s, z18.b, z16.b\n"
-            ".inst 0x451d9bdb  // smmla z27.s, z30.b, z29.b\n"
-            ".inst 0x45109bd3  // smmla z19.s, z30.b, z16.b\n"
-            "uzp1 z31.d, z2.d, z25.d\n"
-            "uzp2 z13.d, z2.d, z25.d\n"
-            "scvtf z31.s, p1/m, z31.s\n"
-            "uzp1 z17.d, z27.d, z19.d\n"
-            "uzp2 z18.d, z27.d, z19.d\n"
-            "scvtf z13.s, p1/m, z13.s\n"
-            "fmla z24.s, p1/M, z31.s, z23.s\n"
-            "scvtf z17.s, p1/m, z17.s\n"
-            "scvtf z18.s, p1/m, z18.s\n"
-            "fmla z15.s, p1/M, z13.s, z9.s\n"
-            "fmla z12.s, p1/M, z17.s, z21.s\n"
-            "fmla z0.s, p1/M, z18.s, z4.s\n"
-            "bgt 7b\n"
-            "mov x20, %x[res_ptr]\n"
-            "cmp x13, #0x1\n"
-            "st1w { z24.s }, p1, [x20]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "ble 8f\n"
-            "cmp x13, #0x2\n"
-            "st1w { z15.s }, p1, [x20]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "ble 8f\n"
-            "cmp x13, #0x3\n"
-            "st1w { z12.s }, p1, [x20]\n"
-            "add x20, x20, %x[res_stride]\n"
-            "ble 8f\n"
-            "st1w { z0.s }, p1, [x20]\n"
-            "8:"  // Row tail: Accumulator store skip
-            "subs x24, x24, #0x8\n"
-            "add %x[res_ptr], %x[res_ptr], #0x20\n"
-            "bne 6b\n"
-            "subs x13, x13, #0x4\n"
-            "add %x[a_ptr], %x[a_ptr], x12\n"
-            "mov %x[res_ptr], x23\n"
-            "bgt 5b\n"
-            "9:"  // Row tail: Row loop skip
-            : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
-            : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
-            : "cc", "memory", "p0", "p1", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
-        );
-        return;
-    }
-#endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
-#elif defined(__AVX2__) || defined(__AVX512F__)
-    {
-        const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx;
-        const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *)vy;
-        int64_t b_nb = n / QK4_0;
-        int64_t y = 0;
-        // Mask to mask out nibbles from packed bytes
-        const __m256i m4b = _mm256_set1_epi8(0x0F);
-        const __m128i loadMask = _mm_blend_epi32(_mm_setzero_si128(), _mm_set1_epi32(0xFFFFFFFF), 3);
-        // Lookup table to convert signed nibbles to signed bytes
-        __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
-        signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
-        // Permute mask used for easier vector processing at later stages
-        __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
-        int64_t xstart = 0;
-        int anr = nr - nr%16; // Used to align nr with boundary of 16
-    #ifdef __AVX512F__
-        int anc = nc - nc%16; // Used to align nc with boundary of 16
-        // Mask to mask out nibbles from packed bytes expanded to 512 bit length
-        const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
-        // Lookup table to convert signed nibbles to signed bytes expanded to 512 bit length
-        __m512i signextendlutexpanded = _mm512_inserti32x8(_mm512_castsi256_si512(signextendlut), signextendlut, 1);
-
-        // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
-        for (; y < anr / 4; y += 4) {
-
-            const block_q8_0x4 * a_ptrs[4];
-
-            a_ptrs[0] = a_ptr_start + (y * nb);
-            for (int i = 0; i < 3; ++i) {
-                a_ptrs[i + 1] = a_ptrs[i] + nb;
-            }
-
-            // Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation
-            for (int64_t x = 0; x < anc / 8; x += 2) {
-
-                const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x)     * b_nb);
-                const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
-
-                // Master FP accumulators
-                __m512 acc_rows[16];
-                for (int i = 0; i < 16; i++) {
-                    acc_rows[i] = _mm512_setzero_ps();
-                }
-
-                for (int64_t b = 0; b < nb; b++) {
-                    // Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF
-                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs));
-                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32));
-                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64));
-                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96));
-
-                    const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs));
-                    const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32));
-                    const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64));
-                    const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96));
-
-                    // Save the values in the following vectors in the formats B0B1B4B5B8B9BCBD, B2B3B6B7BABBBEBF for further processing and storing of values
-                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
-                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
-                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
-                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
-
-                    const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
-                    const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
-                    const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
-                    const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
-
-                    const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);
-                    const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);
-                    const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);
-                    const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);
-
-                    // 4-bit -> 8-bit - Sign is maintained
-                    const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7)
-                    const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7)
-
-                    const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15)
-                    const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15)
-
-                    const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23)
-                    const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23)
-
-                    const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31)
-                    const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
-
-                    // Shuffle pattern one - right side input
-                    const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
-                    const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
-
-                    const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
-                    const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
-
-                    const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
-                    const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
-
-                    const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
-                    const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
-
-                    // Shuffle pattern two - right side input
-
-                    const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
-                    const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
-
-                    const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
-                    const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
-
-                    const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
-                    const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
-
-                    const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
-                    const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
-
-                    // Scale values - Load the weight scale values of two block_q4_0x8
-                    const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
-
-                    // Process LHS in pairs of rows
-                    for (int rp = 0; rp < 4; rp++) {
-
-                        // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
-                        // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector
-                        __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));
-                        __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0);
-                        __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17);
-                        __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));
-                        __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0);
-                        __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17);
-                        __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));
-                        __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0);
-                        __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17);
-                        __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));
-                        __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0);
-                        __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17);
-
-                        __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1);
-                        __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1);
-                        __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1);
-                        __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1);
-                        __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1);
-                        __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1);
-                        __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1);
-                        __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1);
-
-                        // Shuffle pattern one - left side input
-
-                        const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160);  //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
-                        const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160);  //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
-
-                        const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160);  //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
-                        const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160);  //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
-
-                        const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160);  //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
-                        const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160);  //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
-
-                        const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160);  //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
-                        const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160);  //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
-
-                        // Shuffle pattern two - left side input
-
-                        const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245);  //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
-                        const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245);  //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
-
-                        const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245);  //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
-                        const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245);  //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
-
-                        const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245);  //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
-                        const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245);  //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
-
-                        const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245);  //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
-                        const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245);  //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
-
-                        // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
-                        // Resembles MMLAs into 2x2 matrices in ARM Version
-                        __m512i iacc_mat_00_sp1 =
-                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1));
-                        __m512i iacc_mat_01_sp1 =
-                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1));
-                        __m512i iacc_mat_10_sp1 =
-                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1));
-                        __m512i iacc_mat_11_sp1 =
-                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1));
-                        __m512i iacc_mat_00_sp2 =
-                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2));
-                        __m512i iacc_mat_01_sp2 =
-                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2));
-                        __m512i iacc_mat_10_sp2 =
-                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2));
-                        __m512i iacc_mat_11_sp2 =
-                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2));
-
-                        // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
-                        __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
-                        __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
-                        __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
-                        __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
-
-
-                        // Straighten out to make 4 row vectors
-                        __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
-                        __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
-                        __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
-                        __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
-
-                        // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
-                        const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68);
-                        const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16);
-
-                        // Multiply with appropiate scales and accumulate
-                        acc_rows[rp * 4]     = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)),   acc_rows[rp * 4]);
-                        acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)),  acc_rows[rp * 4 + 1]);
-                        acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
-                        acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
-                    }
-                }
-
-                // Store the accumulated values
-                for (int i = 0; i < 16; i++) {
-                    _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
-                }
-            }
-        }
-        // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation
-        for (; y < nr / 4; y ++) {
-
-            const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb);
-
-            // Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation
-            for (int64_t x = 0; x < anc / 8; x += 2) {
-
-                const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x)     * b_nb);
-                const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
-
-                // Master FP accumulators
-                __m512 acc_rows[4];
-                for (int i = 0; i < 4; i++) {
-                    acc_rows[i] = _mm512_setzero_ps();
-                }
-
-                for (int64_t b = 0; b < nb; b++) {
-                    // Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF
-                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs));
-                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32));
-                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64));
-                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96));
-
-                    const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs));
-                    const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32));
-                    const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64));
-                    const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96));
-
-                    // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess
-                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
-                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
-                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
-                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
-
-                    const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
-                    const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
-                    const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
-                    const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
-
-                    const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);
-                    const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);
-                    const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);
-                    const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);
-
-                    // 4-bit -> 8-bit - Sign is maintained
-                    const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7)
-                    const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7)
-
-                    const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15)
-                    const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15)
-
-                    const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23)
-                    const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23)
-
-                    const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31)
-                    const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
-
-                    // Shuffle pattern one - right side input
-                    const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
-                    const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
-
-                    const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
-                    const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
-
-                    const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
-                    const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
-
-                    const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
-                    const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
-
-                    // Shuffle pattern two - right side input
-
-                    const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
-                    const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
-
-                    const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
-                    const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
-
-                    const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
-                    const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
-
-                    const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
-                    const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
-
-
-                    // Scale values - Load the weight scale values of two block_q4_0x8
-                    const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
-
-                    // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
-                    // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector
-                    __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));
-                    __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0);
-                    __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17);
-                    __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));
-                    __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0);
-                    __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17);
-                    __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));
-                    __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0);
-                    __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17);
-                    __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));
-                    __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0);
-                    __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17);
-
-                    __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1);
-                    __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1);
-                    __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1);
-                    __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1);
-                    __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1);
-                    __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1);
-                    __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1);
-                    __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1);
-
-                    // Shuffle pattern one - left side input
-
-                    const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160);  //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
-                    const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160);  //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
-
-                    const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160);  //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
-                    const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160);  //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
-
-                    const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160);  //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
-                    const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160);  //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
-
-                    const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160);  //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
-                    const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160);  //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
-
-                    // Shuffle pattern two - left side input
-
-                    const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245);  //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
-                    const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245);  //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
-
-                    const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245);  //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
-                    const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245);  //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
-
-                    const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245);  //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
-                    const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245);  //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
-
-                    const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245);  //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
-                    const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245);  //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
-
-                    // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
-                    // Resembles MMLAs into 2x2 matrices in ARM Version
-                    __m512i iacc_mat_00_sp1 =
-                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1));
-                    __m512i iacc_mat_01_sp1 =
-                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1));
-                    __m512i iacc_mat_10_sp1 =
-                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1));
-                    __m512i iacc_mat_11_sp1 =
-                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1));
-                    __m512i iacc_mat_00_sp2 =
-                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2));
-                    __m512i iacc_mat_01_sp2 =
-                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2));
-                    __m512i iacc_mat_10_sp2 =
-                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2));
-                    __m512i iacc_mat_11_sp2 =
-                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2));
-
-                    // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
-                    __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
-                    __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
-                    __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
-                    __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
-
-
-                    // Straighten out to make 4 row vectors
-                    __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
-                    __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
-                    __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
-                    __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
-
-                    // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
-                    const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68);
-                    const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16);
-
-                    // Multiply with appropiate scales and accumulate
-                    acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)),   acc_rows[0]);
-                    acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)),  acc_rows[1]);
-                    acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
-                    acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
-                }
-
-                // Store the accumulated values
-                for (int i = 0; i < 4; i++) {
-                    _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
-                }
-            }
-        }
-        if (anc != nc) {
-            xstart = anc/8;
-            y = 0;
-        }
-    #endif // __AVX512F__
-
-        // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
-
-        for (; y < anr / 4; y += 4) {
-            const block_q8_0x4 * a_ptrs[4];
-
-            a_ptrs[0] = a_ptr_start + (y * nb);
-            for (int i = 0; i < 3; ++i) {
-                a_ptrs[i + 1] = a_ptrs[i] + nb;
-            }
-
-            // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation
-            for (int64_t x = xstart; x < nc / 8; x++) {
-
-                const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
-
-                // Master FP accumulators
-                __m256 acc_rows[16];
-                for (int i = 0; i < 16; i++) {
-                    acc_rows[i] = _mm256_setzero_ps();
-                }
-
-                for (int64_t b = 0; b < nb; b++) {
-                    // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
-                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
-                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));
-                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));
-                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));
-
-                    // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values
-                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
-                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
-                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
-                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
-
-                    // 4-bit -> 8-bit - Sign is maintained
-                    const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7)
-                    const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7)
-
-                    const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15)
-                    const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15)
-
-                    const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23)
-                    const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23)
-
-                    const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31)
-                    const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31)
-
-                    // Shuffle pattern one - right side input
-                    const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136);  //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3)
-                    const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136);  //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)
-
-                    const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136);  //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11)
-                    const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136);  //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11)
-
-                    const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136);  //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19)
-                    const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136);  //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19)
-
-                    const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136);  //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27)
-                    const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136);  //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27)
-
-                    // Shuffle pattern two - right side input
-
-                    const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221);  //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7)
-                    const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221);  //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7)
-
-                    const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221);  //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15)
-                    const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221);  //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15)
-
-                    const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221);  //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23)
-                    const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221);  //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23)
-
-                    const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221);  //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31)
-                    const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221);  //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31)
-
-                    // Scale values - Load the wight scale values of block_q4_0x8
-                    const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
-
-                    // Process LHS in groups of four
-                    for (int rp = 0; rp < 4; rp++) {
-                        // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
-                        // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
-                        __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));
-                        __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
-                        __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
-                        __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));
-                        __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
-                        __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
-                        __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));
-                        __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
-                        __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
-                        __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));
-                        __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
-                        __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
-
-                        // Shuffle pattern one - left side input
-                        const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160);  //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
-                        const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160);  //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
-
-                        const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160);  //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
-                        const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160);  //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
-
-                        const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160);  //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
-                        const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160);  //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
-
-                        const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160);  //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
-                        const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160);  //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
-
-                        // Shuffle pattern two - left side input
-                        const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245);  //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
-                        const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245);  //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
-
-                        const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245);  //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
-                        const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245);  //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
-
-                        const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245);  //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
-                        const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245);  //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
-
-                        const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245);  //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
-                        const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245);  //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
-
-                        // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
-                        // Resembles MMLAs into 2x2 matrices in ARM Version
-                        __m256i iacc_mat_00_sp1 =
-                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1));
-                        __m256i iacc_mat_01_sp1 =
-                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1));
-                        __m256i iacc_mat_10_sp1 =
-                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1));
-                        __m256i iacc_mat_11_sp1 =
-                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1));
-                        __m256i iacc_mat_00_sp2 =
-                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
-                        __m256i iacc_mat_01_sp2 =
-                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
-                        __m256i iacc_mat_10_sp2 =
-                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
-                        __m256i iacc_mat_11_sp2 =
-                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
-
-                        // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
-                        __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
-                        __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
-                        __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
-                        __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
-
-                        // Straighten out to make 4 row vectors
-                        __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);
-                        __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);
-                        __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);
-                        __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);
-
-                        // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
-                        const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
-
-                        // Multiply with appropiate scales and accumulate
-                        acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
-                        acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
-                        acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
-                        acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32,  255)), acc_rows[rp * 4 + 3]);
-                    }
-                }
-
-                // Store the accumulated values
-                for (int i = 0; i < 16; i++) {
-                    _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
-                }
-            }
-        }
-
-        // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation
-        for (; y < nr / 4; y ++) {
-
-            const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb);
-
-            // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
-            for (int64_t x = xstart; x < nc / 8; x++) {
-
-                const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
-
-                // Master FP accumulators
-                __m256 acc_rows[4];
-                for (int i = 0; i < 4; i++) {
-                    acc_rows[i] = _mm256_setzero_ps();
-                }
-
-                for (int64_t b = 0; b < nb; b++) {
-                    // Load the eight block_q8_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
-                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
-                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));
-                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));
-                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));
-
-                    // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess
-                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
-                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
-                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
-                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
-
-                    // 4-bit -> 8-bit - Sign is maintained
-                    const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b));  //B0(0-7) B1(0-7) B4(0-7) B5(0-7)
-                    const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b));  //B2(0-7) B3(0-7) B6(0-7) B7(0-7)
-
-                    const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b));  //B0(8-15) B1(8-15) B4(8-15) B5(8-15)
-                    const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b));  //B2(8-15) B3(8-15) B6(8-15) B7(8-15)
-
-                    const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b));  //B0(16-23) B1(16-23) B4(16-23) B5(16-23)
-                    const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b));  //B2(16-23) B3(16-23) B6(16-23) B7(16-23)
-
-                    const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b));  //B0(24-31) B1(24-31) B4(24-31) B5(24-31)
-                    const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b));  //B2(24-31) B3(24-31) B6(24-31) B7(24-31)
-
-                    // Shuffle pattern one - right side input
-                    const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136);  //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3)
-                    const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136);  //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)
-
-                    const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136);  //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11)
-                    const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136);  //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11)
-
-                    const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136);  //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19)
-                    const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136);  //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19)
-
-                    const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136);  //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27)
-                    const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136);  //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27)
-
-                    // Shuffle pattern two - right side input
-
-                    const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221);  //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7)
-                    const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221);  //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7)
-
-                    const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221);  //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15)
-                    const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221);  //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15)
-
-                    const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221);  //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23)
-                    const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221);  //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23)
-
-                    const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221);  //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31)
-                    const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221);  //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31)
-
-                    // Scale values - Load the wight scale values of block_q4_0x8
-                    const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
-
-                    // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
-                    // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
-                    __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));
-                    __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
-                    __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
-                    __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));
-                    __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
-                    __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
-                    __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));
-                    __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
-                    __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
-                    __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));
-                    __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
-                    __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
-
-                    // Shuffle pattern one - left side input
-
-                    const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160);  //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
-                    const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160);  //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
-
-                    const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160);  //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
-                    const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160);  //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
-
-                    const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160);  //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
-                    const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160);  //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
-
-                    const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160);  //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
-                    const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160);  //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
-
-                    // Shuffle pattern two - left side input
-
-                    const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245);  //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
-                    const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245);  //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
-
-                    const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245);  //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
-                    const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245);  //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
-
-                    const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245);  //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
-                    const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245);  //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
-
-                    const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245);  //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
-                    const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245);  //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
-
-                    // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
-                    // Resembles MMLAs into 2x2 matrices in ARM Version
-                    __m256i iacc_mat_00_sp1 =
-                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1));
-                    __m256i iacc_mat_01_sp1 =
-                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1));
-                    __m256i iacc_mat_10_sp1 =
-                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1));
-                    __m256i iacc_mat_11_sp1 =
-                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1));
-                    __m256i iacc_mat_00_sp2 =
-                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
-                    __m256i iacc_mat_01_sp2 =
-                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
-                    __m256i iacc_mat_10_sp2 =
-                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
-                    __m256i iacc_mat_11_sp2 =
-                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
-
-                    // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
-                    __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
-                    __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
-                    __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
-                    __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
-
-
-                    // Straighten out to make 4 row vectors
-                    __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);
-                    __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);
-                    __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);
-                    __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);
-
-                    // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
-                    const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptr[b].d, loadMask);
-
-                    // Multiply with appropiate scales and accumulate
-                    acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
-                    acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
-                    acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
-                    acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
-                }
-
-                // Store the accumulated values
-                for (int i = 0; i < 4; i++) {
-                    _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
-                }
-            }
-        }
-        return;
-    }
-#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
-    float sumf[4][8];
-    int sumi;
-
-    for (int y = 0; y < nr / 4; y++) {
-        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
-        for (int x = 0; x < nc / ncols_interleaved; x++) {
-            const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
-            for (int m = 0; m < 4; m++) {
-                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
-            }
-            for (int l = 0; l < nb; l++) {
-                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
-                    for (int m = 0; m < 4; m++) {
-                        for (int j = 0; j < ncols_interleaved; j++) {
-                            sumi = 0;
-                            for (int i = 0; i < blocklen; ++i) {
-                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
-                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
-                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
-                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
-                            }
-                            sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
-                        }
-                    }
-                }
-            }
-            for (int m = 0; m < 4; m++) {
-                for (int j = 0; j < ncols_interleaved; j++)
-                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
-            }
-        }
-    }
-}

+ 1 - 21
llama/ggml-aarch64.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -24,12 +24,8 @@
  * SOFTWARE.
  */
 
-// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
 #pragma once
 
-#define GGML_COMMON_DECL_C
-#include "ggml-common.h"
-
 #include "ggml.h"
 
 // GGML internal header
@@ -38,27 +34,11 @@
 extern "C" {
 #endif
 
-// Quantization
-void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
-void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
-
-void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t blck_size_interleave);
-
 // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
 size_t quantize_q4_0_4x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
 size_t quantize_q4_0_4x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
 size_t quantize_q4_0_8x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
 
-// GEMV
-void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-
-// GEMM
-void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-
 #ifdef __cplusplus
 }
 #endif

+ 24 - 28
llama/ggml-alloc.c

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -40,7 +40,7 @@
 
 //#define GGML_ALLOCATOR_DEBUG
 
-//#define AT_PRINTF(...) fprintf(stderr, __VA_ARGS__)
+//#define AT_PRINTF(...) GGML_LOG_DEBUG(__VA_ARGS__)
 #define AT_PRINTF(...)
 
 
@@ -115,7 +115,7 @@ void ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tenso
     size = GGML_PAD(size, talloc->alignment);
 
     if (talloc->offset + size > ggml_backend_buffer_get_size(talloc->buffer)) {
-        fprintf(stderr, "%s: not enough space in the buffer to allocate %s (needed %zu, available %zu)\n",
+        GGML_LOG_ERROR("%s: not enough space in the buffer to allocate %s (needed %zu, available %zu)\n",
                 __func__, tensor->name, size, ggml_backend_buffer_get_size(talloc->buffer) - talloc->offset);
         GGML_ABORT("not enough space in the buffer");
     }
@@ -198,7 +198,7 @@ static size_t ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * alloc, size_t siz
             best_fit_block = alloc->n_free_blocks - 1;
         } else {
             // this should never happen
-            fprintf(stderr, "%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n",
+            GGML_LOG_ERROR("%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n",
                     __func__, size, max_avail);
             GGML_ABORT("not enough space in the buffer");
         }
@@ -235,16 +235,16 @@ static size_t ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * alloc, size_t siz
                 }
             }
         }
-        fprintf(stderr, "max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
+        GGML_LOG_DEBUG("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
         for (int i = 0; i < 1024; i++) {
             if (alloc->allocated_tensors[i].tensor) {
-                fprintf(stderr, "%s [%zx-%zx] (%.2f MB) ", alloc->allocated_tensors[i].tensor->name,
+                GGML_LOG_DEBUG("%s [%zx-%zx] (%.2f MB) ", alloc->allocated_tensors[i].tensor->name,
                     alloc->allocated_tensors[i].offset,
                     alloc->allocated_tensors[i].offset + ggml_nbytes(alloc->allocated_tensors[i].tensor),
                     ggml_nbytes(alloc->allocated_tensors[i].tensor) / 1024.0 / 1024.0);
             }
         }
-        fprintf(stderr, "\n");
+        GGML_LOG_DEBUG("\n");
     }
 #endif
 
@@ -374,7 +374,6 @@ struct tensor_alloc {
 };
 
 struct leaf_alloc {
-    int buffer_id;
     struct tensor_alloc leaf;
 };
 
@@ -493,18 +492,12 @@ static bool ggml_gallocr_is_own(ggml_gallocr_t galloc, struct ggml_tensor * t) {
     return ggml_gallocr_hash_get(galloc, t)->allocated;
 }
 
-static void ggml_gallocr_set_node_offset(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id, size_t offset) {
-    struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
-    hn->buffer_id = buffer_id;
-    hn->offset = offset;
-    hn->allocated = true;
-}
-
 static bool ggml_gallocr_is_allocated(ggml_gallocr_t galloc, struct ggml_tensor * t) {
     return t->data != NULL || ggml_gallocr_hash_get(galloc, t)->allocated;
 }
 
 static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) {
+    GGML_ASSERT(buffer_id >= 0);
     struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
 
     if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_is_view(node)) {
@@ -766,7 +759,6 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
     for (int i = 0; i < graph->n_leafs; i++) {
         struct ggml_tensor * leaf = graph->leafs[i];
         struct hash_node * hn = ggml_gallocr_hash_get(galloc, leaf);
-        galloc->leaf_allocs[i].buffer_id = hn->buffer_id;
         if (leaf->view_src || leaf->data) {
             galloc->leaf_allocs[i].leaf.buffer_id = -1;
             galloc->leaf_allocs[i].leaf.offset = SIZE_MAX;
@@ -794,13 +786,13 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
         // even if there are no tensors allocated in this buffer, we still need to allocate it to initialize views
         if (new_size > cur_size || galloc->buffers[i] == NULL) {
 #ifndef NDEBUG
-            fprintf(stderr, "%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
+            GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
 #endif
 
             ggml_backend_buffer_free(galloc->buffers[i]);
             galloc->buffers[i] = ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size);
             if (galloc->buffers[i] == NULL) {
-                fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size);
+                GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size);
                 return false;
             }
             ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
@@ -844,21 +836,25 @@ static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor *
 }
 
 static bool ggml_gallocr_node_needs_realloc(ggml_gallocr_t galloc, struct ggml_tensor * node, struct tensor_alloc * talloc) {
-    size_t node_size = (node->data || node->view_src) ? 0 : ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node);
+    size_t node_size = 0;
+    if (!node->data && !node->view_src) {
+        GGML_ASSERT(talloc->buffer_id >= 0); // prevent segfault when misusing the API
+        node_size = ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node);
+    }
     return talloc->size_max >= node_size;
 }
 
 static bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph * graph) {
     if (galloc->n_nodes != graph->n_nodes) {
 #ifndef NDEBUG
-        fprintf(stderr, "%s: graph has different number of nodes\n", __func__);
+        GGML_LOG_DEBUG("%s: graph has different number of nodes\n", __func__);
 #endif
         return true;
     }
 
     if (galloc->n_leafs != graph->n_leafs) {
 #ifndef NDEBUG
-        fprintf(stderr, "%s: graph has different number of leafs\n", __func__);
+        GGML_LOG_DEBUG("%s: graph has different number of leafs\n", __func__);
 #endif
         return true;
     }
@@ -869,7 +865,7 @@ static bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph
 
         if (!ggml_gallocr_node_needs_realloc(galloc, node, &node_alloc->dst)) {
 #ifndef NDEBUG
-            fprintf(stderr, "%s: node %s is not valid\n", __func__, node->name);
+            GGML_LOG_DEBUG("%s: node %s is not valid\n", __func__, node->name);
 #endif
             return true;
         }
@@ -881,7 +877,7 @@ static bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph
             }
             if (!ggml_gallocr_node_needs_realloc(galloc, src, &node_alloc->src[j])) {
 #ifndef NDEBUG
-                fprintf(stderr, "%s: src %d (%s) of node %s is not valid\n", __func__, j, src->name, node->name);
+                GGML_LOG_DEBUG("%s: src %d (%s) of node %s is not valid\n", __func__, j, src->name, node->name);
 #endif
                 return true;
             }
@@ -895,14 +891,14 @@ bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph)
     if (ggml_gallocr_needs_realloc(galloc, graph)) {
         if (galloc->n_buffers == 1) {
 #ifndef NDEBUG
-            fprintf(stderr, "%s: reallocating buffers automatically\n", __func__);
+            GGML_LOG_DEBUG("%s: reallocating buffers automatically\n", __func__);
 #endif
             if (!ggml_gallocr_reserve(galloc, graph)) {
                 return false;
             }
         } else {
 #ifndef NDEBUG
-            fprintf(stderr, "%s: cannot reallocate multi buffer graph automatically, call reserve\n", __func__);
+            GGML_LOG_DEBUG("%s: cannot reallocate multi buffer graph automatically, call reserve\n", __func__);
 #endif
             return false;
         }
@@ -966,7 +962,7 @@ static bool alloc_tensor_range(struct ggml_context * ctx,
     ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size);
     if (buffer == NULL) {
 #ifndef NDEBUG
-        fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(buft), size);
+        GGML_LOG_DEBUG("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(buft), size);
 #endif
         for (size_t i = 0; i < *n_buffers; i++) {
             ggml_backend_buffer_free((*buffers)[i]);
@@ -1016,7 +1012,7 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte
         }
 
         if (this_size > max_size) {
-            fprintf(stderr, "%s: tensor %s is too large to fit in a %s buffer (tensor size: %zu, max buffer size: %zu)\n",
+            GGML_LOG_ERROR("%s: tensor %s is too large to fit in a %s buffer (tensor size: %zu, max buffer size: %zu)\n",
                     __func__, t->name,
                     ggml_backend_buft_name(buft),
                     this_size, max_size);
@@ -1048,7 +1044,7 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte
 
     if (n_buffers == 0) {
 #ifndef NDEBUG
-        fprintf(stderr, "%s: all tensors in the context are already allocated\n", __func__);
+        GGML_LOG_DEBUG("%s: all tensors in the context are already allocated\n", __func__);
 #endif
         return NULL;
     }

+ 2 - 2
llama/ggml-alloc.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -50,7 +50,7 @@ GGML_API void                ggml_tallocr_alloc(struct ggml_tallocr * talloc, st
 // Graph allocator
 /*
   Example usage:
-    ggml_gallocr_t galloc = ggml_gallocr_new(ggml_bacckend_cpu_buffer_type());
+    ggml_gallocr_t galloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
 
     // optional: create a worst-case graph and reserve the buffers to avoid reallocations
     ggml_gallocr_reserve(galloc, build_graph(max_batch));

+ 186 - 84
llama/ggml-backend-impl.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -34,146 +34,248 @@
 extern "C" {
 #endif
 
+    #define GGML_BACKEND_API_VERSION 1
+
     //
-    // Backend buffer
+    // Backend buffer type
     //
 
-    // buffer type
-    typedef void * ggml_backend_buffer_type_context_t;
-
     struct ggml_backend_buffer_type_i {
-        const char *          (*GGML_CALL get_name)        (ggml_backend_buffer_type_t buft);
+        const char *          (*get_name)      (ggml_backend_buffer_type_t buft);
         // allocate a buffer of this type
-        ggml_backend_buffer_t (*GGML_CALL alloc_buffer)    (ggml_backend_buffer_type_t buft, size_t size);
+        ggml_backend_buffer_t (*alloc_buffer)  (ggml_backend_buffer_type_t buft, size_t size);
         // tensor alignment
-        size_t                (*GGML_CALL get_alignment)   (ggml_backend_buffer_type_t buft);
-        // max buffer size that can be allocated
-        size_t                (*GGML_CALL get_max_size)    (ggml_backend_buffer_type_t buft);
-        // data size needed to allocate the tensor, including padding
-        size_t                (*GGML_CALL get_alloc_size)  (ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor);
-        // check if tensor data is in host memory
-        bool                  (*GGML_CALL is_host)         (ggml_backend_buffer_type_t buft);
+        size_t                (*get_alignment) (ggml_backend_buffer_type_t buft);
+        // (optional) max buffer size that can be allocated (defaults to SIZE_MAX)
+        size_t                (*get_max_size)  (ggml_backend_buffer_type_t buft);
+        // (optional) data size needed to allocate the tensor, including padding (defaults to ggml_nbytes)
+        size_t                (*get_alloc_size)(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor);
+        // (optional) check if tensor data is in host memory and uses standard ggml tensor layout (defaults to false)
+        bool                  (*is_host)       (ggml_backend_buffer_type_t buft);
     };
 
     struct ggml_backend_buffer_type {
         struct ggml_backend_buffer_type_i  iface;
-        ggml_backend_buffer_type_context_t context;
+        ggml_backend_dev_t device;
+        void * context;
     };
 
-    // buffer
-    typedef void * ggml_backend_buffer_context_t;
+    //
+    // Backend buffer
+    //
 
     struct ggml_backend_buffer_i {
-        const char * (*GGML_CALL get_name)      (ggml_backend_buffer_t buffer);
-        void         (*GGML_CALL free_buffer)   (ggml_backend_buffer_t buffer);
-        void *       (*GGML_CALL get_base)      (ggml_backend_buffer_t buffer);
-        void         (*GGML_CALL init_tensor)   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
-        void         (*GGML_CALL memset_tensor) (ggml_backend_buffer_t buffer,       struct ggml_tensor * tensor,     uint8_t value, size_t offset, size_t size);
-        void         (*GGML_CALL set_tensor)    (ggml_backend_buffer_t buffer,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
-        void         (*GGML_CALL get_tensor)    (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
-        bool         (*GGML_CALL cpy_tensor)    (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // dst is in the buffer, src may be in any buffer
-        void         (*GGML_CALL clear)         (ggml_backend_buffer_t buffer, uint8_t value);
-        void         (*GGML_CALL reset)         (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
+        // (optional) free the buffer
+        void         (*free_buffer)  (ggml_backend_buffer_t buffer);
+        // base address of the buffer
+        void *       (*get_base)     (ggml_backend_buffer_t buffer);
+        // (optional) initialize a tensor in the buffer (eg. add tensor extras)
+        void         (*init_tensor)  (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+        // tensor data access
+        void         (*memset_tensor)(ggml_backend_buffer_t buffer,       struct ggml_tensor * tensor,     uint8_t value, size_t offset, size_t size);
+        void         (*set_tensor)   (ggml_backend_buffer_t buffer,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+        void         (*get_tensor)   (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+        // (optional) tensor copy: dst is in the buffer, src may be in any buffer, including buffers from a different backend (return false if not supported)
+        bool         (*cpy_tensor)   (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst);
+        // clear the entire buffer
+        void         (*clear)        (ggml_backend_buffer_t buffer, uint8_t value);
+        // (optional) reset any internal state due to tensor initialization, such as tensor extras
+        void         (*reset)        (ggml_backend_buffer_t buffer);
     };
 
     struct ggml_backend_buffer {
         struct ggml_backend_buffer_i  iface;
         ggml_backend_buffer_type_t    buft;
-        ggml_backend_buffer_context_t context;
+        void * context;
         size_t size;
         enum ggml_backend_buffer_usage usage;
     };
 
-    GGML_CALL ggml_backend_buffer_t ggml_backend_buffer_init(
-                   ggml_backend_buffer_type_t      buft,
-            struct ggml_backend_buffer_i           iface,
-                   ggml_backend_buffer_context_t   context,
-                   size_t                          size);
+    GGML_API ggml_backend_buffer_t ggml_backend_buffer_init(
+                   ggml_backend_buffer_type_t buft,
+            struct ggml_backend_buffer_i      iface,
+                   void *                     context,
+                   size_t                     size);
 
     // do not use directly, use ggml_backend_tensor_copy instead
-    bool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml_tensor * dst);
+    GGML_API bool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml_tensor * dst);
 
+    // multi-buffer
     // buffer that contains a collection of buffers
-    GGML_CALL ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer_t * buffers, size_t n_buffers);
-    GGML_CALL bool                  ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer);
-    GGML_CALL void                  ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
+    GGML_API ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer_t * buffers, size_t n_buffers);
+    GGML_API bool                  ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer);
+    GGML_API void                  ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
 
     //
-    // Backend
+    // Backend (stream)
     //
 
-    typedef void * ggml_backend_context_t;
-
     struct ggml_backend_i {
-        const char * (*GGML_CALL get_name)(ggml_backend_t backend);
+        const char * (*get_name)(ggml_backend_t backend);
 
-        void (*GGML_CALL free)(ggml_backend_t backend);
-
-        // buffer allocation
-        ggml_backend_buffer_type_t (*GGML_CALL get_default_buffer_type)(ggml_backend_t backend);
+        void (*free)(ggml_backend_t backend);
 
         // (optional) asynchronous tensor data access
-        void (*GGML_CALL set_tensor_async)(ggml_backend_t backend,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
-        void (*GGML_CALL get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
-        bool (*GGML_CALL cpy_tensor_async)(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst);
+        void (*set_tensor_async)(ggml_backend_t backend,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+        void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+        bool (*cpy_tensor_async)(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst);
 
-        // (optional) complete all pending operations
-        void (*GGML_CALL synchronize)(ggml_backend_t backend);
+        // (optional) complete all pending operations (required if the backend supports async operations)
+        void (*synchronize)(ggml_backend_t backend);
 
-        // compute graph with a plan (not used currently)
-        // create a new plan for a graph
-        ggml_backend_graph_plan_t (*GGML_CALL graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph);
-        void                      (*GGML_CALL graph_plan_free)   (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+        // (optional) graph plans (not used currently)
+        // compute graph with a plan
+        ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph);
+        void                      (*graph_plan_free)   (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
         // update the plan with a new graph - this should be faster than creating a new plan when the graph has the same topology
-        void                      (*GGML_CALL graph_plan_update) (ggml_backend_t backend, ggml_backend_graph_plan_t plan, const struct ggml_cgraph * cgraph);
+        void                      (*graph_plan_update) (ggml_backend_t backend, ggml_backend_graph_plan_t plan, const struct ggml_cgraph * cgraph);
         // compute the graph with the plan
-        enum ggml_status          (*GGML_CALL graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
-
-        // compute graph without a plan (async)
-        enum ggml_status (*GGML_CALL graph_compute)     (ggml_backend_t backend, struct ggml_cgraph * cgraph);
-
-        // check if the backend can compute an operation
-        bool (*GGML_CALL supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
-
-        // check if the backend can use tensors allocated in a buffer type
-        bool (*GGML_CALL supports_buft)(ggml_backend_t backend, ggml_backend_buffer_type_t buft);
+        enum ggml_status          (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
 
-        // check if the backend wants to run an operation, even if the weights are allocated in a CPU buffer
-        // these should be expensive operations with large batch sizes that may benefit from running on this backend
-        // even if the weight has to be copied from the CPU temporarily
-        bool (*GGML_CALL offload_op)(ggml_backend_t backend, const struct ggml_tensor * op);
+        // compute graph (always async if supported by the backend)
+        enum ggml_status          (*graph_compute)     (ggml_backend_t backend, struct ggml_cgraph * cgraph);
 
         // (optional) event synchronization
-        // create a new event that can record events on this backend instance
-        ggml_backend_event_t (*GGML_CALL event_new)         (ggml_backend_t backend);
-        void                 (*GGML_CALL event_free)        (ggml_backend_event_t event);
-        // record an event on the backend instance that created it
-        void                 (*GGML_CALL event_record)      (ggml_backend_event_t event);
-        // wait for an event on on a different backend instance
-        void                 (*GGML_CALL event_wait)        (ggml_backend_t backend, ggml_backend_event_t event);
-        // block until an event is recorded
-        void                 (*GGML_CALL event_synchronize) (ggml_backend_event_t event);
+        // record an event on this stream
+        void (*event_record)(ggml_backend_t backend, ggml_backend_event_t event);
+        // wait for an event on on a different stream
+        void (*event_wait)  (ggml_backend_t backend, ggml_backend_event_t event);
     };
 
     struct ggml_backend {
         ggml_guid_t guid;
-
         struct ggml_backend_i iface;
-        ggml_backend_context_t context;
+        ggml_backend_dev_t device;
+        void * context;
     };
 
     struct ggml_backend_event {
-        ggml_backend_t backend;
+        struct ggml_backend_device * device;
         void * context;
     };
 
     //
-    // Backend registry
+    // Backend device
     //
 
-    typedef ggml_backend_t (*GGML_CALL ggml_backend_init_fn)(const char * params, void * user_data);
+    // Note: if additional properties are needed, we should add a struct with all of them
+    //       the current functions to obtain the properties can remain, since they are more convenient for often used properties
+    struct ggml_backend_device_i {
+        // device name: short identifier for this device, such as "CPU" or "CUDA0"
+        const char * (*get_name)(ggml_backend_dev_t dev);
+
+        // device description: short informative description of the device, could be the model name
+        const char * (*get_description)(ggml_backend_dev_t dev);
+
+        // device memory in bytes
+        void         (*get_memory)(ggml_backend_dev_t dev, size_t * free, size_t * total);
+
+        // device type
+        enum ggml_backend_dev_type (*get_type)(ggml_backend_dev_t dev);
+
+        // device properties
+        void (*get_props)(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props);
+
+        // backend (stream) initialization
+        ggml_backend_t (*init_backend)(ggml_backend_dev_t dev, const char * params);
+
+        // preferred buffer type
+        ggml_backend_buffer_type_t (*get_buffer_type)(ggml_backend_dev_t dev);
+
+        // (optional) host buffer type (in system memory, typically this is a pinned memory buffer for faster transfers between host and device)
+        ggml_backend_buffer_type_t (*get_host_buffer_type)(ggml_backend_dev_t dev);
 
-    GGML_CALL void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data);
+        // (optional) buffer from pointer: create a buffer from a host pointer (useful for memory mapped models and importing data from other libraries)
+        ggml_backend_buffer_t (*buffer_from_host_ptr)(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size);
+
+        // check if the backend can compute an operation
+        bool (*supports_op)(ggml_backend_dev_t dev, const struct ggml_tensor * op);
+
+        // check if the backend can use tensors allocated in a buffer type
+        bool (*supports_buft)(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft);
+
+        // (optional) check if the backend wants to run an operation, even if the weights are allocated in an incompatible buffer
+        // these should be expensive operations that may benefit from running on this backend instead of the CPU backend
+        bool (*offload_op)(ggml_backend_dev_t dev, const struct ggml_tensor * op);
+
+        // (optional) event synchronization
+        ggml_backend_event_t (*event_new)         (ggml_backend_dev_t dev);
+        void                 (*event_free)        (ggml_backend_dev_t dev, ggml_backend_event_t event);
+        void                 (*event_synchronize) (ggml_backend_dev_t dev, ggml_backend_event_t event);
+    };
+
+    struct ggml_backend_device {
+        struct ggml_backend_device_i iface;
+        ggml_backend_reg_t reg;
+        void * context;
+    };
+
+    //
+    // Backend (reg)
+    //
+
+    struct ggml_backend_reg_i {
+        const char * (*get_name)(ggml_backend_reg_t reg);
+
+        // enumerate available devices
+        size_t             (*get_device_count)(ggml_backend_reg_t reg);
+        ggml_backend_dev_t (*get_device)(ggml_backend_reg_t reg, size_t index);
+
+        // (optional) get a pointer to a function in the backend
+        // backends can add custom functions that are not part of the standard ggml-backend interface
+        void * (*get_proc_address)(ggml_backend_reg_t reg, const char * name);
+    };
+
+    struct ggml_backend_reg {
+        int api_version; // initialize to GGML_BACKEND_API_VERSION
+        struct ggml_backend_reg_i iface;
+        void * context;
+    };
+
+    // Internal backend registry API
+    GGML_API void ggml_backend_register(ggml_backend_reg_t reg);
+    GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);
+
+    // Add backend dynamic loading support to the backend
+
+    // Initialize the backend
+    typedef ggml_backend_reg_t (*ggml_backend_init_t)(void);
+    // Optional: obtain a score for the backend based on the system configuration
+    // Higher scores are preferred, 0 means the backend is not supported in the current system
+    typedef int                (*ggml_backend_score_t)(void);
+
+#ifdef GGML_BACKEND_DL
+#    ifdef __cplusplus
+#        define GGML_BACKEND_DL_IMPL(reg_fn)                             \
+            extern "C" {                                                 \
+            GGML_BACKEND_API ggml_backend_reg_t ggml_backend_init(void); \
+            }                                                            \
+            ggml_backend_reg_t ggml_backend_init(void) {                 \
+                return reg_fn();                                         \
+            }
+#        define GGML_BACKEND_DL_SCORE_IMPL(score_fn)       \
+            extern "C" {                                   \
+            GGML_BACKEND_API int ggml_backend_score(void); \
+            }                                              \
+            int ggml_backend_score(void) {                 \
+                return score_fn();                         \
+            }
+#    else
+#        define GGML_BACKEND_DL_IMPL(reg_fn)                              \
+            GGML_BACKEND_API ggml_backend_reg_t ggml_backend_init(void);  \
+            ggml_backend_reg_t                  ggml_backend_init(void) { \
+                return reg_fn();                                          \
+            }
+#        define GGML_BACKEND_DL_SCORE_IMPL(score_fn)        \
+            GGML_BACKEND_API int ggml_backend_score(void);  \
+            int                  ggml_backend_score(void) { \
+                return score_fn();                          \
+            }
+#    endif
+#else
+#    define GGML_BACKEND_DL_IMPL(reg_fn)
+#    define GGML_BACKEND_DL_SCORE_IMPL(score_fn)
+#endif
 
 #ifdef  __cplusplus
 }

+ 545 - 0
llama/ggml-backend-reg.cpp

@@ -0,0 +1,545 @@
+/**
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "ggml-backend-impl.h"
+#include "ggml-backend.h"
+#include "ggml-impl.h"
+#include <algorithm>
+#include <codecvt>
+#include <cstring>
+#include <filesystem>
+#include <locale>
+#include <memory>
+#include <string>
+#include <type_traits>
+#include <vector>
+
+#ifdef _WIN32
+#    define WIN32_LEAN_AND_MEAN
+#    ifndef NOMINMAX
+#        define NOMINMAX
+#    endif
+#    include <windows.h>
+#elif defined(__APPLE__)
+#    include <mach-o/dyld.h>
+#    include <dlfcn.h>
+#else
+#    include <dlfcn.h>
+#    include <unistd.h>
+#endif
+
+// Backend registry
+#ifdef GGML_USE_CPU
+#include "ggml-cpu.h"
+#endif
+
+#ifdef GGML_USE_CUDA
+#include "ggml-cuda.h"
+#endif
+
+#ifdef GGML_USE_METAL
+#include "ggml-metal.h"
+#endif
+
+#ifdef GGML_USE_SYCL
+#include "ggml-sycl.h"
+#endif
+
+#ifdef GGML_USE_VULKAN
+#include "ggml-vulkan.h"
+#endif
+
+#ifdef GGML_USE_BLAS
+#include "ggml-blas.h"
+#endif
+
+#ifdef GGML_USE_RPC
+#include "ggml-rpc.h"
+#endif
+
+#ifdef GGML_USE_CANN
+#include "ggml-cann.h"
+#endif
+
+#ifdef GGML_USE_KOMPUTE
+#include "ggml-kompute.h"
+#endif
+
+#ifdef _WIN32
+
+using dl_handle = std::remove_pointer_t<HMODULE>;
+
+struct dl_handle_deleter {
+    void operator()(HMODULE handle) {
+        FreeLibrary(handle);
+    }
+};
+
+static dl_handle * dl_load_library(const std::wstring & path) {
+    // suppress error dialogs for missing DLLs
+    DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
+    SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
+
+    HMODULE handle = LoadLibraryW(path.c_str());
+
+    SetErrorMode(old_mode);
+
+    return handle;
+}
+
+static dl_handle * dl_load_library(const std::string & path) {
+    std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
+    return dl_load_library(converter.from_bytes(path));
+}
+
+static void * dl_get_sym(dl_handle * handle, const char * name) {
+    DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
+    SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
+
+    void * p = (void *) GetProcAddress(handle, name);
+
+    SetErrorMode(old_mode);
+
+    return p;
+}
+
+#else
+
+using dl_handle = void;
+
+struct dl_handle_deleter {
+    void operator()(void * handle) {
+        dlclose(handle);
+    }
+};
+
+static void * dl_load_library(const std::string & path) {
+    dl_handle * handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL);
+
+    return handle;
+}
+
+static void * dl_get_sym(dl_handle * handle, const char * name) {
+    return dlsym(handle, name);
+}
+
+#endif
+
+using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
+
+struct ggml_backend_reg_entry {
+    ggml_backend_reg_t reg;
+    dl_handle_ptr handle;
+};
+
+struct ggml_backend_registry {
+    std::vector<ggml_backend_reg_entry> backends;
+    std::vector<ggml_backend_dev_t> devices;
+
+    ggml_backend_registry() {
+#ifdef GGML_USE_CUDA
+        register_backend(ggml_backend_cuda_reg());
+#endif
+#ifdef GGML_USE_METAL
+        register_backend(ggml_backend_metal_reg());
+#endif
+#ifdef GGML_USE_SYCL
+        register_backend(ggml_backend_sycl_reg());
+#endif
+#ifdef GGML_USE_VULKAN
+        register_backend(ggml_backend_vk_reg());
+#endif
+#ifdef GGML_USE_CANN
+        register_backend(ggml_backend_cann_reg());
+#endif
+#ifdef GGML_USE_BLAS
+        register_backend(ggml_backend_blas_reg());
+#endif
+#ifdef GGML_USE_RPC
+        register_backend(ggml_backend_rpc_reg());
+#endif
+#ifdef GGML_USE_KOMPUTE
+        register_backend(ggml_backend_kompute_reg());
+#endif
+#ifdef GGML_USE_CPU
+        register_backend(ggml_backend_cpu_reg());
+#endif
+    }
+
+    ~ggml_backend_registry() {
+        // FIXME: backends cannot be safely unloaded without a function to destroy all the backend resources,
+        // since backend threads may still be running and accessing resources from the dynamic library
+        for (auto & entry : backends) {
+            if (entry.handle) {
+                entry.handle.release(); // NOLINT
+            }
+        }
+    }
+
+    void register_backend(ggml_backend_reg_t reg, dl_handle_ptr handle = nullptr) {
+        if (!reg) {
+            return;
+        }
+
+#ifndef NDEBUG
+        GGML_LOG_DEBUG("%s: registered backend %s (%zu devices)\n",
+            __func__, ggml_backend_reg_name(reg), ggml_backend_reg_dev_count(reg));
+#endif
+        backends.push_back({ reg, std::move(handle) });
+        for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) {
+            register_device(ggml_backend_reg_dev_get(reg, i));
+        }
+    }
+
+    void register_device(ggml_backend_dev_t device) {
+#ifndef NDEBUG
+        GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device));
+#endif
+        devices.push_back(device);
+    }
+
+    ggml_backend_reg_t load_backend(const char * path, bool silent) {
+        dl_handle_ptr handle { dl_load_library(path) };
+        if (!handle) {
+            if (!silent) {
+                GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path);
+            }
+            return nullptr;
+        }
+
+        auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
+        if (score_fn && score_fn() == 0) {
+            if (!silent) {
+                GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, path);
+            }
+            return nullptr;
+        }
+
+        auto backend_init_fn = (ggml_backend_init_t) dl_get_sym(handle.get(), "ggml_backend_init");
+        if (!backend_init_fn) {
+            if (!silent) {
+                GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, path);
+            }
+            return nullptr;
+        }
+
+        ggml_backend_reg_t reg = backend_init_fn();
+        if (!reg || reg->api_version != GGML_BACKEND_API_VERSION) {
+            if (!silent) {
+                if (!reg) {
+                    GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, path);
+                } else {
+                    GGML_LOG_ERROR("%s: failed to initialize backend from %s: incompatible API version (backend: %d, current: %d)\n",
+                        __func__, path, reg->api_version, GGML_BACKEND_API_VERSION);
+                }
+            }
+            return nullptr;
+        }
+
+        GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path);
+
+        register_backend(reg, std::move(handle));
+
+        return reg;
+    }
+
+    void unload_backend(ggml_backend_reg_t reg, bool silent) {
+        auto it = std::find_if(backends.begin(), backends.end(),
+                               [reg](const ggml_backend_reg_entry & entry) { return entry.reg == reg; });
+
+        if (it == backends.end()) {
+            if (!silent) {
+                GGML_LOG_ERROR("%s: backend not found\n", __func__);
+            }
+            return;
+        }
+
+        if (!silent) {
+            GGML_LOG_DEBUG("%s: unloading %s backend\n", __func__, ggml_backend_reg_name(reg));
+        }
+
+        // remove devices
+        devices.erase(
+            std::remove_if(devices.begin(), devices.end(),
+                            [reg](ggml_backend_dev_t dev) { return ggml_backend_dev_backend_reg(dev) == reg; }),
+            devices.end());
+
+        // remove backend
+        backends.erase(it);
+    }
+};
+
+static ggml_backend_registry & get_reg() {
+    static ggml_backend_registry reg;
+    return reg;
+}
+
+// Internal API
+void ggml_backend_register(ggml_backend_reg_t reg) {
+    get_reg().register_backend(reg);
+}
+
+void ggml_backend_device_register(ggml_backend_dev_t device) {
+    get_reg().register_device(device);
+}
+
+// Backend (reg) enumeration
+static bool striequals(const char * a, const char * b) {
+    for (; *a && *b; a++, b++) {
+        if (std::tolower(*a) != std::tolower(*b)) {
+            return false;
+        }
+    }
+    return *a == *b;
+}
+
+size_t ggml_backend_reg_count() {
+    return get_reg().backends.size();
+}
+
+ggml_backend_reg_t ggml_backend_reg_get(size_t index) {
+    GGML_ASSERT(index < ggml_backend_reg_count());
+    return get_reg().backends[index].reg;
+}
+
+ggml_backend_reg_t ggml_backend_reg_by_name(const char * name) {
+    for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
+        ggml_backend_reg_t reg = ggml_backend_reg_get(i);
+        if (striequals(ggml_backend_reg_name(reg), name)) {
+            return reg;
+        }
+    }
+    return nullptr;
+}
+
+// Device enumeration
+size_t ggml_backend_dev_count() {
+    return get_reg().devices.size();
+}
+
+ggml_backend_dev_t ggml_backend_dev_get(size_t index) {
+    GGML_ASSERT(index < ggml_backend_dev_count());
+    return get_reg().devices[index];
+}
+
+ggml_backend_dev_t ggml_backend_dev_by_name(const char * name) {
+    for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
+        ggml_backend_dev_t dev = ggml_backend_dev_get(i);
+        if (striequals(ggml_backend_dev_name(dev), name)) {
+            return dev;
+        }
+    }
+    return nullptr;
+}
+
+ggml_backend_dev_t ggml_backend_dev_by_type(enum ggml_backend_dev_type type) {
+    for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
+        ggml_backend_dev_t dev = ggml_backend_dev_get(i);
+        if (ggml_backend_dev_type(dev) == type) {
+            return dev;
+        }
+    }
+    return nullptr;
+}
+
+// Convenience functions
+ggml_backend_t ggml_backend_init_by_name(const char * name, const char * params) {
+    ggml_backend_dev_t dev = ggml_backend_dev_by_name(name);
+    if (!dev) {
+        return nullptr;
+    }
+    return ggml_backend_dev_init(dev, params);
+}
+
+ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const char * params) {
+    ggml_backend_dev_t dev = ggml_backend_dev_by_type(type);
+    if (!dev) {
+        return nullptr;
+    }
+    return ggml_backend_dev_init(dev, params);
+}
+
+ggml_backend_t ggml_backend_init_best(void) {
+    ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU);
+    if (!dev) {
+        dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
+    }
+    if (!dev) {
+        return nullptr;
+    }
+    return ggml_backend_dev_init(dev, nullptr);
+}
+
+// Dynamic loading
+ggml_backend_reg_t ggml_backend_load(const char * path) {
+    return get_reg().load_backend(path, false);
+}
+
+void ggml_backend_unload(ggml_backend_reg_t reg) {
+    get_reg().unload_backend(reg, true);
+}
+
+static std::string get_executable_path() {
+#if defined(__APPLE__)
+    // get executable path
+    std::vector<char> path;
+    uint32_t size;
+    while (true) {
+        size = path.size();
+        if (_NSGetExecutablePath(path.data(), &size) == 0) {
+            break;
+        }
+        path.resize(size);
+    }
+    std::string base_path(path.data(), size);
+    // remove executable name
+    auto last_slash = base_path.find_last_of('/');
+    if (last_slash != std::string::npos) {
+        base_path = base_path.substr(0, last_slash);
+    }
+    return base_path + "/";
+#elif defined(__linux__)
+    std::string base_path = ".";
+    std::vector<char> path(1024);
+    while (true) {
+        // get executable path
+        ssize_t len = readlink("/proc/self/exe", path.data(), path.size());
+        if (len == -1) {
+            break;
+        }
+        if (len < (ssize_t) path.size()) {
+            base_path = std::string(path.data(), len);
+            // remove executable name
+            auto last_slash = base_path.find_last_of('/');
+            if (last_slash != std::string::npos) {
+                base_path = base_path.substr(0, last_slash);
+            }
+            break;
+        }
+        path.resize(path.size() * 2);
+    }
+
+    return base_path + "/";
+#elif defined(_WIN32)
+    std::vector<char> path(MAX_PATH);
+    DWORD len = GetModuleFileNameA(NULL, path.data(), path.size());
+    if (len == 0) {
+        return "";
+    }
+    std::string base_path(path.data(), len);
+    // remove executable name
+    auto last_slash = base_path.find_last_of('\\');
+    if (last_slash != std::string::npos) {
+        base_path = base_path.substr(0, last_slash);
+    }
+    return base_path + "\\";
+#endif
+}
+
+static std::string backend_filename_prefix() {
+#ifdef _WIN32
+    return "ggml-";
+#else
+    return "libggml-";
+#endif
+}
+
+static std::string backend_filename_suffix() {
+#ifdef _WIN32
+    return ".dll";
+#else
+    return ".so";
+#endif
+}
+
+static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent) {
+    // enumerate all the files that match [lib]ggml-name-*.[so|dll] in the search paths
+     // TODO: search system paths
+    std::vector<std::string> search_paths = { "./", get_executable_path() };
+    std::string file_prefix = backend_filename_prefix() + name + "-";
+
+    int best_score = 0;
+    std::string best_path;
+
+    namespace fs = std::filesystem;
+    for (const auto & search_path : search_paths) {
+        if (!fs::exists(search_path)) {
+            continue;
+        }
+        for (const auto & entry : fs::directory_iterator(search_path)) {
+            if (entry.is_regular_file()) {
+                std::string filename = entry.path().filename().string();
+                std::string ext = entry.path().extension().string();
+                if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
+                    dl_handle_ptr handle { dl_load_library(entry.path().c_str()) };
+                    if (!handle && !silent) {
+                        GGML_LOG_ERROR("%s: failed to load %s\n", __func__, entry.path().string().c_str());
+                    }
+                    if (handle) {
+                        auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
+                        if (score_fn) {
+                            int s = score_fn();
+#ifndef NDEBUG
+                            GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, entry.path().string().c_str(), s);
+#endif
+                            if (s > best_score) {
+                                best_score = s;
+                                best_path = entry.path().string();
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+
+    if (best_score == 0) {
+        // try to load the base backend
+        for (const auto & search_path : search_paths) {
+            std::string path = search_path + backend_filename_prefix() + name + backend_filename_suffix();
+            if (fs::exists(path)) {
+                return get_reg().load_backend(path.c_str(), silent);
+            }
+        }
+        return nullptr;
+    }
+
+    return get_reg().load_backend(best_path.c_str(), silent);
+}
+
+void ggml_backend_load_all() {
+    ggml_backend_load_best("blas", true);
+    ggml_backend_load_best("cann", true);
+    ggml_backend_load_best("cuda", true);
+    ggml_backend_load_best("hip", true);
+    ggml_backend_load_best("kompute", true);
+    ggml_backend_load_best("metal", true);
+    ggml_backend_load_best("rpc", true);
+    ggml_backend_load_best("sycl", true);
+    ggml_backend_load_best("vulkan", true);
+    ggml_backend_load_best("musa", true);
+    ggml_backend_load_best("cpu", true);
+}

Diferenças do arquivo suprimidas por serem muito extensas
+ 155 - 607
llama/ggml-backend.cpp


+ 182 - 72
llama/ggml-backend.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -29,6 +29,20 @@
 #include "ggml.h"
 #include "ggml-alloc.h"
 
+#ifdef GGML_BACKEND_SHARED
+#    if defined(_WIN32) && !defined(__MINGW32__)
+#        ifdef GGML_BACKEND_BUILD
+#            define GGML_BACKEND_API __declspec(dllexport) extern
+#        else
+#            define GGML_BACKEND_API __declspec(dllimport) extern
+#        endif
+#    else
+#        define GGML_BACKEND_API __attribute__ ((visibility ("default"))) extern
+#    endif
+#else
+#    define GGML_BACKEND_API extern
+#endif
+
 #ifdef  __cplusplus
 extern "C" {
 #endif
@@ -38,43 +52,52 @@ extern "C" {
     typedef struct ggml_backend_event * ggml_backend_event_t;
     typedef struct ggml_backend * ggml_backend_t;
     typedef void * ggml_backend_graph_plan_t;
+    typedef struct ggml_backend_reg * ggml_backend_reg_t;
+    typedef struct ggml_backend_device * ggml_backend_dev_t;
+
 
     //
-    // Backend buffer
+    // Backend buffer type
     //
 
-    // buffer type
-    GGML_API           const char *          ggml_backend_buft_name            (ggml_backend_buffer_type_t buft);
-    GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_buft_alloc_buffer    (ggml_backend_buffer_type_t buft, size_t size);
-    GGML_API           size_t                ggml_backend_buft_get_alignment   (ggml_backend_buffer_type_t buft);
-    GGML_API           size_t                ggml_backend_buft_get_max_size    (ggml_backend_buffer_type_t buft);
-    GGML_API GGML_CALL size_t                ggml_backend_buft_get_alloc_size  (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
-    GGML_API           bool                  ggml_backend_buft_is_host         (ggml_backend_buffer_type_t buft);
+    GGML_API const char *          ggml_backend_buft_name          (ggml_backend_buffer_type_t buft);
+    GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer  (ggml_backend_buffer_type_t buft, size_t size);
+    GGML_API size_t                ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
+    GGML_API size_t                ggml_backend_buft_get_max_size  (ggml_backend_buffer_type_t buft);
+    GGML_API size_t                ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
+    GGML_API bool                  ggml_backend_buft_is_host       (ggml_backend_buffer_type_t buft);
+    GGML_API ggml_backend_dev_t    ggml_backend_buft_get_device    (ggml_backend_buffer_type_t buft);
+
+    //
+    // Backend buffer
+    //
 
-    // buffer
     enum ggml_backend_buffer_usage {
         GGML_BACKEND_BUFFER_USAGE_ANY = 0,
         GGML_BACKEND_BUFFER_USAGE_WEIGHTS = 1,
         GGML_BACKEND_BUFFER_USAGE_COMPUTE = 2,
     };
 
-    GGML_API           const char *                   ggml_backend_buffer_name          (ggml_backend_buffer_t buffer);
-    GGML_API           void                           ggml_backend_buffer_free          (ggml_backend_buffer_t buffer);
-    GGML_API           void *                         ggml_backend_buffer_get_base      (ggml_backend_buffer_t buffer);
-    GGML_API           size_t                         ggml_backend_buffer_get_size      (ggml_backend_buffer_t buffer);
-    GGML_API GGML_CALL void                           ggml_backend_buffer_init_tensor   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
-    GGML_API           size_t                         ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
-    GGML_API           size_t                         ggml_backend_buffer_get_max_size  (ggml_backend_buffer_t buffer);
-    GGML_API           size_t                         ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
-    GGML_API           void                           ggml_backend_buffer_clear         (ggml_backend_buffer_t buffer, uint8_t value);
-    GGML_API           bool                           ggml_backend_buffer_is_host       (ggml_backend_buffer_t buffer);
-    GGML_API           void                           ggml_backend_buffer_set_usage     (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
-    GGML_API           enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage     (ggml_backend_buffer_t buffer);
-    GGML_API           ggml_backend_buffer_type_t     ggml_backend_buffer_get_type      (ggml_backend_buffer_t buffer);
-    GGML_API           void                           ggml_backend_buffer_reset         (ggml_backend_buffer_t buffer);
+    GGML_API const char *                   ggml_backend_buffer_name          (ggml_backend_buffer_t buffer);
+    GGML_API void                           ggml_backend_buffer_free          (ggml_backend_buffer_t buffer);
+    GGML_API void *                         ggml_backend_buffer_get_base      (ggml_backend_buffer_t buffer);
+    GGML_API size_t                         ggml_backend_buffer_get_size      (ggml_backend_buffer_t buffer);
+    GGML_API void                           ggml_backend_buffer_init_tensor   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+    GGML_API size_t                         ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
+    GGML_API size_t                         ggml_backend_buffer_get_max_size  (ggml_backend_buffer_t buffer);
+    GGML_API size_t                         ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+    GGML_API void                           ggml_backend_buffer_clear         (ggml_backend_buffer_t buffer, uint8_t value);
+    GGML_API bool                           ggml_backend_buffer_is_host       (ggml_backend_buffer_t buffer);
+    GGML_API void                           ggml_backend_buffer_set_usage     (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
+    GGML_API enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage     (ggml_backend_buffer_t buffer);
+    GGML_API ggml_backend_buffer_type_t     ggml_backend_buffer_get_type      (ggml_backend_buffer_t buffer);
+    GGML_API void                           ggml_backend_buffer_reset         (ggml_backend_buffer_t buffer);
+
+    // tensor copy between different backends
+    GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
 
     //
-    // Backend
+    // Backend (stream)
     //
 
     GGML_API ggml_guid_t  ggml_backend_guid(ggml_backend_t backend);
@@ -89,10 +112,10 @@ extern "C" {
     GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
     GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
 
-    // "offset" refers to the offset of the tensor data for setting/getting data
-    GGML_API GGML_CALL void ggml_backend_tensor_set(      struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
-    GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
-    GGML_API GGML_CALL void ggml_backend_tensor_memset(   struct ggml_tensor * tensor,     uint8_t value, size_t offset, size_t size);
+    // "offset" refers to the offset in tensor->data for setting/getting data
+    GGML_API void ggml_backend_tensor_set(      struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+    GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+    GGML_API void ggml_backend_tensor_memset(   struct ggml_tensor * tensor,     uint8_t value, size_t offset, size_t size);
 
     GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
 
@@ -102,65 +125,141 @@ extern "C" {
     GGML_API enum ggml_status ggml_backend_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
     GGML_API enum ggml_status ggml_backend_graph_compute      (ggml_backend_t backend, struct ggml_cgraph * cgraph);
     GGML_API enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph);
+
+    // NOTE: will be removed, use device version instead
     GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op);
     GGML_API bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft);
     GGML_API bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op);
 
-    // tensor copy between different backends
-    GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
-
     // asynchronous copy
     // the copy is performed after all the currently queued operations in backend_src
     // backend_dst will wait for the copy to complete before performing other operations
     // automatic fallback to sync copy if async is not supported
     GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst);
 
-    // events
-    GGML_API ggml_backend_event_t   ggml_backend_event_new        (ggml_backend_t backend);
-    GGML_API void                   ggml_backend_event_free       (ggml_backend_event_t event);
-    GGML_API void                   ggml_backend_event_record     (ggml_backend_event_t event);
-    GGML_API void                   ggml_backend_event_synchronize(ggml_backend_event_t event);
-    GGML_API void                   ggml_backend_event_wait       (ggml_backend_t backend, ggml_backend_event_t event);
+    GGML_API ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend);
 
     //
-    // CPU backend
+    // Events
     //
 
-    GGML_API ggml_backend_t ggml_backend_cpu_init(void);
+    GGML_API ggml_backend_event_t ggml_backend_event_new(ggml_backend_dev_t device);
+    GGML_API void                 ggml_backend_event_free(ggml_backend_event_t event);
+    GGML_API void                 ggml_backend_event_record(ggml_backend_event_t event, ggml_backend_t backend);
+    GGML_API void                 ggml_backend_event_synchronize(ggml_backend_event_t event);
+    GGML_API void                 ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event);
 
-    GGML_API GGML_CALL bool ggml_backend_is_cpu                (ggml_backend_t backend);
-    GGML_API           void ggml_backend_cpu_set_n_threads     (ggml_backend_t backend_cpu, int n_threads);
-    GGML_API           void ggml_backend_cpu_set_threadpool    (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool);
-    GGML_API           void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
+    //
+    // Backend device
+    //
 
-    // Create a backend buffer from an existing pointer
-    GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
+    enum ggml_backend_dev_type {
+        // CPU device using system memory
+        GGML_BACKEND_DEVICE_TYPE_CPU,
+        // GPU device using dedicated memory
+        GGML_BACKEND_DEVICE_TYPE_GPU,
+        // accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX)
+        GGML_BACKEND_DEVICE_TYPE_ACCEL
+    };
 
-    GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);
+    // functionality supported by the device
+    struct ggml_backend_dev_caps {
+        // asynchronous operations
+        bool async;
+        // pinned host buffer
+        bool host_buffer;
+        // creating buffers from host ptr
+        bool buffer_from_host_ptr;
+        // event synchronization
+        bool events;
+    };
 
-#ifdef GGML_USE_CPU_HBM
-    GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);
-#endif
+    // all the device properties
+    struct ggml_backend_dev_props {
+        const char * name;
+        const char * description;
+        size_t memory_free;
+        size_t memory_total;
+        enum ggml_backend_dev_type type;
+        struct ggml_backend_dev_caps caps;
+    };
+
+    GGML_API const char *                  ggml_backend_dev_name(ggml_backend_dev_t device);
+    GGML_API const char *                  ggml_backend_dev_description(ggml_backend_dev_t device);
+    GGML_API void                          ggml_backend_dev_memory(ggml_backend_dev_t device, size_t * free, size_t * total);
+    GGML_API enum ggml_backend_dev_type    ggml_backend_dev_type(ggml_backend_dev_t device);
+    GGML_API void                          ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props);
+    GGML_API ggml_backend_reg_t            ggml_backend_dev_backend_reg(ggml_backend_dev_t device);
+    GGML_API ggml_backend_t                ggml_backend_dev_init(ggml_backend_dev_t device, const char * params);
+    GGML_API ggml_backend_buffer_type_t    ggml_backend_dev_buffer_type(ggml_backend_dev_t device);
+    GGML_API ggml_backend_buffer_type_t    ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);
+    GGML_API ggml_backend_buffer_t         ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);
+
+    GGML_API bool                          ggml_backend_dev_supports_op(ggml_backend_dev_t device, const struct ggml_tensor * op);
+    GGML_API bool                          ggml_backend_dev_supports_buft(ggml_backend_dev_t device, ggml_backend_buffer_type_t buft);
+    GGML_API bool                          ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_tensor * op);
 
     //
-    // Backend registry
+    // Backend (reg)
     //
 
-    // The backend registry is a registry of all the available backends, and allows initializing backends in a generic way
+    GGML_API const char *       ggml_backend_reg_name(ggml_backend_reg_t reg);
+    GGML_API size_t             ggml_backend_reg_dev_count(ggml_backend_reg_t reg);
+    GGML_API ggml_backend_dev_t ggml_backend_reg_dev_get(ggml_backend_reg_t reg, size_t index);
+    GGML_API void *             ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * name);
+
+    // Common functions that may be obtained using ggml_backend_reg_get_proc_address
+
+    // Split buffer type for tensor parallelism
+    typedef ggml_backend_buffer_type_t   (*ggml_backend_split_buffer_type_t)(int main_device, const float * tensor_split);
+    // Set the number of threads for the backend
+    typedef void                         (*ggml_backend_set_n_threads_t)(ggml_backend_t backend, int n_threads);
+    // Get additional buffer types provided by the device (returns a NULL-terminated array)
+    typedef ggml_backend_buffer_type_t * (*ggml_backend_dev_get_extra_bufts_t)(ggml_backend_dev_t device);
+    // Set the abort callback for the backend
+    typedef void                         (*ggml_backend_set_abort_callback_t)(ggml_backend_t backend, ggml_abort_callback abort_callback, void * abort_callback_data);
+    // Get a list of feature flags supported by the backend (returns a NULL-terminated array)
+    struct ggml_backend_feature {
+        const char * name;
+        const char * value;
+    };
+    typedef struct ggml_backend_feature * (*ggml_backend_get_features_t)(ggml_backend_reg_t reg);
+
+    //
+    // Backend registry
+    //
 
-    GGML_API size_t                     ggml_backend_reg_get_count(void);
-    GGML_API size_t                     ggml_backend_reg_find_by_name(const char * name); // returns index of backend with name, or SIZE_MAX if not found
-    GGML_API ggml_backend_t             ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is backend_name:params (params is optional)
-    GGML_API const char *               ggml_backend_reg_get_name(size_t i);
-    GGML_API ggml_backend_t             ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific
-    GGML_API ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i);
-    GGML_API ggml_backend_buffer_t      ggml_backend_reg_alloc_buffer(size_t i, size_t size);
+    // Backend (reg) enumeration
+    GGML_API size_t             ggml_backend_reg_count(void);
+    GGML_API ggml_backend_reg_t ggml_backend_reg_get(size_t index);
+    GGML_API ggml_backend_reg_t ggml_backend_reg_by_name(const char * name);
+
+    // Device enumeration
+    GGML_API size_t             ggml_backend_dev_count(void);
+    GGML_API ggml_backend_dev_t ggml_backend_dev_get(size_t index);
+    GGML_API ggml_backend_dev_t ggml_backend_dev_by_name(const char * name);
+    GGML_API ggml_backend_dev_t ggml_backend_dev_by_type(enum ggml_backend_dev_type type);
+
+    // Direct backend (stream) initialization
+    // = ggml_backend_dev_init(ggml_backend_dev_by_name(name), params)
+    GGML_API ggml_backend_t ggml_backend_init_by_name(const char * name, const char * params);
+    // = ggml_backend_dev_init(ggml_backend_dev_by_type(type), params)
+    GGML_API ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const char * params);
+    // = ggml_backend_dev_init(ggml_backend_dev_by_type(GPU) OR ggml_backend_dev_by_type(CPU), NULL)
+    GGML_API ggml_backend_t ggml_backend_init_best(void);
+
+    // Load a backend from a dynamic library and register it
+    GGML_API ggml_backend_reg_t ggml_backend_load(const char * path);
+    // Unload a backend if loaded dynamically and unregister it
+    GGML_API void               ggml_backend_unload(ggml_backend_reg_t reg);
+    // Load all known backends from dynamic libraries
+    GGML_API void               ggml_backend_load_all(void);
 
     //
     // Backend scheduler
     //
 
-    // The backend scheduler allows for multiple backends to be used together
+    // The backend scheduler allows for multiple backend devices to be used together
     // Handles compute buffer allocation, assignment of tensors to backends, and copying of tensors between backends
     // The backends are selected based on:
     // - the backend that supports the operation
@@ -184,20 +283,26 @@ extern "C" {
         ggml_backend_sched_reserve(sched, reserve_graph);
 
         // compute
-        graph = build_graph(sched);
-        ggml_backend_sched_graph_compute(sched, graph);
+        graph = build_graph(sched); // the graph and its tensors are single-use in terms of allocation, multi-use in terms of computation
+        for (int i = 0; i < 10; ++i) {
+            ggml_backend_sched_graph_compute(sched, graph); // on the first iteration the graph is allocated automatically
+        }
 
         // if there are graph inputs:
-        ggml_backend_sched_reset(sched);
-        ggml_backend_sched_alloc_graph(sched, graph);
-        ggml_backend_tensor_set(input_tensor, ...);
-        ggml_backend_sched_graph_compute(sched, graph);
+        graph = build_graph(sched); // get a new graph that is not allocated (the metadata for the old graph is freed once ggml_free is called)
+        ggml_backend_sched_reset(sched); // clear the allocation of the previous graph
+        ggml_backend_sched_alloc_graph(sched, graph); // explicitly allocate the new graph but do not execute it
+        ggml_backend_tensor_set(input_tensor, ...); // copy data to the newly allocated graph tensors
+        ggml_backend_sched_graph_compute(sched, graph); // execute the graph
+
+        // as an alternative to the above it is also possible to assign the inputs to a dedicated context and
+        // allocate them statically via ggml_backend_alloc_ctx_tensors
     }
     */
 
-    struct ggml_backend_sched;
     typedef struct ggml_backend_sched * ggml_backend_sched_t;
 
+    // Evaluation callback for each node in the graph (set with ggml_backend_sched_set_eval_callback)
     // when ask == true, the scheduler wants to know if the user wants to observe this node
     // this allows the scheduler to batch nodes together in order to evaluate them in a single call
     //
@@ -206,12 +311,12 @@ extern "C" {
     //
     typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
 
-    // Initialize a backend scheduler
+    // Initialize a backend scheduler, backends with low index are given priority over backends with high index
     GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel);
     GGML_API void                 ggml_backend_sched_free(ggml_backend_sched_t sched);
 
     // Initialize backend buffers from a measure graph
-    GGML_API bool                 ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
+    GGML_API bool                 ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); // returns success
 
     GGML_API int                  ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched);
     GGML_API ggml_backend_t       ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i);
@@ -226,12 +331,14 @@ extern "C" {
     GGML_API ggml_backend_t       ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
 
     // Allocate and compute graph on the backend scheduler
-    GGML_API bool                 ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
+    GGML_API bool                 ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success
     GGML_API enum ggml_status     ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
     GGML_API enum ggml_status     ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
     GGML_API void                 ggml_backend_sched_synchronize(ggml_backend_sched_t sched);
 
-    // Reset all assignments and allocators - must be called before changing the node backends
+    // Reset all assignments and allocators - must be called before changing the node backends or allocating a new graph.
+    // This in effect deallocates all tensors that were previously allocated and leaves them with dangling pointers.
+    // The correct way to use this API is to discard the deallocated tensors and create new ones.
     GGML_API void                 ggml_backend_sched_reset(ggml_backend_sched_t sched);
 
     // Set a callback to be called for each resulting node during graph compute
@@ -252,7 +359,7 @@ extern "C" {
     GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph);
     GGML_API void                           ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy);
 
-    typedef bool (*GGML_CALL ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
+    typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
 
     // Compare the output of two backends
     GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
@@ -261,6 +368,9 @@ extern "C" {
     GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
     GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor);
 
+    // CPU buffer types are always available
+    GGML_API ggml_backend_buffer_t      ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
+    GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);
 
 #ifdef  __cplusplus
 }

+ 221 - 72
llama/ggml-blas.cpp

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -32,8 +32,9 @@
 
 #include <future>
 #include <vector>
+#include <cstring>
 
-#if defined(GGML_USE_ACCELERATE)
+#if defined(GGML_BLAS_USE_ACCELERATE)
 #   include <Accelerate/Accelerate.h>
 #elif defined(GGML_BLAS_USE_MKL)
 #   include <mkl.h>
@@ -54,30 +55,6 @@ struct ggml_backend_blas_context {
 #endif
 };
 
-// helper function to determine if it is better to use BLAS or not
-// for large matrices, BLAS is faster
-static bool ggml_backend_blas_use_blas(const struct ggml_tensor * dst) {
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    const int64_t ne10 = src1->ne[0];
-
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
-
-    // TODO: find the optimal values for these
-    if (ggml_is_contiguous(src0) &&
-        ggml_is_contiguous(src1) &&
-        src1->type == GGML_TYPE_F32 &&
-        (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
-
-        /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
-        return true;
-    }
-
-    return false;
-}
-
 static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
@@ -116,8 +93,8 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
 
     // convert src0 to float
     if (type != GGML_TYPE_F32) {
-        ggml_type_traits_t type_traits = ggml_internal_get_type_traits(type);
-        ggml_to_float_t const to_float = type_traits.to_float;
+        const auto * type_traits = ggml_get_type_traits(type);
+        ggml_to_float_t const to_float = type_traits->to_float;
 
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
@@ -263,25 +240,19 @@ static void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct g
 
 // backend interface
 
-GGML_CALL static const char * ggml_backend_blas_name(ggml_backend_t backend) {
+static const char * ggml_backend_blas_get_name(ggml_backend_t backend) {
     return "BLAS";
 
     GGML_UNUSED(backend);
 }
 
-GGML_CALL static void ggml_backend_blas_free(ggml_backend_t backend) {
+static void ggml_backend_blas_free(ggml_backend_t backend) {
     ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
     delete ctx;
     delete backend;
 }
 
-GGML_CALL static ggml_backend_buffer_type_t ggml_backend_blas_get_default_buffer_type(ggml_backend_t backend) {
-    return ggml_backend_cpu_buffer_type();
-
-    GGML_UNUSED(backend);
-}
-
-GGML_CALL static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
     ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
 
     for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -313,31 +284,9 @@ GGML_CALL static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t
     GGML_UNUSED(backend);
 }
 
-GGML_CALL static bool ggml_backend_blas_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
-    const struct ggml_tensor * src0 = op->src[0];
-    const struct ggml_tensor * src1 = op->src[1];
-
-    return (op->op == GGML_OP_MUL_MAT  && ggml_backend_blas_use_blas(op)) ||
-           (op->op == GGML_OP_OUT_PROD && op->src[0]->type == GGML_TYPE_F32 &&
-                                          op->src[1]->type == GGML_TYPE_F32 &&
-                                          ggml_is_matrix(src0) &&
-                                          ggml_is_matrix(src1) &&
-                                          ggml_is_contiguous(src0) &&
-                                          (ggml_is_contiguous(src1) || ggml_is_transposed(src1)));
-
-    GGML_UNUSED(backend);
-}
-
-GGML_CALL static bool ggml_backend_blas_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
-    return ggml_backend_buft_is_host(buft);
-
-    GGML_UNUSED(backend);
-}
-
 static struct ggml_backend_i blas_backend_i = {
-    /* .get_name                = */ ggml_backend_blas_name,
+    /* .get_name                = */ ggml_backend_blas_get_name,
     /* .free                    = */ ggml_backend_blas_free,
-    /* .get_default_buffer_type = */ ggml_backend_blas_get_default_buffer_type,
     /* .set_tensor_async        = */ NULL,
     /* .get_tensor_async        = */ NULL,
     /* .cpy_tensor_async        = */ NULL,
@@ -347,14 +296,8 @@ static struct ggml_backend_i blas_backend_i = {
     /* .graph_plan_update       = */ NULL,
     /* .graph_plan_compute      = */ NULL,
     /* .graph_compute           = */ ggml_backend_blas_graph_compute,
-    /* .supports_op             = */ ggml_backend_blas_supports_op,
-    /* .supports_buft           = */ ggml_backend_blas_supports_buft,
-    /* .offload_op              = */ NULL,
-    /* .event_new               = */ NULL,
-    /* .event_free              = */ NULL,
     /* .event_record            = */ NULL,
     /* .event_wait              = */ NULL,
-    /* .event_synchronize       = */ NULL,
 };
 
 static ggml_guid_t ggml_backend_blas_guid(void) {
@@ -368,23 +311,24 @@ ggml_backend_t ggml_backend_blas_init(void) {
     ggml_backend_t backend = new ggml_backend {
         /* .guid      = */ ggml_backend_blas_guid(),
         /* .interface = */ blas_backend_i,
+        /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0),
         /* .context   = */ ctx,
     };
 
-#if !defined(NDEBUG) && defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP)
+#if defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP)
     if (openblas_get_parallel() != OPENBLAS_OPENMP) {
-        fprintf(stderr, "%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__);
+        GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__);
     }
 #endif
 
-#if !defined(NDEBUG) && defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP)
-    fprintf(stderr, "%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__);
+#if defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP)
+    GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__);
 #endif
 
     return backend;
 }
 
-GGML_CALL bool ggml_backend_is_blas(ggml_backend_t backend) {
+bool ggml_backend_is_blas(ggml_backend_t backend) {
     return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_blas_guid());
 }
 
@@ -395,4 +339,209 @@ void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads)
     ctx->n_threads = n_threads;
 }
 
-#endif
+// device interface
+
+static const char * ggml_backend_blas_device_get_name(ggml_backend_dev_t dev) {
+    return "BLAS";
+
+    GGML_UNUSED(dev);
+}
+
+static const char * ggml_backend_blas_device_get_description(ggml_backend_dev_t dev) {
+    #if defined(GGML_BLAS_USE_ACCELERATE)
+        return "Accelerate";
+    #elif defined(GGML_BLAS_USE_MKL)
+        return "MKL";
+    #elif defined(GGML_BLAS_USE_BLIS)
+        return "BLIS";
+    #elif defined(GGML_BLAS_USE_NVPL)
+        return "NVPL";
+    #elif defined(OPENBLAS_VERSION)
+        return "OpenBLAS";
+    #else
+        return "BLAS";
+    #endif
+
+    GGML_UNUSED(dev);
+}
+
+static void ggml_backend_blas_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
+    // TODO
+    *free = 0;
+    *total = 0;
+
+    GGML_UNUSED(dev);
+}
+
+static enum ggml_backend_dev_type ggml_backend_blas_device_get_type(ggml_backend_dev_t dev) {
+    return GGML_BACKEND_DEVICE_TYPE_ACCEL;
+
+    GGML_UNUSED(dev);
+}
+
+static void ggml_backend_blas_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
+    props->name        = ggml_backend_blas_device_get_name(dev);
+    props->description = ggml_backend_blas_device_get_description(dev);
+    props->type        = ggml_backend_blas_device_get_type(dev);
+    ggml_backend_blas_device_get_memory(dev, &props->memory_free, &props->memory_total);
+    props->caps = {
+        /* .async                 = */ false,
+        /* .host_buffer           = */ false,
+        /* .buffer_from_host_ptr  = */ true,
+        /* .events                = */ false,
+    };
+}
+
+static ggml_backend_t ggml_backend_blas_device_init_backend(ggml_backend_dev_t dev, const char * params) {
+    return ggml_backend_blas_init();
+
+    GGML_UNUSED(dev);
+    GGML_UNUSED(params);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_blas_device_get_buffer_type(ggml_backend_dev_t dev) {
+    return ggml_backend_cpu_buffer_type();
+
+    GGML_UNUSED(dev);
+}
+
+static ggml_backend_buffer_t ggml_backend_blas_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
+    return ggml_backend_cpu_buffer_from_ptr(ptr, size);
+
+    GGML_UNUSED(dev);
+    GGML_UNUSED(max_tensor_size);
+}
+
+static bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
+    const struct ggml_tensor * src0 = op->src[0];
+    const struct ggml_tensor * src1 = op->src[1];
+
+    switch (op->op) {
+        case GGML_OP_NONE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+            return true;
+
+        case GGML_OP_MUL_MAT:
+        {
+            // BLAS usually is only faster for large matrices
+            const struct ggml_tensor * src0 = op->src[0];
+            const struct ggml_tensor * src1 = op->src[1];
+
+            const int64_t ne10 = src1->ne[0];
+
+            const int64_t ne0 = op->ne[0];
+            const int64_t ne1 = op->ne[1];
+
+            // TODO: find the optimal value
+            const int64_t min_batch = 32;
+
+            return ggml_is_contiguous(src0) &&
+                   ggml_is_contiguous(src1) &&
+                   src1->type == GGML_TYPE_F32 &&
+                   (ne0 >= min_batch && ne1 >= min_batch && ne10 >= min_batch) &&
+                   (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
+        }
+
+        case GGML_OP_OUT_PROD:
+            return op->src[0]->type == GGML_TYPE_F32 &&
+                   op->src[1]->type == GGML_TYPE_F32 &&
+                   ggml_is_matrix(src0) &&
+                   ggml_is_matrix(src1) &&
+                   ggml_is_contiguous(src0) &&
+                   (ggml_is_contiguous(src1) || ggml_is_transposed(src1)) &&
+                   (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
+
+        default:
+            return false;
+
+    }
+
+    GGML_UNUSED(dev);
+}
+
+static bool ggml_backend_blas_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
+    return ggml_backend_buft_is_host(buft);
+
+    GGML_UNUSED(dev);
+}
+
+static const struct ggml_backend_device_i ggml_backend_blas_device_i = {
+    /* .get_name             = */ ggml_backend_blas_device_get_name,
+    /* .get_description      = */ ggml_backend_blas_device_get_description,
+    /* .get_memory           = */ ggml_backend_blas_device_get_memory,
+    /* .get_type             = */ ggml_backend_blas_device_get_type,
+    /* .get_props            = */ ggml_backend_blas_device_get_props,
+    /* .init_backend         = */ ggml_backend_blas_device_init_backend,
+    /* .get_buffer_type      = */ ggml_backend_blas_device_get_buffer_type,
+    /* .get_host_buffer_type = */ NULL,
+    /* .buffer_from_host_ptr = */ ggml_backend_blas_device_buffer_from_host_ptr,
+    /* .supports_op          = */ ggml_backend_blas_device_supports_op,
+    /* .supports_buft        = */ ggml_backend_blas_device_supports_buft,
+    /* .offload_op           = */ NULL,
+    /* .event_new            = */ NULL,
+    /* .event_free           = */ NULL,
+    /* .event_synchronize    = */ NULL,
+};
+
+// backend reg interface
+
+static const char * ggml_backend_blas_reg_get_name(ggml_backend_reg_t reg) {
+    return "BLAS";
+
+    GGML_UNUSED(reg);
+}
+
+static size_t ggml_backend_blas_reg_get_device_count(ggml_backend_reg_t reg) {
+    return 1;
+
+    GGML_UNUSED(reg);
+}
+
+static ggml_backend_dev_t ggml_backend_blas_reg_get_device(ggml_backend_reg_t reg, size_t index) {
+    GGML_ASSERT(index == 0);
+
+    static ggml_backend_device ggml_backend_blas_device = {
+        /* .iface   = */ ggml_backend_blas_device_i,
+        /* .reg     = */ reg,
+        /* .context = */ nullptr,
+    };
+
+    return &ggml_backend_blas_device;
+
+    GGML_UNUSED(reg);
+    GGML_UNUSED(index);
+}
+
+static void * ggml_backend_blas_get_proc_address(ggml_backend_reg_t reg, const char * name) {
+    if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
+        return (void *)ggml_backend_blas_set_n_threads;
+    }
+    return NULL;
+
+    GGML_UNUSED(reg);
+    GGML_UNUSED(name);
+}
+
+static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = {
+    /* .get_name         = */ ggml_backend_blas_reg_get_name,
+    /* .get_device_count = */ ggml_backend_blas_reg_get_device_count,
+    /* .get_device       = */ ggml_backend_blas_reg_get_device,
+    /* .get_proc_address = */ ggml_backend_blas_get_proc_address,
+};
+
+ggml_backend_reg_t ggml_backend_blas_reg(void) {
+    static struct ggml_backend_reg ggml_backend_blas_reg = {
+        /* .api_version = */ GGML_BACKEND_API_VERSION,
+        /* .iface       = */ ggml_backend_blas_reg_i,
+        /* .context     = */ NULL,
+    };
+
+    return &ggml_backend_blas_reg;
+}
+
+GGML_BACKEND_DL_IMPL(ggml_backend_blas_reg)
+
+#endif // GGML_USE_BLAS

+ 6 - 4
llama/ggml-blas.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -35,13 +35,15 @@ extern "C" {
 #endif
 
 // backend API
-GGML_API GGML_CALL ggml_backend_t ggml_backend_blas_init(void);
+GGML_BACKEND_API ggml_backend_t ggml_backend_blas_init(void);
 
-GGML_API GGML_CALL bool ggml_backend_is_blas(ggml_backend_t backend);
+GGML_BACKEND_API bool ggml_backend_is_blas(ggml_backend_t backend);
 
 // number of threads used for conversion to float
 // for openblas and blis, this will also set the number of threads used for blas operations
-GGML_API GGML_CALL void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads);
+GGML_BACKEND_API void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads);
+
+GGML_BACKEND_API ggml_backend_reg_t ggml_backend_blas_reg(void);
 
 
 #ifdef  __cplusplus

+ 7 - 1
llama/ggml-common.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -444,6 +444,12 @@ typedef struct {
 } block_iq4_xs;
 static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
 
+typedef struct {
+    ggml_half d[4];        // deltas for 4 iq4_nl blocks
+    uint8_t qs[QK4_NL * 2];// nibbles / quants for 4 iq4_nl blocks
+} block_iq4_nlx4;
+static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
+
 #endif // GGML_COMMON_DECL
 #endif // GGML_COMMON_DECL
 

+ 64 - 0
llama/ggml-cpp.h

@@ -0,0 +1,64 @@
+/**
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#ifndef __cplusplus
+#error "This header is for C++ only"
+#endif
+
+#include "ggml.h"
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
+#include <memory>
+
+// Smart pointers for ggml types
+
+// ggml
+
+struct ggml_context_deleter { void operator()(ggml_context * ctx) { ggml_free(ctx); } };
+struct gguf_context_deleter { void operator()(gguf_context * ctx) { gguf_free(ctx); } };
+
+typedef std::unique_ptr<ggml_context, ggml_context_deleter> ggml_context_ptr;
+typedef std::unique_ptr<gguf_context, gguf_context_deleter> gguf_context_ptr;
+
+// ggml-alloc
+
+struct ggml_gallocr_deleter { void operator()(ggml_gallocr_t galloc) { ggml_gallocr_free(galloc); } };
+
+typedef std::unique_ptr<ggml_gallocr_t, ggml_gallocr_deleter> ggml_gallocr_ptr;
+
+// ggml-backend
+
+struct ggml_backend_deleter        { void operator()(ggml_backend_t backend)       { ggml_backend_free(backend); } };
+struct ggml_backend_buffer_deleter { void operator()(ggml_backend_buffer_t buffer) { ggml_backend_buffer_free(buffer); } };
+struct ggml_backend_event_deleter  { void operator()(ggml_backend_event_t event)   { ggml_backend_event_free(event); } };
+struct ggml_backend_sched_deleter  { void operator()(ggml_backend_sched_t sched)   { ggml_backend_sched_free(sched); } };
+
+typedef std::unique_ptr<ggml_backend,        ggml_backend_deleter>        ggml_backend_ptr;
+typedef std::unique_ptr<ggml_backend_buffer, ggml_backend_buffer_deleter> ggml_backend_buffer_ptr;
+typedef std::unique_ptr<ggml_backend_event,  ggml_backend_event_deleter>  ggml_backend_event_ptr;
+typedef std::unique_ptr<ggml_backend_sched,  ggml_backend_sched_deleter>  ggml_backend_sched_ptr;

+ 3849 - 0
llama/ggml-cpu-aarch64.c

@@ -0,0 +1,3849 @@
+/**
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#define GGML_COMMON_IMPL_C
+#include "ggml-common.h"
+
+#include "ggml-quants.h"
+#include "ggml-impl.h"
+#include "ggml-cpu.h"
+#include "ggml-cpu-impl.h"
+
+#include <math.h>
+#include <string.h>
+#include <assert.h>
+#include <float.h>
+#include <stdlib.h> // for qsort
+#include <stdio.h>  // for GGML_ASSERT
+
+#include "ggml-cpu-aarch64.h"
+
+#if defined(__GNUC__)
+#pragma GCC diagnostic ignored "-Woverlength-strings"
+#elif defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+#define UNUSED GGML_UNUSED
+
+// Functions to create the interleaved data layout formats
+
+// interleave 4 block_q4_0s in blocks of blck_size_interleave
+// returns an interleaved block_q4_0x4
+// in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks
+// first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave
+//
+// - in                  : an array of block_q4_0 pointers
+// - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of
+//                         blck_size_interleave bytes
+// - xor_mask            : the mask to convert the nibbles in block_q4_0 quants bytes
+//                         from bias offset form to pure sign form (this saves subtract
+//                         operations durin unpacking)
+//
+#if defined(__AVX__)
+#if defined(__F16C__)
+#if defined(__AVX512F__)
+#define GGML_F32Cx8x2_LOAD(x, y)     _mm512_cvtph_ps(_mm256_set_m128i(_mm_loadu_si128((const __m128i *)(y)), _mm_loadu_si128((const __m128i *)(x))))
+#define GGML_F32Cx16_REPEAT_LOAD(x)  _mm512_cvtph_ps(_mm256_set_m128i(x, x))
+#endif
+// the  _mm256_cvt intrinsics require F16C
+#define GGML_F32Cx8_LOAD(x)     _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
+#define GGML_F32Cx8_REPEAT_LOAD(x, loadMask)     _mm256_cvtph_ps(_mm_shuffle_epi32(_mm_maskload_epi32((int const*)(x), loadMask), 68))
+#define GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask)     _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask))
+#else
+#if defined(__AVX512F__)
+static inline __m512 __avx512_f32cx8x2_load(ggml_fp16_t *x, ggml_fp16_t *y) {
+    float tmp[16];
+
+    for (int i = 0; i < 8; i++) {
+        tmp[i] = GGML_FP16_TO_FP32(x[i]);
+    }
+
+    for (int i = 0; i < 8; i++) {
+        tmp[i + 8] = GGML_FP16_TO_FP32(y[i]);
+    }
+
+    return _mm512_loadu_ps(tmp);
+}
+static inline __m512 __avx512_repeat_f32cx16_load(__m128i x) {
+    float tmp[16];
+    uint16_t tmphalf[8];
+    _mm_storeu_si128((__m128i*)tmphalf, x);
+
+    for (int i = 0; i < 4; i++) {
+        tmp[i] = GGML_FP16_TO_FP32(tmphalf[i]);
+        tmp[i + 4] = GGML_FP16_TO_FP32(tmphalf[i]);
+        tmp[i + 8] = GGML_FP16_TO_FP32(tmphalf[i]);
+        tmp[i + 12] = GGML_FP16_TO_FP32(tmphalf[i]);
+    }
+
+    return _mm512_loadu_ps(tmp);
+}
+#endif
+static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
+    float tmp[8];
+
+    for (int i = 0; i < 8; i++) {
+        tmp[i] = GGML_FP16_TO_FP32(x[i]);
+    }
+
+    return _mm256_loadu_ps(tmp);
+}
+static inline __m256 __avx_repeat_f32cx8_load(ggml_fp16_t *x) {
+    float tmp[8];
+
+    for (int i = 0; i < 4; i++) {
+        tmp[i] = GGML_FP16_TO_FP32(x[i]);
+        tmp[i + 4] = GGML_FP16_TO_FP32(x[i]);
+    }
+
+    return _mm256_loadu_ps(tmp);
+}
+static inline __m256 __avx_rearranged_f32cx8_load(ggml_fp16_t *x, __m128i arrangeMask) {
+    uint16_t tmphalf[8];
+    float tmp[8];
+
+    _mm_storeu_si128((__m128i*)tmphalf, _mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask));
+    for (int i = 0; i < 8; i++) {
+        tmp[i] = GGML_FP16_TO_FP32(tmphalf[i]);
+    }
+
+    return _mm256_loadu_ps(tmp);
+}
+
+#define GGML_F32Cx8_LOAD(x)     __avx_f32cx8_load(x)
+#define GGML_F32Cx8_REPEAT_LOAD(x, loadMask)     __avx_repeat_f32cx8_load(x)
+#define GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask)     __avx_rearranged_f32cx8_load(x, arrangeMask)
+#if defined(__AVX512F__)
+#define GGML_F32Cx8x2_LOAD(x, y)     __avx512_f32cx8x2_load(x, y)
+#define GGML_F32Cx16_REPEAT_LOAD(x)  __avx512_repeat_f32cx16_load(x)
+#endif
+#endif
+#endif
+
+
+#if defined(__AVX2__) || defined(__AVX512F__)
+#if defined(__AVX512F__)
+// add int16_t pairwise and return as 512 bit int vector
+static inline __m512i sum_i16_pairs_int_32x16(const __m512i x) {
+    const __m512i ones = _mm512_set1_epi16(1);
+    return _mm512_madd_epi16(ones, x);
+}
+
+static inline __m512i mul_sum_us8_pairs_int32x16(const __m512i ax, const __m512i sy) {
+#if defined(__AVX512VNNI__)
+    const __m512i zero = _mm512_setzero_si512();
+    return _mm512_dpbusd_epi32(zero, ax, sy);
+#else
+    // Perform multiplication and create 16-bit values
+    const __m512i dot = _mm512_maddubs_epi16(ax, sy);
+    return sum_i16_pairs_int_32x16(dot);
+#endif
+}
+
+// multiply int8_t, add results pairwise twice and return as 512 bit int vector
+static inline __m512i mul_sum_i8_pairs_int32x16(const __m512i x, const __m512i y) {
+    const __m512i zero = _mm512_setzero_si512();
+    // Get absolute values of x vectors
+    const __m512i ax = _mm512_abs_epi8(x);
+    // Sign the values of the y vectors
+    __mmask64 blt0 = _mm512_movepi8_mask(x);
+    const __m512i sy = _mm512_mask_sub_epi8(y, blt0, zero, y);
+    return mul_sum_us8_pairs_int32x16(ax, sy);
+}
+#endif
+
+// add int16_t pairwise and return as 256 bit int vector
+static inline __m256i sum_i16_pairs_int32x8(const __m256i x) {
+    const __m256i ones = _mm256_set1_epi16(1);
+    return _mm256_madd_epi16(ones, x);
+}
+
+static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) {
+#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
+    const __m256i zero = _mm256_setzero_si256();
+    return _mm256_dpbusd_epi32(zero, ax, sy);
+#else
+    // Perform multiplication and create 16-bit values
+    const __m256i dot = _mm256_maddubs_epi16(ax, sy);
+    return sum_i16_pairs_int32x8(dot);
+#endif
+}
+
+// Integer variant of the function defined in ggml-quants.c
+// multiply int8_t, add results pairwise twice and return as 256 bit int vector
+static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y) {
+#if __AVXVNNIINT8__
+    const __m256i zero = _mm256_setzero_si256();
+    return _mm256_dpbssd_epi32(zero, x, y);
+#else
+    // Get absolute values of x vectors
+    const __m256i ax = _mm256_sign_epi8(x, x);
+    // Sign the values of the y vectors
+    const __m256i sy = _mm256_sign_epi8(y, x);
+    return mul_sum_us8_pairs_int32x8(ax, sy);
+#endif
+}
+#endif
+
+static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
+
+static void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) {
+    assert(QK8_0 == 32);
+    assert(k % QK8_0 == 0);
+    const int nb = k / QK8_0;
+
+    block_q8_0x4 * restrict y = (block_q8_0x4 *) vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t srcv[4][8];
+    float id[4];
+
+    for (int i = 0; i < nb; i++) {
+        float32x4_t asrcv[8];
+        float32x4_t amaxv[8];
+
+        for (int row_iter = 0; row_iter < 4; row_iter++) {
+            for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j);
+            for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
+
+            for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
+            for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
+            for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
+
+            const float amax = vmaxvq_f32(amaxv[0]);
+
+            const float d = amax / ((1 << 7) - 1);
+            id[row_iter] = d ? 1.0f / d : 0.0f;
+
+            y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
+        }
+
+        for (int j = 0; j < 8; j++) {
+            float32x4_t v = vmulq_n_f32(srcv[0][j], id[0]);
+            int32x4_t vi = vcvtnq_s32_f32(v);
+            y[i].qs[16 * j + 0] = vgetq_lane_s32(vi, 0);
+            y[i].qs[16 * j + 1] = vgetq_lane_s32(vi, 1);
+            y[i].qs[16 * j + 2] = vgetq_lane_s32(vi, 2);
+            y[i].qs[16 * j + 3] = vgetq_lane_s32(vi, 3);
+
+            v = vmulq_n_f32(srcv[1][j], id[1]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[16 * j + 4] = vgetq_lane_s32(vi, 0);
+            y[i].qs[16 * j + 5] = vgetq_lane_s32(vi, 1);
+            y[i].qs[16 * j + 6] = vgetq_lane_s32(vi, 2);
+            y[i].qs[16 * j + 7] = vgetq_lane_s32(vi, 3);
+
+            v = vmulq_n_f32(srcv[2][j], id[2]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[16 * j + 8] = vgetq_lane_s32(vi, 0);
+            y[i].qs[16 * j + 9] = vgetq_lane_s32(vi, 1);
+            y[i].qs[16 * j + 10] = vgetq_lane_s32(vi, 2);
+            y[i].qs[16 * j + 11] = vgetq_lane_s32(vi, 3);
+
+            v = vmulq_n_f32(srcv[3][j], id[3]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[16 * j + 12] = vgetq_lane_s32(vi, 0);
+            y[i].qs[16 * j + 13] = vgetq_lane_s32(vi, 1);
+            y[i].qs[16 * j + 14] = vgetq_lane_s32(vi, 2);
+            y[i].qs[16 * j + 15] = vgetq_lane_s32(vi, 3);
+        }
+    }
+#else
+    // scalar
+    const int blck_size_interleave = 4;
+    float srcv[4][QK8_0];
+    float id[4];
+
+    for (int i = 0; i < nb; i++) {
+        for (int row_iter = 0; row_iter < 4; row_iter++) {
+            float amax = 0.0f; // absolute max
+
+            for (int j = 0; j < QK8_0; j++) {
+                srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
+                amax = MAX(amax, fabsf(srcv[row_iter][j]));
+            }
+
+            const float d = amax / ((1 << 7) - 1);
+            id[row_iter] = d ? 1.0f / d : 0.0f;
+
+            y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
+        }
+
+        for (int j = 0; j < QK8_0 * 4; j++) {
+            int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
+            int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
+            src_offset += (j % blck_size_interleave);
+
+            float x0 = srcv[src_id][src_offset] * id[src_id];
+            y[i].qs[j] = roundf(x0);
+        }
+    }
+#endif
+}
+
+static void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k) {
+    assert(QK8_0 == 32);
+    assert(k % QK8_0 == 0);
+    const int nb = k / QK8_0;
+
+    block_q8_0x4 * restrict y = (block_q8_0x4 *) vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t srcv[4][8];
+    float id[4];
+
+    for (int i = 0; i < nb; i++) {
+        float32x4_t asrcv[8];
+        float32x4_t amaxv[8];
+
+        for (int row_iter = 0; row_iter < 4; row_iter++) {
+            for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j);
+            for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
+
+            for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
+            for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
+            for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
+
+            const float amax = vmaxvq_f32(amaxv[0]);
+
+            const float d = amax / ((1 << 7) - 1);
+            id[row_iter] = d ? 1.0f / d : 0.0f;
+
+            y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
+        }
+
+        for (int j = 0; j < 4; j++) {
+            float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]);
+            int32x4_t vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 0] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 1] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 2] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 3] = vgetq_lane_s32(vi, 3);
+            v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 4] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 5] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 6] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 7] = vgetq_lane_s32(vi, 3);
+
+            v = vmulq_n_f32(srcv[1][2 * j], id[1]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 8] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 9] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 10] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 11] = vgetq_lane_s32(vi, 3);
+            v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 12] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 13] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 14] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 15] = vgetq_lane_s32(vi, 3);
+
+            v = vmulq_n_f32(srcv[2][2 * j], id[2]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 16] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 17] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 18] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 19] = vgetq_lane_s32(vi, 3);
+            v = vmulq_n_f32(srcv[2][2 * j + 1], id[2]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 20] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 21] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 22] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 23] = vgetq_lane_s32(vi, 3);
+
+            v = vmulq_n_f32(srcv[3][2 * j], id[3]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 24] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 25] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 26] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 27] = vgetq_lane_s32(vi, 3);
+            v = vmulq_n_f32(srcv[3][2 * j + 1], id[3]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 28] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 29] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 30] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3);
+        }
+    }
+#elif defined(__AVX2__) || defined(__AVX__)
+    float id[4];
+    __m256 srcv[4][4];
+    __m256 idvec[4];
+
+    for (int i = 0; i < nb; i++) {
+        for (int row_iter = 0; row_iter < 4; row_iter++) {
+            // Load elements into 4 AVX vectors
+            __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 32 );
+            __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 8 );
+            __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 16 );
+            __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 24 );
+
+            // Compute max(abs(e)) for the block
+            const __m256 signBit = _mm256_set1_ps( -0.0f );
+            __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+            maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+            maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+            maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+            __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+            max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+            max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+            const float maxScalar = _mm_cvtss_f32( max4 );
+
+            // Divided by 127.f to mirror results in quantize_row_q8_0
+            const float d = maxScalar  / 127.f;
+            id[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; //d ? 1.0f / d : 0.0f;
+
+            // Store the scale for the individual block
+            y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
+
+            // Store the values in blocks of eight values - Aim is to use these later for block interleaving
+            srcv[row_iter][0] = v0;
+            srcv[row_iter][1] = v1;
+            srcv[row_iter][2] = v2;
+            srcv[row_iter][3] = v3;
+            idvec[row_iter] = _mm256_set1_ps(id[row_iter]);
+        }
+
+        // The loop iterates four times - The aim is to get 4 corresponding chunks of eight bytes from the original weight blocks that are interleaved
+        for (int j = 0; j < 4; j++) {
+            // Apply the multiplier
+            __m256 v0 = _mm256_mul_ps(srcv[0][j], idvec[0]);
+            __m256 v1 = _mm256_mul_ps(srcv[1][j], idvec[1]);
+            __m256 v2 = _mm256_mul_ps(srcv[2][j], idvec[2]);
+            __m256 v3 = _mm256_mul_ps(srcv[3][j], idvec[3]);
+
+            // Round to nearest integer
+            v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+            v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+            v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+            v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+            // Convert floats to integers
+            __m256i i0 = _mm256_cvtps_epi32( v0 );
+            __m256i i1 = _mm256_cvtps_epi32( v1 );
+            __m256i i2 = _mm256_cvtps_epi32( v2 );
+            __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+#if defined(__AVX2__)
+            // Convert int32 to int16
+            i0 = _mm256_packs_epi32( i0, i1 );
+            i2 = _mm256_packs_epi32( i2, i3 );
+            // Convert int16 to int8
+            i0 = _mm256_packs_epi16( i0, i2 );
+
+            //  Permute and store the quantized weights in the required order after the pack instruction
+            const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+            i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+            _mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0);
+#else
+            // Since we don't have in AVX some necessary functions,
+            // we split the registers in half and call AVX2 analogs from SSE
+            __m128i ni0 = _mm256_castsi256_si128( i0 );
+            __m128i ni1 = _mm256_extractf128_si256( i0, 1);
+            __m128i ni2 = _mm256_castsi256_si128( i1 );
+            __m128i ni3 = _mm256_extractf128_si256( i1, 1);
+            __m128i ni4 = _mm256_castsi256_si128( i2 );
+            __m128i ni5 = _mm256_extractf128_si256( i2, 1);
+            __m128i ni6 = _mm256_castsi256_si128( i3 );
+            __m128i ni7 = _mm256_extractf128_si256( i3, 1);
+
+            // Convert int32 to int16
+            ni0 = _mm_packs_epi32( ni0, ni1 );
+            ni2 = _mm_packs_epi32( ni2, ni3 );
+            ni4 = _mm_packs_epi32( ni4, ni5 );
+            ni6 = _mm_packs_epi32( ni6, ni7 );
+            // Convert int16 to int8
+            ni0 = _mm_packs_epi16( ni0, ni2 );
+            ni4 = _mm_packs_epi16( ni4, ni6 );
+            _mm_storeu_si128((__m128i *)(y[i].qs + 32 * j), ni0);
+            _mm_storeu_si128((__m128i *)(y[i].qs + 32 * j + 16), ni4);
+#endif
+        }
+    }
+#else
+    // scalar
+    const int blck_size_interleave = 8;
+    float srcv[4][QK8_0];
+    float id[4];
+
+    for (int i = 0; i < nb; i++) {
+        for (int row_iter = 0; row_iter < 4; row_iter++) {
+            float amax = 0.0f; // absolute max
+
+            for (int j = 0; j < QK8_0; j++) {
+                srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
+                amax = MAX(amax, fabsf(srcv[row_iter][j]));
+            }
+
+            const float d = amax / ((1 << 7) - 1);
+            id[row_iter] = d ? 1.0f / d : 0.0f;
+
+            y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
+        }
+
+        for (int j = 0; j < QK8_0 * 4; j++) {
+            int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
+            int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
+            src_offset += (j % blck_size_interleave);
+
+            float x0 = srcv[src_id][src_offset] * id[src_id];
+            y[i].qs[j] = roundf(x0);
+        }
+    }
+#endif
+}
+
+void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
+    assert(nrow == 4);
+    UNUSED(nrow);
+    if (blck_size_interleave == 4) {
+        quantize_q8_0_4x4(x, vy, n_per_row);
+    } else if (blck_size_interleave == 8) {
+        quantize_q8_0_4x8(x, vy, n_per_row);
+    } else {
+        assert(false);
+    }
+}
+
+void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+        const block_q4_0x4 * b_ptr = (const block_q4_0x4 *)vx;
+
+        for (int c = 0; c < nc; c += ncols_interleaved) {
+            const block_q8_0 * a_ptr = (const block_q8_0 *)vy;
+            float32x4_t acc = vdupq_n_f32(0);
+            for (int b = 0; b < nb; b++) {
+                int8x16_t b0 = vld1q_s8((const int8_t *)b_ptr->qs);
+                int8x16_t b1 = vld1q_s8((const int8_t *)b_ptr->qs + 16);
+                int8x16_t b2 = vld1q_s8((const int8_t *)b_ptr->qs + 32);
+                int8x16_t b3 = vld1q_s8((const int8_t *)b_ptr->qs + 48);
+                float16x4_t bd = vld1_f16((const __fp16 *)b_ptr->d);
+
+                int8x16_t a0 = vld1q_s8(a_ptr->qs);
+                int8x16_t a1 = vld1q_s8(a_ptr->qs + qk/2);
+                float16x4_t ad = vld1_dup_f16((const __fp16 *)&a_ptr->d);
+
+                int32x4_t ret = vdupq_n_s32(0);
+
+                ret = vdotq_laneq_s32(ret, b0 << 4, a0, 0);
+                ret = vdotq_laneq_s32(ret, b1 << 4, a0, 1);
+                ret = vdotq_laneq_s32(ret, b2 << 4, a0, 2);
+                ret = vdotq_laneq_s32(ret, b3 << 4, a0, 3);
+
+                ret = vdotq_laneq_s32(ret, b0 & 0xf0U, a1, 0);
+                ret = vdotq_laneq_s32(ret, b1 & 0xf0U, a1, 1);
+                ret = vdotq_laneq_s32(ret, b2 & 0xf0U, a1, 2);
+                ret = vdotq_laneq_s32(ret, b3 & 0xf0U, a1, 3);
+
+                acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4),
+                                vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
+                a_ptr++;
+                b_ptr++;
+            }
+            vst1q_f32(s, acc);
+            s += ncols_interleaved;
+        }
+        return;
+    }
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    float sumf[4];
+    int sumi;
+
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
+                    }
+                    sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+    }
+}
+
+void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 8;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+        const void * b_ptr = vx;
+        const void * a_ptr = vy;
+        float * res_ptr = s;
+
+        __asm__ __volatile__(
+            "movi v2.16b, #0x4\n"
+            "movi v1.16b, #0xf0\n"
+            "add %x[b_ptr], %x[b_ptr], #0x8\n"
+            "1:"  // Column loop
+            "add x23, %x[a_ptr], #0x2\n"
+            "movi v0.16b, #0x0\n"
+            "mov x22, %x[nb]\n"
+            "2:"  // Block loop
+            "ldr q31, [%x[b_ptr], #0x0]\n"
+            "ldr q30, [%x[b_ptr], #0x10]\n"
+            "mov x21, x23\n"
+            "movi v29.4s, #0x0\n"
+            "ldr q28, [%x[b_ptr], #0x20]\n"
+            "ldr q27, [%x[b_ptr], #0x30]\n"
+            "movi v26.4s, #0x0\n"
+            "sub x20, x23, #0x2\n"
+            "ld1r { v25.8h }, [x20]\n"
+            "ldr q24, [%x[b_ptr], #-0x8]\n"
+            "sub x22, x22, #0x1\n"
+            "add x23, x23, #0x22\n"
+            "ld1r { v23.2d }, [x21], #0x8\n"
+            "sshl v22.16b, v31.16b, v2.16b\n"
+            "sshl v16.16b, v30.16b, v2.16b\n"
+            "add %x[b_ptr], %x[b_ptr], #0x48\n"
+            "ld1r { v21.2d }, [x21], #0x8\n"
+            "sshl v20.16b, v28.16b, v2.16b\n"
+            "sshl v19.16b, v27.16b, v2.16b\n"
+            "ld1r { v18.2d }, [x21], #0x8\n"
+            "ld1r { v17.2d }, [x21], #0x8\n"
+            "and v31.16b, v31.16b, v1.16b\n"
+            "and v30.16b, v30.16b, v1.16b\n"
+            ".inst 0x4e9796dd  // sdot v29.4s, v22.16b, v23.16b\n"
+            ".inst 0x4e97961a  // sdot v26.4s, v16.16b, v23.16b\n"
+            "and v28.16b, v28.16b, v1.16b\n"
+            "and v27.16b, v27.16b, v1.16b\n"
+            "fcvtl v25.4s, v25.4h\n"
+            "fcvtl v16.4s, v24.4h\n"
+            ".inst 0x4e95969d  // sdot v29.4s, v20.16b, v21.16b\n"
+            ".inst 0x4e95967a  // sdot v26.4s, v19.16b, v21.16b\n"
+            "fmul v16.4s, v16.4s, v25.4s\n"
+            ".inst 0x4e9297fd  // sdot v29.4s, v31.16b, v18.16b\n"
+            ".inst 0x4e9297da  // sdot v26.4s, v30.16b, v18.16b\n"
+            ".inst 0x4e91979d  // sdot v29.4s, v28.16b, v17.16b\n"
+            ".inst 0x4e91977a  // sdot v26.4s, v27.16b, v17.16b\n"
+            "addp v29.4s, v29.4s, v26.4s\n"
+            "scvtf v29.4s, v29.4s, #0x4\n"
+            "fmla v0.4s, v29.4s, v16.4s\n"
+            "cbnz x22, 2b\n"
+            "sub %x[nc], %x[nc], #0x4\n"
+            "str q0, [%x[res_ptr], #0x0]\n"
+            "add %x[res_ptr], %x[res_ptr], #0x10\n"
+            "cbnz %x[nc], 1b\n"
+            : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
+            : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
+            : "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23"
+        );
+        return;
+    }
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    float sumf[4];
+    int sumi;
+
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
+                    }
+                    sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+    }
+}
+
+void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 8;
+    const int blocklen = 8;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
+#if defined(__ARM_FEATURE_SVE)
+    if (ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) {
+        const void * b_ptr = vx;
+        const void * a_ptr = vy;
+        float * res_ptr = s;
+
+        __asm__ __volatile__(
+            "ptrue p0.b\n"
+            "add %x[b_ptr], %x[b_ptr], #0x10\n"
+            "1:"  // Column loop
+            "add x22, %x[a_ptr], #0x2\n"
+            "mov z31.b, #0x0\n"
+            "mov x21, %x[nb]\n"
+            "2:"  // Block loop
+            "ld1b { z30.b }, p0/Z, [%x[b_ptr]]\n"
+            "ld1b { z29.b }, p0/Z, [%x[b_ptr], #1, MUL VL]\n"
+            "mov z28.s, #0x0\n"
+            "mov z27.s, #0x0\n"
+            "ld1rd { z26.d }, p0/Z, [x22]\n"
+            "ld1b { z25.b }, p0/Z, [%x[b_ptr], #2, MUL VL]\n"
+            "sub x20, x22, #0x2\n"
+            "sub x21, x21, #0x1\n"
+            "ld1b { z24.b }, p0/Z, [%x[b_ptr], #3, MUL VL]\n"
+            "ld1rd { z23.d }, p0/Z, [x22, #8]\n"
+            "lsl z22.b, z30.b, #0x4\n"
+            "lsl z16.b, z29.b, #0x4\n"
+            "and z30.b, z30.b, #0xf0\n"
+            "and z29.b, z29.b, #0xf0\n"
+            "ld1rd { z21.d }, p0/Z, [x22, #16]\n"
+            "ld1rd { z20.d }, p0/Z, [x22, #24]\n"
+            "lsl z19.b, z25.b, #0x4\n"
+            "and z25.b, z25.b, #0xf0\n"
+            "ld1rh { z17.h }, p0/Z, [x20]\n"
+            "ld1h { z18.s }, p0/Z, [%x[b_ptr], #-1, MUL VL]\n"
+            "sdot z28.s, z22.b, z26.b\n"
+            "sdot z27.s, z16.b, z26.b\n"
+            "lsl z16.b, z24.b, #0x4\n"
+            "add x22, x22, #0x22\n"
+            "and z24.b, z24.b, #0xf0\n"
+            "add %x[b_ptr], %x[b_ptr], #0x90\n"
+            "fcvt z17.s, p0/m, z17.h\n"
+            "fcvt z18.s, p0/m, z18.h\n"
+            "sdot z28.s, z19.b, z23.b\n"
+            "sdot z27.s, z16.b, z23.b\n"
+            "fmul z18.s, z18.s, z17.s\n"
+            "sdot z28.s, z30.b, z21.b\n"
+            "sdot z27.s, z29.b, z21.b\n"
+            "sdot z28.s, z25.b, z20.b\n"
+            "sdot z27.s, z24.b, z20.b\n"
+            "uzp1 z17.s, z28.s, z27.s\n"
+            "uzp2 z16.s, z28.s, z27.s\n"
+            "add z17.s, z17.s, z16.s\n"
+            "asr z17.s, z17.s, #0x4\n"
+            "scvtf z17.s, p0/m, z17.s\n"
+            "fmla z31.s, p0/M, z17.s, z18.s\n"
+            "cbnz x21, 2b\n"
+            "sub %x[nc], %x[nc], #0x8\n"
+            "st1w { z31.s }, p0, [%x[res_ptr]]\n"
+            "add %x[res_ptr], %x[res_ptr], #0x20\n"
+            "cbnz %x[nc], 1b\n"
+            : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
+            : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
+            : "memory", "p0", "x20", "x21", "x22", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
+        );
+        return;
+    }
+#endif // #if defined(__ARM_FEATURE_SVE)
+#elif defined(__AVX2__)
+    // Lookup table to convert signed nibbles to signed bytes
+    __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
+    signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
+    __m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
+    __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
+
+    // Permute mask used for easier vector processing at later stages
+    const __m256i m4b = _mm256_set1_epi8(0x0F);
+
+    int64_t b_nb = n / QK4_0;
+
+    const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx;
+    const block_q8_0 * a_ptr_start = (const block_q8_0 *)vy;
+
+    // Process Q8_0 blocks one by one
+    for (int64_t y = 0; y < nr; y++) {
+
+        // Pointers to LHS blocks of block_q8_0 format
+        const block_q8_0 * a_ptr = a_ptr_start + (y * nb);
+
+        // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation
+        for (int64_t x = 0; x < nc / 8; x++) {
+
+            // Pointers to RHS blocks
+            const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
+
+            // Master FP accumulator
+            __m256 acc_row = _mm256_setzero_ps();
+
+            for (int64_t b = 0; b < nb; b++) {
+                // Load 8 blocks of Q4_0 interleaved as 8 bytes (B0 - B7)
+                const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
+                const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 1);
+                const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 2);
+                const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 3);
+
+                // 4-bit -> 8-bit - Sign is maintained
+                const __m256i rhs_vec_0123_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_0123_0, m4b)); // B0(0-7) B1(0-7) B2(0-7) B3(0-7)
+                const __m256i rhs_vec_4567_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_4567_0, m4b)); // B4(0-7) B5(0-7) B6(0-7) B7(0-7)
+                const __m256i rhs_vec_0123_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_0123_1, m4b)); // B0(8-15) B1(8-15) B2(8-15) B3(8-15)
+                const __m256i rhs_vec_4567_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_4567_1, m4b)); // B0(8-15) B1(8-15) B2(8-15) B3(8-15)
+
+                const __m256i rhs_vec_0123_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b)); // B0(16-23) B1(16-23) B2(16-23) B3(16-23)
+                const __m256i rhs_vec_4567_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b)); // B4(16-23) B5(16-23) B6(16-23) B7(16-23)
+                const __m256i rhs_vec_0123_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b)); // B0(24-31) B1(24-31) B2(24-31) B3(24-31)
+                const __m256i rhs_vec_4567_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b)); // B4(24-31) B5(24-31) B6(24-31) B7(24-31)
+
+                // Load the scale values for the 8 blocks interleaved in block_q4_0x8
+                const __m256 col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask);
+
+                // Load and convert to FP32 scale from block_q8_0
+                const __m256 row_scale_f32 = _mm256_set1_ps(GGML_FP16_TO_FP32(a_ptr[b].d));
+
+                // Load the block values in block_q8_0 in batches of 16 bytes and replicate the same across 256 bit vector
+                __m256i lhs_vec_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)a_ptr[b].qs));
+                __m256i lhs_vec_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16)));
+
+                lhs_vec_0 = _mm256_permute2f128_si256(lhs_vec_0, lhs_vec_0, 0); // A0 (0-15) A0(0-15)
+                lhs_vec_1 = _mm256_permute2f128_si256(lhs_vec_1, lhs_vec_1, 0); // A0 (16-31) A0(16-31))
+
+                __m256i iacc = _mm256_setzero_si256();
+
+                // Dot product done within 32 bit lanes and accumulated in the same vector
+                // B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3)
+                // B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7)
+                // ...........................................................................
+                // B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31)
+
+                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0)));
+                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85)));
+
+                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170)));
+                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255)));
+
+                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0)));
+                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85)));
+
+                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170)));
+                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255)));
+
+                // Accumulated values multipled with appropriate scales
+                acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);
+            }
+
+            // Accumulated output values permuted so as to be stored in appropriate order post accumulation
+            acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask);
+            _mm256_storeu_ps(s + (y * nr + x * 8), acc_row);
+        }
+    }
+    return;
+#elif defined(__riscv_v_intrinsic)
+    if (__riscv_vlenb() >= QK4_0) {
+        const size_t vl = QK4_0;
+
+        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
+
+            vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
+            for (int l = 0; l < nb; l++) {
+                const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0];
+                const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8];
+                const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16];
+                const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24];
+                __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
+                const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, vl / 4));
+                const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, vl / 4));
+                const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, vl / 4));
+                const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, vl / 4));
+
+                const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
+                const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
+                const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);
+                const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0);
+                const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1);
+                const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0);
+                const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1);
+
+                const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
+                const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
+                const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
+                const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
+
+                const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_hi_m));
+                const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
+                const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
+                const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
+                const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
+                const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
+                const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
+                const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
+                const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
+                const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
+                const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
+                const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
+                const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
+
+                // vector version needs Zvfhmin extension
+                const float a_scale = GGML_FP16_TO_FP32(a_ptr[l].d);
+                const float b_scales[8] = {
+                    GGML_FP16_TO_FP32(b_ptr[l].d[0]),
+                    GGML_FP16_TO_FP32(b_ptr[l].d[1]),
+                    GGML_FP16_TO_FP32(b_ptr[l].d[2]),
+                    GGML_FP16_TO_FP32(b_ptr[l].d[3]),
+                    GGML_FP16_TO_FP32(b_ptr[l].d[4]),
+                    GGML_FP16_TO_FP32(b_ptr[l].d[5]),
+                    GGML_FP16_TO_FP32(b_ptr[l].d[6]),
+                    GGML_FP16_TO_FP32(b_ptr[l].d[7])
+                };
+                const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4);
+                const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scale, vl / 4);
+                sumf = __riscv_vfmacc_vv_f32m1(sumf, tmp1, b_scales_vec, vl / 4);
+            }
+            __riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, vl / 4);
+        }
+        return;
+    }
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
+    {
+        float sumf[8];
+        int sumi;
+
+        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
+
+            for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    for (int j = 0; j < ncols_interleaved; j++) {
+                        sumi = 0;
+                        for (int i = 0; i < blocklen; ++i) {
+                            const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                            const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                            sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
+                        }
+                        sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
+                    }
+                }
+            }
+            for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+        }
+    }
+}
+
+void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+        const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
+        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+        float * res_ptr = s;
+
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
+
+            float32x4_t sumf = vdupq_n_f32(0);
+            for (int l = 0; l < nb; l++) {
+                uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
+                uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
+                uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
+                uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
+
+                int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4);
+                int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F);
+                int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4);
+                int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F);
+                int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4);
+                int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F);
+                int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4);
+                int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F);
+
+                int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
+                int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
+
+                int32x4_t sumi = vdupq_n_s32(0);
+                sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);
+                sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);
+                sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);
+                sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1);
+                sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2);
+                sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
+                sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
+                sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);
+
+                float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
+                float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
+                float32x4_t d = a_d * b_d;
+
+                sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi));
+            }
+
+            vst1q_f32(res_ptr + x * 4, sumf);
+        }
+        return;
+    }
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
+    {
+        float sumf[4];
+        int sumi;
+
+        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
+
+            for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    for (int j = 0; j < ncols_interleaved; j++) {
+                        sumi = 0;
+                        for (int i = 0; i < blocklen; ++i) {
+                            const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
+                            const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
+                            sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
+                        }
+                        sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
+                    }
+                }
+            }
+            for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+        }
+    }
+}
+
+void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
+    if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+        const void * b_ptr = vx;
+        const void * a_ptr = vy;
+        float * res_ptr = s;
+        size_t res_stride = bs * sizeof(float);
+
+        __asm__ __volatile__(
+            "mov x10, %x[nr]\n"
+            "mov x9, #0x88\n"
+            "cmp x10, #0x10\n"
+            "mul x9, %x[nb], x9\n"
+            "blt 4f\n"
+            "1:"  // Row loop
+            "add x28, %x[b_ptr], #0x8\n"
+            "mov x27, %x[nc]\n"
+            "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
+            "2:"  // Column loop
+            "add x25, %x[a_ptr], #0x8\n"
+            "movi v15.16b, #0x0\n"
+            "movi v19.16b, #0x0\n"
+            "mov x24, %x[nb]\n"
+            "add x23, x25, x9\n"
+            "movi v18.16b, #0x0\n"
+            "movi v14.16b, #0x0\n"
+            "add x22, x23, x9\n"
+            "movi v11.16b, #0x0\n"
+            "movi v13.16b, #0x0\n"
+            "add x21, x22, x9\n"
+            "movi v23.16b, #0x0\n"
+            "movi v16.16b, #0x0\n"
+            "movi v25.16b, #0x0\n"
+            "movi v7.16b, #0x0\n"
+            "movi v0.16b, #0x0\n"
+            "movi v4.16b, #0x0\n"
+            "movi v5.16b, #0x0\n"
+            "movi v21.16b, #0x0\n"
+            "movi v8.16b, #0x0\n"
+            "movi v1.16b, #0x0\n"
+            "3:"  // Block loop
+            "ldr q3, [x28, #0x0]\n"
+            "ldr q31, [x25, #0x0]\n"
+            "movi v28.16b, #0x4\n"
+            "movi v10.4s, #0x0\n"
+            "ldr q22, [x28, #0x10]\n"
+            "ldr q6, [x25, #0x10]\n"
+            "movi v29.4s, #0x0\n"
+            "movi v9.4s, #0x0\n"
+            "ldr q27, [x28, #0x20]\n"
+            "ldr q30, [x28, #0x30]\n"
+            "movi v20.4s, #0x0\n"
+            "movi v24.16b, #0xf0\n"
+            "ldr d2, [x25, #-0x8]\n"
+            "ldr d26, [x23, #-0x8]\n"
+            "sshl v12.16b, v3.16b, v28.16b\n"
+            "sub x20, x28, #0x8\n"
+            "ldr d17, [x20, #0x0]\n"
+            "and v3.16b, v3.16b, v24.16b\n"
+            "subs x24, x24, #0x1\n"
+            "add x28, x28, #0x48\n"
+            ".inst 0x4f9fe18a  // sdot v10.4s, v12.16b, v31.4b[0]\n"
+            ".inst 0x4fbfe19d  // sdot v29.4s, v12.16b, v31.4b[1]\n"
+            ".inst 0x4f9fe989  // sdot v9.4s, v12.16b, v31.4b[2]\n"
+            ".inst 0x4fbfe994  // sdot v20.4s, v12.16b, v31.4b[3]\n"
+            "sshl v31.16b, v22.16b, v28.16b\n"
+            "and v22.16b, v22.16b, v24.16b\n"
+            "fcvtl v17.4s, v17.4h\n"
+            "fcvtl v2.4s, v2.4h\n"
+            "fcvtl v26.4s, v26.4h\n"
+            ".inst 0x4f86e3ea  // sdot v10.4s, v31.16b, v6.4b[0]\n"
+            ".inst 0x4fa6e3fd  // sdot v29.4s, v31.16b, v6.4b[1]\n"
+            ".inst 0x4f86ebe9  // sdot v9.4s, v31.16b, v6.4b[2]\n"
+            ".inst 0x4fa6ebf4  // sdot v20.4s, v31.16b, v6.4b[3]\n"
+            "sshl v6.16b, v27.16b, v28.16b\n"
+            "sshl v28.16b, v30.16b, v28.16b\n"
+            "and v27.16b, v27.16b, v24.16b\n"
+            "and v30.16b, v30.16b, v24.16b\n"
+            "ldr q24, [x25, #0x20]\n"
+            ".inst 0x4f98e0ca  // sdot v10.4s, v6.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e0dd  // sdot v29.4s, v6.16b, v24.4b[1]\n"
+            ".inst 0x4f98e8c9  // sdot v9.4s, v6.16b, v24.4b[2]\n"
+            ".inst 0x4fb8e8d4  // sdot v20.4s, v6.16b, v24.4b[3]\n"
+            "ldr q24, [x25, #0x30]\n"
+            ".inst 0x4f98e38a  // sdot v10.4s, v28.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e39d  // sdot v29.4s, v28.16b, v24.4b[1]\n"
+            ".inst 0x4f98eb89  // sdot v9.4s, v28.16b, v24.4b[2]\n"
+            ".inst 0x4fb8eb94  // sdot v20.4s, v28.16b, v24.4b[3]\n"
+            "ldr q24, [x25, #0x40]\n"
+            ".inst 0x4f98e06a  // sdot v10.4s, v3.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e07d  // sdot v29.4s, v3.16b, v24.4b[1]\n"
+            ".inst 0x4f98e869  // sdot v9.4s, v3.16b, v24.4b[2]\n"
+            ".inst 0x4fb8e874  // sdot v20.4s, v3.16b, v24.4b[3]\n"
+            "ldr q24, [x25, #0x50]\n"
+            ".inst 0x4f98e2ca  // sdot v10.4s, v22.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e2dd  // sdot v29.4s, v22.16b, v24.4b[1]\n"
+            ".inst 0x4f98eac9  // sdot v9.4s, v22.16b, v24.4b[2]\n"
+            ".inst 0x4fb8ead4  // sdot v20.4s, v22.16b, v24.4b[3]\n"
+            "ldr q24, [x25, #0x60]\n"
+            ".inst 0x4f98e36a  // sdot v10.4s, v27.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e37d  // sdot v29.4s, v27.16b, v24.4b[1]\n"
+            ".inst 0x4f98eb69  // sdot v9.4s, v27.16b, v24.4b[2]\n"
+            ".inst 0x4fb8eb74  // sdot v20.4s, v27.16b, v24.4b[3]\n"
+            "ldr q24, [x25, #0x70]\n"
+            "add x25, x25, #0x88\n"
+            ".inst 0x4f98e3ca  // sdot v10.4s, v30.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e3dd  // sdot v29.4s, v30.16b, v24.4b[1]\n"
+            ".inst 0x4f98ebc9  // sdot v9.4s, v30.16b, v24.4b[2]\n"
+            ".inst 0x4fb8ebd4  // sdot v20.4s, v30.16b, v24.4b[3]\n"
+            "fmul v24.4s, v17.4s, v2.s[0]\n"
+            "scvtf v10.4s, v10.4s, #0x4\n"
+            "scvtf v29.4s, v29.4s, #0x4\n"
+            "scvtf v9.4s, v9.4s, #0x4\n"
+            "scvtf v20.4s, v20.4s, #0x4\n"
+            "fmla v15.4s, v10.4s, v24.4s\n"
+            "ldr q24, [x23, #0x0]\n"
+            "fmul v10.4s, v17.4s, v2.s[1]\n"
+            "fmla v19.4s, v29.4s, v10.4s\n"
+            "ldr q10, [x23, #0x10]\n"
+            "fmul v29.4s, v17.4s, v2.s[2]\n"
+            "fmul v2.4s, v17.4s, v2.s[3]\n"
+            "fmla v18.4s, v9.4s, v29.4s\n"
+            "movi v9.4s, #0x0\n"
+            "movi v29.4s, #0x0\n"
+            ".inst 0x4f98e189  // sdot v9.4s, v12.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e19d  // sdot v29.4s, v12.16b, v24.4b[1]\n"
+            "fmla v14.4s, v20.4s, v2.4s\n"
+            "movi v20.4s, #0x0\n"
+            "movi v2.4s, #0x0\n"
+            ".inst 0x4f98e994  // sdot v20.4s, v12.16b, v24.4b[2]\n"
+            ".inst 0x4fb8e982  // sdot v2.4s, v12.16b, v24.4b[3]\n"
+            "ldr q24, [x23, #0x20]\n"
+            ".inst 0x4f8ae3e9  // sdot v9.4s, v31.16b, v10.4b[0]\n"
+            ".inst 0x4faae3fd  // sdot v29.4s, v31.16b, v10.4b[1]\n"
+            ".inst 0x4f8aebf4  // sdot v20.4s, v31.16b, v10.4b[2]\n"
+            ".inst 0x4faaebe2  // sdot v2.4s, v31.16b, v10.4b[3]\n"
+            "ldr q10, [x23, #0x30]\n"
+            ".inst 0x4f98e0c9  // sdot v9.4s, v6.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e0dd  // sdot v29.4s, v6.16b, v24.4b[1]\n"
+            ".inst 0x4f98e8d4  // sdot v20.4s, v6.16b, v24.4b[2]\n"
+            ".inst 0x4fb8e8c2  // sdot v2.4s, v6.16b, v24.4b[3]\n"
+            "ldr q24, [x23, #0x40]\n"
+            ".inst 0x4f8ae389  // sdot v9.4s, v28.16b, v10.4b[0]\n"
+            ".inst 0x4faae39d  // sdot v29.4s, v28.16b, v10.4b[1]\n"
+            ".inst 0x4f8aeb94  // sdot v20.4s, v28.16b, v10.4b[2]\n"
+            ".inst 0x4faaeb82  // sdot v2.4s, v28.16b, v10.4b[3]\n"
+            "ldr q10, [x23, #0x50]\n"
+            ".inst 0x4f98e069  // sdot v9.4s, v3.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e07d  // sdot v29.4s, v3.16b, v24.4b[1]\n"
+            ".inst 0x4f98e874  // sdot v20.4s, v3.16b, v24.4b[2]\n"
+            ".inst 0x4fb8e862  // sdot v2.4s, v3.16b, v24.4b[3]\n"
+            "ldr q24, [x23, #0x60]\n"
+            ".inst 0x4f8ae2c9  // sdot v9.4s, v22.16b, v10.4b[0]\n"
+            ".inst 0x4faae2dd  // sdot v29.4s, v22.16b, v10.4b[1]\n"
+            ".inst 0x4f8aead4  // sdot v20.4s, v22.16b, v10.4b[2]\n"
+            ".inst 0x4faaeac2  // sdot v2.4s, v22.16b, v10.4b[3]\n"
+            "ldr q10, [x23, #0x70]\n"
+            "add x23, x23, #0x88\n"
+            ".inst 0x4f98e369  // sdot v9.4s, v27.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e37d  // sdot v29.4s, v27.16b, v24.4b[1]\n"
+            ".inst 0x4f98eb74  // sdot v20.4s, v27.16b, v24.4b[2]\n"
+            ".inst 0x4fb8eb62  // sdot v2.4s, v27.16b, v24.4b[3]\n"
+            "ldr q24, [x22, #0x0]\n"
+            ".inst 0x4f8ae3c9  // sdot v9.4s, v30.16b, v10.4b[0]\n"
+            ".inst 0x4faae3dd  // sdot v29.4s, v30.16b, v10.4b[1]\n"
+            ".inst 0x4f8aebd4  // sdot v20.4s, v30.16b, v10.4b[2]\n"
+            ".inst 0x4faaebc2  // sdot v2.4s, v30.16b, v10.4b[3]\n"
+            "fmul v10.4s, v17.4s, v26.s[0]\n"
+            "scvtf v9.4s, v9.4s, #0x4\n"
+            "scvtf v29.4s, v29.4s, #0x4\n"
+            "scvtf v20.4s, v20.4s, #0x4\n"
+            "scvtf v2.4s, v2.4s, #0x4\n"
+            "fmla v11.4s, v9.4s, v10.4s\n"
+            "ldr q9, [x22, #0x10]\n"
+            "fmul v10.4s, v17.4s, v26.s[1]\n"
+            "fmla v13.4s, v29.4s, v10.4s\n"
+            "ldr d29, [x22, #-0x8]\n"
+            "fmul v10.4s, v17.4s, v26.s[2]\n"
+            "fmul v26.4s, v17.4s, v26.s[3]\n"
+            "fcvtl v29.4s, v29.4h\n"
+            "fmla v23.4s, v20.4s, v10.4s\n"
+            "movi v20.4s, #0x0\n"
+            "movi v10.4s, #0x0\n"
+            "fmla v16.4s, v2.4s, v26.4s\n"
+            "movi v26.4s, #0x0\n"
+            "movi v2.4s, #0x0\n"
+            ".inst 0x4f98e194  // sdot v20.4s, v12.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e18a  // sdot v10.4s, v12.16b, v24.4b[1]\n"
+            ".inst 0x4f98e99a  // sdot v26.4s, v12.16b, v24.4b[2]\n"
+            ".inst 0x4fb8e982  // sdot v2.4s, v12.16b, v24.4b[3]\n"
+            "ldr q24, [x22, #0x20]\n"
+            ".inst 0x4f89e3f4  // sdot v20.4s, v31.16b, v9.4b[0]\n"
+            ".inst 0x4fa9e3ea  // sdot v10.4s, v31.16b, v9.4b[1]\n"
+            ".inst 0x4f89ebfa  // sdot v26.4s, v31.16b, v9.4b[2]\n"
+            ".inst 0x4fa9ebe2  // sdot v2.4s, v31.16b, v9.4b[3]\n"
+            "ldr q9, [x22, #0x30]\n"
+            ".inst 0x4f98e0d4  // sdot v20.4s, v6.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e0ca  // sdot v10.4s, v6.16b, v24.4b[1]\n"
+            ".inst 0x4f98e8da  // sdot v26.4s, v6.16b, v24.4b[2]\n"
+            ".inst 0x4fb8e8c2  // sdot v2.4s, v6.16b, v24.4b[3]\n"
+            "ldr q24, [x22, #0x40]\n"
+            ".inst 0x4f89e394  // sdot v20.4s, v28.16b, v9.4b[0]\n"
+            ".inst 0x4fa9e38a  // sdot v10.4s, v28.16b, v9.4b[1]\n"
+            ".inst 0x4f89eb9a  // sdot v26.4s, v28.16b, v9.4b[2]\n"
+            ".inst 0x4fa9eb82  // sdot v2.4s, v28.16b, v9.4b[3]\n"
+            "ldr q9, [x22, #0x50]\n"
+            ".inst 0x4f98e074  // sdot v20.4s, v3.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e06a  // sdot v10.4s, v3.16b, v24.4b[1]\n"
+            ".inst 0x4f98e87a  // sdot v26.4s, v3.16b, v24.4b[2]\n"
+            ".inst 0x4fb8e862  // sdot v2.4s, v3.16b, v24.4b[3]\n"
+            "ldr q24, [x22, #0x60]\n"
+            ".inst 0x4f89e2d4  // sdot v20.4s, v22.16b, v9.4b[0]\n"
+            ".inst 0x4fa9e2ca  // sdot v10.4s, v22.16b, v9.4b[1]\n"
+            ".inst 0x4f89eada  // sdot v26.4s, v22.16b, v9.4b[2]\n"
+            ".inst 0x4fa9eac2  // sdot v2.4s, v22.16b, v9.4b[3]\n"
+            "ldr q9, [x22, #0x70]\n"
+            "add x22, x22, #0x88\n"
+            ".inst 0x4f98e374  // sdot v20.4s, v27.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e36a  // sdot v10.4s, v27.16b, v24.4b[1]\n"
+            ".inst 0x4f98eb7a  // sdot v26.4s, v27.16b, v24.4b[2]\n"
+            ".inst 0x4fb8eb62  // sdot v2.4s, v27.16b, v24.4b[3]\n"
+            "ldr q24, [x21, #0x0]\n"
+            ".inst 0x4f89e3d4  // sdot v20.4s, v30.16b, v9.4b[0]\n"
+            ".inst 0x4fa9e3ca  // sdot v10.4s, v30.16b, v9.4b[1]\n"
+            ".inst 0x4f89ebda  // sdot v26.4s, v30.16b, v9.4b[2]\n"
+            ".inst 0x4fa9ebc2  // sdot v2.4s, v30.16b, v9.4b[3]\n"
+            "fmul v9.4s, v17.4s, v29.s[0]\n"
+            "scvtf v20.4s, v20.4s, #0x4\n"
+            "scvtf v10.4s, v10.4s, #0x4\n"
+            "scvtf v26.4s, v26.4s, #0x4\n"
+            "scvtf v2.4s, v2.4s, #0x4\n"
+            "fmla v25.4s, v20.4s, v9.4s\n"
+            "ldr q9, [x21, #0x10]\n"
+            "fmul v20.4s, v17.4s, v29.s[1]\n"
+            "fmla v7.4s, v10.4s, v20.4s\n"
+            "ldr d20, [x21, #-0x8]\n"
+            "fmul v10.4s, v17.4s, v29.s[2]\n"
+            "fmul v29.4s, v17.4s, v29.s[3]\n"
+            "fcvtl v20.4s, v20.4h\n"
+            "fmla v0.4s, v26.4s, v10.4s\n"
+            "movi v26.4s, #0x0\n"
+            "movi v10.4s, #0x0\n"
+            "fmla v4.4s, v2.4s, v29.4s\n"
+            "movi v2.4s, #0x0\n"
+            "movi v29.4s, #0x0\n"
+            ".inst 0x4f98e19a  // sdot v26.4s, v12.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e18a  // sdot v10.4s, v12.16b, v24.4b[1]\n"
+            ".inst 0x4f98e982  // sdot v2.4s, v12.16b, v24.4b[2]\n"
+            ".inst 0x4fb8e99d  // sdot v29.4s, v12.16b, v24.4b[3]\n"
+            "ldr q12, [x21, #0x20]\n"
+            "fmul v24.4s, v17.4s, v20.s[0]\n"
+            ".inst 0x4f89e3fa  // sdot v26.4s, v31.16b, v9.4b[0]\n"
+            ".inst 0x4fa9e3ea  // sdot v10.4s, v31.16b, v9.4b[1]\n"
+            ".inst 0x4f89ebe2  // sdot v2.4s, v31.16b, v9.4b[2]\n"
+            ".inst 0x4fa9ebfd  // sdot v29.4s, v31.16b, v9.4b[3]\n"
+            "ldr q9, [x21, #0x30]\n"
+            "fmul v31.4s, v17.4s, v20.s[1]\n"
+            ".inst 0x4f8ce0da  // sdot v26.4s, v6.16b, v12.4b[0]\n"
+            ".inst 0x4face0ca  // sdot v10.4s, v6.16b, v12.4b[1]\n"
+            ".inst 0x4f8ce8c2  // sdot v2.4s, v6.16b, v12.4b[2]\n"
+            ".inst 0x4face8dd  // sdot v29.4s, v6.16b, v12.4b[3]\n"
+            "ldr q12, [x21, #0x40]\n"
+            "fmul v6.4s, v17.4s, v20.s[2]\n"
+            "fmul v20.4s, v17.4s, v20.s[3]\n"
+            ".inst 0x4f89e39a  // sdot v26.4s, v28.16b, v9.4b[0]\n"
+            ".inst 0x4fa9e38a  // sdot v10.4s, v28.16b, v9.4b[1]\n"
+            ".inst 0x4f89eb82  // sdot v2.4s, v28.16b, v9.4b[2]\n"
+            ".inst 0x4fa9eb9d  // sdot v29.4s, v28.16b, v9.4b[3]\n"
+            "ldr q9, [x21, #0x50]\n"
+            ".inst 0x4f8ce07a  // sdot v26.4s, v3.16b, v12.4b[0]\n"
+            ".inst 0x4face06a  // sdot v10.4s, v3.16b, v12.4b[1]\n"
+            ".inst 0x4f8ce862  // sdot v2.4s, v3.16b, v12.4b[2]\n"
+            ".inst 0x4face87d  // sdot v29.4s, v3.16b, v12.4b[3]\n"
+            "ldr q12, [x21, #0x60]\n"
+            ".inst 0x4f89e2da  // sdot v26.4s, v22.16b, v9.4b[0]\n"
+            ".inst 0x4fa9e2ca  // sdot v10.4s, v22.16b, v9.4b[1]\n"
+            ".inst 0x4f89eac2  // sdot v2.4s, v22.16b, v9.4b[2]\n"
+            ".inst 0x4fa9eadd  // sdot v29.4s, v22.16b, v9.4b[3]\n"
+            "ldr q17, [x21, #0x70]\n"
+            "add x21, x21, #0x88\n"
+            ".inst 0x4f8ce37a  // sdot v26.4s, v27.16b, v12.4b[0]\n"
+            ".inst 0x4face36a  // sdot v10.4s, v27.16b, v12.4b[1]\n"
+            ".inst 0x4f8ceb62  // sdot v2.4s, v27.16b, v12.4b[2]\n"
+            ".inst 0x4faceb7d  // sdot v29.4s, v27.16b, v12.4b[3]\n"
+            ".inst 0x4f91e3da  // sdot v26.4s, v30.16b, v17.4b[0]\n"
+            ".inst 0x4fb1e3ca  // sdot v10.4s, v30.16b, v17.4b[1]\n"
+            ".inst 0x4f91ebc2  // sdot v2.4s, v30.16b, v17.4b[2]\n"
+            ".inst 0x4fb1ebdd  // sdot v29.4s, v30.16b, v17.4b[3]\n"
+            "scvtf v26.4s, v26.4s, #0x4\n"
+            "scvtf v10.4s, v10.4s, #0x4\n"
+            "fmla v5.4s, v26.4s, v24.4s\n"
+            "scvtf v2.4s, v2.4s, #0x4\n"
+            "scvtf v29.4s, v29.4s, #0x4\n"
+            "fmla v21.4s, v10.4s, v31.4s\n"
+            "fmla v8.4s, v2.4s, v6.4s\n"
+            "fmla v1.4s, v29.4s, v20.4s\n"
+            "bgt 3b\n"
+            "mov x20, %x[res_ptr]\n"
+            "subs x27, x27, #0x4\n"
+            "add %x[res_ptr], %x[res_ptr], #0x10\n"
+            "str q15, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q19, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q18, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q14, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q11, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q13, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q23, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q16, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q25, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q7, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q0, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q4, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q5, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q21, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q8, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q1, [x20, #0x0]\n"
+            "bne 2b\n"
+            "mov x20, #0x4\n"
+            "sub x10, x10, #0x10\n"
+            "cmp x10, #0x10\n"
+            "mov %x[res_ptr], x26\n"
+            "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
+            "bge 1b\n"
+            "4:"  // Row loop skip
+            "cbz x10, 9f\n"
+            "5:"  // Row tail: Row loop
+            "add x24, %x[b_ptr], #0x8\n"
+            "mov x23, %x[nc]\n"
+            "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
+            "6:"  // Row tail: Column loop
+            "movi v15.16b, #0x0\n"
+            "movi v19.16b, #0x0\n"
+            "add x25, %x[a_ptr], #0x8\n"
+            "mov x21, %x[nb]\n"
+            "movi v18.16b, #0x0\n"
+            "movi v14.16b, #0x0\n"
+            "7:"  // Row tail: Block loop
+            "ldr q7, [x24, #0x0]\n"
+            "ldr q5, [x25, #0x0]\n"
+            "movi v9.16b, #0x4\n"
+            "movi v4.4s, #0x0\n"
+            "ldr q3, [x24, #0x10]\n"
+            "ldr q2, [x25, #0x10]\n"
+            "movi v1.4s, #0x0\n"
+            "movi v0.4s, #0x0\n"
+            "ldr q13, [x24, #0x20]\n"
+            "ldr q31, [x25, #0x20]\n"
+            "movi v30.4s, #0x0\n"
+            "movi v29.16b, #0xf0\n"
+            "ldr q28, [x24, #0x30]\n"
+            "ldr q27, [x25, #0x30]\n"
+            "sshl v20.16b, v7.16b, v9.16b\n"
+            "sub x20, x24, #0x8\n"
+            "ldr q26, [x25, #0x40]\n"
+            "ldr q25, [x25, #0x50]\n"
+            "sshl v17.16b, v3.16b, v9.16b\n"
+            "and v7.16b, v7.16b, v29.16b\n"
+            "ldr q24, [x25, #0x60]\n"
+            "ldr q16, [x25, #0x70]\n"
+            "sshl v22.16b, v13.16b, v9.16b\n"
+            "and v3.16b, v3.16b, v29.16b\n"
+            "ldr d21, [x20, #0x0]\n"
+            "ldr d12, [x25, #-0x8]\n"
+            ".inst 0x4f85e284  // sdot v4.4s, v20.16b, v5.4b[0]\n"
+            ".inst 0x4fa5e281  // sdot v1.4s, v20.16b, v5.4b[1]\n"
+            ".inst 0x4f85ea80  // sdot v0.4s, v20.16b, v5.4b[2]\n"
+            ".inst 0x4fa5ea9e  // sdot v30.4s, v20.16b, v5.4b[3]\n"
+            "sshl v9.16b, v28.16b, v9.16b\n"
+            "subs x21, x21, #0x1\n"
+            "and v13.16b, v13.16b, v29.16b\n"
+            "and v28.16b, v28.16b, v29.16b\n"
+            "add x25, x25, #0x88\n"
+            "add x24, x24, #0x48\n"
+            "fcvtl v21.4s, v21.4h\n"
+            "fcvtl v12.4s, v12.4h\n"
+            ".inst 0x4f82e224  // sdot v4.4s, v17.16b, v2.4b[0]\n"
+            ".inst 0x4fa2e221  // sdot v1.4s, v17.16b, v2.4b[1]\n"
+            ".inst 0x4f82ea20  // sdot v0.4s, v17.16b, v2.4b[2]\n"
+            ".inst 0x4fa2ea3e  // sdot v30.4s, v17.16b, v2.4b[3]\n"
+            "fmul v11.4s, v21.4s, v12.s[0]\n"
+            "fmul v23.4s, v21.4s, v12.s[1]\n"
+            "fmul v17.4s, v21.4s, v12.s[2]\n"
+            ".inst 0x4f9fe2c4  // sdot v4.4s, v22.16b, v31.4b[0]\n"
+            "fmul v6.4s, v21.4s, v12.s[3]\n"
+            ".inst 0x4fbfe2c1  // sdot v1.4s, v22.16b, v31.4b[1]\n"
+            ".inst 0x4f9feac0  // sdot v0.4s, v22.16b, v31.4b[2]\n"
+            ".inst 0x4fbfeade  // sdot v30.4s, v22.16b, v31.4b[3]\n"
+            ".inst 0x4f9be124  // sdot v4.4s, v9.16b, v27.4b[0]\n"
+            ".inst 0x4fbbe121  // sdot v1.4s, v9.16b, v27.4b[1]\n"
+            ".inst 0x4f9be920  // sdot v0.4s, v9.16b, v27.4b[2]\n"
+            ".inst 0x4fbbe93e  // sdot v30.4s, v9.16b, v27.4b[3]\n"
+            ".inst 0x4f9ae0e4  // sdot v4.4s, v7.16b, v26.4b[0]\n"
+            ".inst 0x4fbae0e1  // sdot v1.4s, v7.16b, v26.4b[1]\n"
+            ".inst 0x4f9ae8e0  // sdot v0.4s, v7.16b, v26.4b[2]\n"
+            ".inst 0x4fbae8fe  // sdot v30.4s, v7.16b, v26.4b[3]\n"
+            ".inst 0x4f99e064  // sdot v4.4s, v3.16b, v25.4b[0]\n"
+            ".inst 0x4fb9e061  // sdot v1.4s, v3.16b, v25.4b[1]\n"
+            ".inst 0x4f99e860  // sdot v0.4s, v3.16b, v25.4b[2]\n"
+            ".inst 0x4fb9e87e  // sdot v30.4s, v3.16b, v25.4b[3]\n"
+            ".inst 0x4f98e1a4  // sdot v4.4s, v13.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e1a1  // sdot v1.4s, v13.16b, v24.4b[1]\n"
+            ".inst 0x4f98e9a0  // sdot v0.4s, v13.16b, v24.4b[2]\n"
+            ".inst 0x4fb8e9be  // sdot v30.4s, v13.16b, v24.4b[3]\n"
+            ".inst 0x4f90e384  // sdot v4.4s, v28.16b, v16.4b[0]\n"
+            ".inst 0x4fb0e381  // sdot v1.4s, v28.16b, v16.4b[1]\n"
+            ".inst 0x4f90eb80  // sdot v0.4s, v28.16b, v16.4b[2]\n"
+            ".inst 0x4fb0eb9e  // sdot v30.4s, v28.16b, v16.4b[3]\n"
+            "scvtf v4.4s, v4.4s, #0x4\n"
+            "scvtf v1.4s, v1.4s, #0x4\n"
+            "scvtf v0.4s, v0.4s, #0x4\n"
+            "fmla v15.4s, v4.4s, v11.4s\n"
+            "scvtf v30.4s, v30.4s, #0x4\n"
+            "fmla v19.4s, v1.4s, v23.4s\n"
+            "fmla v18.4s, v0.4s, v17.4s\n"
+            "fmla v14.4s, v30.4s, v6.4s\n"
+            "bgt 7b\n"
+            "mov x20, %x[res_ptr]\n"
+            "cmp x10, #0x1\n"
+            "str q15, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "ble 8f\n"
+            "cmp x10, #0x2\n"
+            "str q19, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "ble 8f\n"
+            "cmp x10, #0x3\n"
+            "str q18, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "ble 8f\n"
+            "str q14, [x20, #0x0]\n"
+            "8:"  // Row tail: Accumulator store skip
+            "subs x23, x23, #0x4\n"
+            "add %x[res_ptr], %x[res_ptr], #0x10\n"
+            "bne 6b\n"
+            "subs x10, x10, #0x4\n"
+            "add %x[a_ptr], %x[a_ptr], x9\n"
+            "mov %x[res_ptr], x22\n"
+            "bgt 5b\n"
+            "9:"  // Row tail: Row loop skip
+            : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
+            : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
+            : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
+        );
+        return;
+    }
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
+    {
+        float sumf[4][4];
+        int sumi;
+
+        for (int y = 0; y < nr / 4; y++) {
+            const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+            for (int x = 0; x < nc / ncols_interleaved; x++) {
+                const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
+                for (int m = 0; m < 4; m++) {
+                    for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+                }
+                for (int l = 0; l < nb; l++) {
+                    for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                        for (int m = 0; m < 4; m++) {
+                            for (int j = 0; j < ncols_interleaved; j++) {
+                                sumi = 0;
+                                for (int i = 0; i < blocklen; ++i) {
+                                    const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                                    const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                                    sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+                                            (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
+                                }
+                                sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
+                            }
+                        }
+                    }
+                }
+                for (int m = 0; m < 4; m++) {
+                    for (int j = 0; j < ncols_interleaved; j++)
+                        s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+                }
+            }
+        }
+    }
+}
+
+void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 8;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+        const void * b_ptr = vx;
+        const void * a_ptr = vy;
+        float * res_ptr = s;
+        size_t res_stride = bs * sizeof(float);
+
+        __asm__ __volatile__(
+            "mov x10, %x[nr]\n"
+            "mov x9, #0x88\n"
+            "cmp x10, #0x10\n"
+            "mul x9, %x[nb], x9\n"
+            "blt 4f\n"
+            "1:"  // Row loop
+            "add x28, %x[b_ptr], #0x8\n"
+            "mov x27, %x[nc]\n"
+            "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
+            "2:"  // Column loop
+            "add x25, %x[a_ptr], #0x8\n"
+            "movi v2.16b, #0x0\n"
+            "movi v10.16b, #0x0\n"
+            "mov x24, %x[nb]\n"
+            "add x23, x25, x9\n"
+            "movi v12.16b, #0x0\n"
+            "movi v28.16b, #0x0\n"
+            "add x22, x23, x9\n"
+            "movi v11.16b, #0x0\n"
+            "movi v13.16b, #0x0\n"
+            "add x21, x22, x9\n"
+            "movi v22.16b, #0x0\n"
+            "movi v23.16b, #0x0\n"
+            "movi v25.16b, #0x0\n"
+            "movi v5.16b, #0x0\n"
+            "movi v7.16b, #0x0\n"
+            "movi v4.16b, #0x0\n"
+            "movi v6.16b, #0x0\n"
+            "movi v30.16b, #0x0\n"
+            "movi v24.16b, #0x0\n"
+            "movi v14.16b, #0x0\n"
+            "3:"  // Block loop
+            "ldr q21, [x28, #0x0]\n"
+            "ldr q16, [x28, #0x10]\n"
+            "movi v1.16b, #0x4\n"
+            "movi v19.4s, #0x0\n"
+            "ldr q27, [x25, #0x0]\n"
+            "ldr q15, [x25, #0x10]\n"
+            "movi v26.4s, #0x0\n"
+            "movi v18.4s, #0x0\n"
+            "ldr q29, [x28, #0x20]\n"
+            "ldr q3, [x28, #0x30]\n"
+            "movi v17.4s, #0x0\n"
+            "movi v0.16b, #0xf0\n"
+            "ldr d20, [x25, #-0x8]\n"
+            "ldr d9, [x23, #-0x8]\n"
+            "sshl v8.16b, v21.16b, v1.16b\n"
+            "sshl v31.16b, v16.16b, v1.16b\n"
+            "and v21.16b, v21.16b, v0.16b\n"
+            "and v16.16b, v16.16b, v0.16b\n"
+            "sub x20, x28, #0x8\n"
+            "subs x24, x24, #0x1\n"
+            "add x28, x28, #0x48\n"
+            ".inst 0x4e88a773  // smmla v19.4s, v27.16b, v8.16b\n"
+            ".inst 0x4e9fa77a  // smmla v26.4s, v27.16b, v31.16b\n"
+            "ldr q27, [x25, #0x20]\n"
+            ".inst 0x4e88a5f2  // smmla v18.4s, v15.16b, v8.16b\n"
+            ".inst 0x4e9fa5f1  // smmla v17.4s, v15.16b, v31.16b\n"
+            "sshl v15.16b, v29.16b, v1.16b\n"
+            "sshl v1.16b, v3.16b, v1.16b\n"
+            "and v29.16b, v29.16b, v0.16b\n"
+            "and v3.16b, v3.16b, v0.16b\n"
+            "ldr q0, [x25, #0x30]\n"
+            "fcvtl v20.4s, v20.4h\n"
+            ".inst 0x4e8fa773  // smmla v19.4s, v27.16b, v15.16b\n"
+            "fcvtl v9.4s, v9.4h\n"
+            ".inst 0x4e81a77a  // smmla v26.4s, v27.16b, v1.16b\n"
+            "ldr q27, [x25, #0x40]\n"
+            ".inst 0x4e8fa412  // smmla v18.4s, v0.16b, v15.16b\n"
+            ".inst 0x4e81a411  // smmla v17.4s, v0.16b, v1.16b\n"
+            "ldr q0, [x25, #0x50]\n"
+            ".inst 0x4e95a773  // smmla v19.4s, v27.16b, v21.16b\n"
+            ".inst 0x4e90a77a  // smmla v26.4s, v27.16b, v16.16b\n"
+            "ldr q27, [x25, #0x60]\n"
+            ".inst 0x4e95a412  // smmla v18.4s, v0.16b, v21.16b\n"
+            ".inst 0x4e90a411  // smmla v17.4s, v0.16b, v16.16b\n"
+            "ldr q0, [x25, #0x70]\n"
+            "add x25, x25, #0x88\n"
+            ".inst 0x4e9da773  // smmla v19.4s, v27.16b, v29.16b\n"
+            ".inst 0x4e83a77a  // smmla v26.4s, v27.16b, v3.16b\n"
+            "ldr d27, [x20, #0x0]\n"
+            ".inst 0x4e9da412  // smmla v18.4s, v0.16b, v29.16b\n"
+            ".inst 0x4e83a411  // smmla v17.4s, v0.16b, v3.16b\n"
+            "fcvtl v27.4s, v27.4h\n"
+            "uzp1 v0.2d, v19.2d, v26.2d\n"
+            "uzp2 v26.2d, v19.2d, v26.2d\n"
+            "fmul v19.4s, v27.4s, v20.s[0]\n"
+            "scvtf v0.4s, v0.4s, #0x4\n"
+            "scvtf v26.4s, v26.4s, #0x4\n"
+            "fmla v2.4s, v0.4s, v19.4s\n"
+            "ldr q19, [x23, #0x0]\n"
+            "uzp1 v0.2d, v18.2d, v17.2d\n"
+            "uzp2 v18.2d, v18.2d, v17.2d\n"
+            "fmul v17.4s, v27.4s, v20.s[1]\n"
+            "scvtf v0.4s, v0.4s, #0x4\n"
+            "scvtf v18.4s, v18.4s, #0x4\n"
+            "fmla v10.4s, v26.4s, v17.4s\n"
+            "ldr q17, [x23, #0x10]\n"
+            "fmul v26.4s, v27.4s, v20.s[2]\n"
+            "fmul v20.4s, v27.4s, v20.s[3]\n"
+            "fmla v12.4s, v0.4s, v26.4s\n"
+            "ldr d0, [x22, #-0x8]\n"
+            "ldr d26, [x21, #-0x8]\n"
+            "fcvtl v0.4s, v0.4h\n"
+            "fmla v28.4s, v18.4s, v20.4s\n"
+            "movi v20.4s, #0x0\n"
+            "movi v18.4s, #0x0\n"
+            ".inst 0x4e88a674  // smmla v20.4s, v19.16b, v8.16b\n"
+            ".inst 0x4e9fa672  // smmla v18.4s, v19.16b, v31.16b\n"
+            "ldr q19, [x23, #0x20]\n"
+            "fcvtl v26.4s, v26.4h\n"
+            ".inst 0x4e8fa674  // smmla v20.4s, v19.16b, v15.16b\n"
+            ".inst 0x4e81a672  // smmla v18.4s, v19.16b, v1.16b\n"
+            "ldr q19, [x23, #0x40]\n"
+            ".inst 0x4e95a674  // smmla v20.4s, v19.16b, v21.16b\n"
+            ".inst 0x4e90a672  // smmla v18.4s, v19.16b, v16.16b\n"
+            "ldr q19, [x23, #0x60]\n"
+            ".inst 0x4e9da674  // smmla v20.4s, v19.16b, v29.16b\n"
+            ".inst 0x4e83a672  // smmla v18.4s, v19.16b, v3.16b\n"
+            "uzp1 v19.2d, v20.2d, v18.2d\n"
+            "scvtf v19.4s, v19.4s, #0x4\n"
+            "uzp2 v20.2d, v20.2d, v18.2d\n"
+            "fmul v18.4s, v27.4s, v9.s[0]\n"
+            "scvtf v20.4s, v20.4s, #0x4\n"
+            "fmla v11.4s, v19.4s, v18.4s\n"
+            "ldr q18, [x22, #0x0]\n"
+            "fmul v19.4s, v27.4s, v9.s[1]\n"
+            "fmla v13.4s, v20.4s, v19.4s\n"
+            "movi v19.4s, #0x0\n"
+            "movi v20.4s, #0x0\n"
+            ".inst 0x4e88a633  // smmla v19.4s, v17.16b, v8.16b\n"
+            ".inst 0x4e9fa634  // smmla v20.4s, v17.16b, v31.16b\n"
+            "ldr q17, [x23, #0x30]\n"
+            ".inst 0x4e8fa633  // smmla v19.4s, v17.16b, v15.16b\n"
+            ".inst 0x4e81a634  // smmla v20.4s, v17.16b, v1.16b\n"
+            "ldr q17, [x23, #0x50]\n"
+            ".inst 0x4e95a633  // smmla v19.4s, v17.16b, v21.16b\n"
+            ".inst 0x4e90a634  // smmla v20.4s, v17.16b, v16.16b\n"
+            "ldr q17, [x23, #0x70]\n"
+            "add x23, x23, #0x88\n"
+            ".inst 0x4e9da633  // smmla v19.4s, v17.16b, v29.16b\n"
+            ".inst 0x4e83a634  // smmla v20.4s, v17.16b, v3.16b\n"
+            "uzp1 v17.2d, v19.2d, v20.2d\n"
+            "scvtf v17.4s, v17.4s, #0x4\n"
+            "uzp2 v20.2d, v19.2d, v20.2d\n"
+            "fmul v19.4s, v27.4s, v9.s[2]\n"
+            "fmul v9.4s, v27.4s, v9.s[3]\n"
+            "scvtf v20.4s, v20.4s, #0x4\n"
+            "fmla v22.4s, v17.4s, v19.4s\n"
+            "ldr q17, [x22, #0x10]\n"
+            "movi v19.4s, #0x0\n"
+            ".inst 0x4e88a653  // smmla v19.4s, v18.16b, v8.16b\n"
+            "fmla v23.4s, v20.4s, v9.4s\n"
+            "movi v20.4s, #0x0\n"
+            "movi v9.4s, #0x0\n"
+            ".inst 0x4e9fa654  // smmla v20.4s, v18.16b, v31.16b\n"
+            "ldr q18, [x22, #0x20]\n"
+            ".inst 0x4e88a629  // smmla v9.4s, v17.16b, v8.16b\n"
+            ".inst 0x4e8fa653  // smmla v19.4s, v18.16b, v15.16b\n"
+            ".inst 0x4e81a654  // smmla v20.4s, v18.16b, v1.16b\n"
+            "ldr q18, [x22, #0x40]\n"
+            ".inst 0x4e95a653  // smmla v19.4s, v18.16b, v21.16b\n"
+            ".inst 0x4e90a654  // smmla v20.4s, v18.16b, v16.16b\n"
+            "ldr q18, [x22, #0x60]\n"
+            ".inst 0x4e9da653  // smmla v19.4s, v18.16b, v29.16b\n"
+            ".inst 0x4e83a654  // smmla v20.4s, v18.16b, v3.16b\n"
+            "movi v18.4s, #0x0\n"
+            ".inst 0x4e9fa632  // smmla v18.4s, v17.16b, v31.16b\n"
+            "ldr q17, [x22, #0x30]\n"
+            ".inst 0x4e8fa629  // smmla v9.4s, v17.16b, v15.16b\n"
+            ".inst 0x4e81a632  // smmla v18.4s, v17.16b, v1.16b\n"
+            "ldr q17, [x22, #0x50]\n"
+            ".inst 0x4e95a629  // smmla v9.4s, v17.16b, v21.16b\n"
+            ".inst 0x4e90a632  // smmla v18.4s, v17.16b, v16.16b\n"
+            "ldr q17, [x22, #0x70]\n"
+            "add x22, x22, #0x88\n"
+            ".inst 0x4e9da629  // smmla v9.4s, v17.16b, v29.16b\n"
+            ".inst 0x4e83a632  // smmla v18.4s, v17.16b, v3.16b\n"
+            "uzp1 v17.2d, v19.2d, v20.2d\n"
+            "uzp2 v20.2d, v19.2d, v20.2d\n"
+            "fmul v19.4s, v27.4s, v0.s[0]\n"
+            "scvtf v17.4s, v17.4s, #0x4\n"
+            "scvtf v20.4s, v20.4s, #0x4\n"
+            "fmla v25.4s, v17.4s, v19.4s\n"
+            "ldr q19, [x21, #0x0]\n"
+            "fmul v17.4s, v27.4s, v0.s[1]\n"
+            "fmla v5.4s, v20.4s, v17.4s\n"
+            "ldr q17, [x21, #0x10]\n"
+            "uzp1 v20.2d, v9.2d, v18.2d\n"
+            "uzp2 v9.2d, v9.2d, v18.2d\n"
+            "fmul v18.4s, v27.4s, v0.s[2]\n"
+            "fmul v0.4s, v27.4s, v0.s[3]\n"
+            "scvtf v20.4s, v20.4s, #0x4\n"
+            "scvtf v9.4s, v9.4s, #0x4\n"
+            "fmla v7.4s, v20.4s, v18.4s\n"
+            "movi v20.4s, #0x0\n"
+            "movi v18.4s, #0x0\n"
+            ".inst 0x4e88a674  // smmla v20.4s, v19.16b, v8.16b\n"
+            ".inst 0x4e9fa672  // smmla v18.4s, v19.16b, v31.16b\n"
+            "ldr q19, [x21, #0x20]\n"
+            "fmla v4.4s, v9.4s, v0.4s\n"
+            "movi v9.4s, #0x0\n"
+            "movi v0.4s, #0x0\n"
+            ".inst 0x4e88a629  // smmla v9.4s, v17.16b, v8.16b\n"
+            "fmul v8.4s, v27.4s, v26.s[0]\n"
+            ".inst 0x4e9fa620  // smmla v0.4s, v17.16b, v31.16b\n"
+            "ldr q17, [x21, #0x30]\n"
+            ".inst 0x4e8fa674  // smmla v20.4s, v19.16b, v15.16b\n"
+            "fmul v31.4s, v27.4s, v26.s[1]\n"
+            ".inst 0x4e81a672  // smmla v18.4s, v19.16b, v1.16b\n"
+            "ldr q19, [x21, #0x40]\n"
+            ".inst 0x4e8fa629  // smmla v9.4s, v17.16b, v15.16b\n"
+            "fmul v15.4s, v27.4s, v26.s[2]\n"
+            "fmul v27.4s, v27.4s, v26.s[3]\n"
+            ".inst 0x4e81a620  // smmla v0.4s, v17.16b, v1.16b\n"
+            "ldr q1, [x21, #0x50]\n"
+            ".inst 0x4e95a674  // smmla v20.4s, v19.16b, v21.16b\n"
+            ".inst 0x4e90a672  // smmla v18.4s, v19.16b, v16.16b\n"
+            "ldr q26, [x21, #0x60]\n"
+            ".inst 0x4e95a429  // smmla v9.4s, v1.16b, v21.16b\n"
+            ".inst 0x4e90a420  // smmla v0.4s, v1.16b, v16.16b\n"
+            "ldr q21, [x21, #0x70]\n"
+            "add x21, x21, #0x88\n"
+            ".inst 0x4e9da754  // smmla v20.4s, v26.16b, v29.16b\n"
+            ".inst 0x4e83a752  // smmla v18.4s, v26.16b, v3.16b\n"
+            ".inst 0x4e9da6a9  // smmla v9.4s, v21.16b, v29.16b\n"
+            ".inst 0x4e83a6a0  // smmla v0.4s, v21.16b, v3.16b\n"
+            "uzp1 v29.2d, v20.2d, v18.2d\n"
+            "uzp2 v21.2d, v20.2d, v18.2d\n"
+            "scvtf v29.4s, v29.4s, #0x4\n"
+            "uzp1 v18.2d, v9.2d, v0.2d\n"
+            "uzp2 v16.2d, v9.2d, v0.2d\n"
+            "scvtf v21.4s, v21.4s, #0x4\n"
+            "fmla v6.4s, v29.4s, v8.4s\n"
+            "scvtf v18.4s, v18.4s, #0x4\n"
+            "scvtf v16.4s, v16.4s, #0x4\n"
+            "fmla v30.4s, v21.4s, v31.4s\n"
+            "fmla v24.4s, v18.4s, v15.4s\n"
+            "fmla v14.4s, v16.4s, v27.4s\n"
+            "bgt 3b\n"
+            "mov x20, %x[res_ptr]\n"
+            "subs x27, x27, #0x4\n"
+            "add %x[res_ptr], %x[res_ptr], #0x10\n"
+            "str q2, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q10, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q12, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q28, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q11, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q13, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q22, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q23, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q25, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q5, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q7, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q4, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q6, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q30, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q24, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q14, [x20, #0x0]\n"
+            "bne 2b\n"
+            "mov x20, #0x4\n"
+            "sub x10, x10, #0x10\n"
+            "cmp x10, #0x10\n"
+            "mov %x[res_ptr], x26\n"
+            "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
+            "bge 1b\n"
+            "4:"  // Row loop skip
+            "cbz x10, 9f\n"
+            "5:"  // Row tail: Row loop
+            "add x24, %x[b_ptr], #0x8\n"
+            "mov x23, %x[nc]\n"
+            "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
+            "6:"  // Row tail: Column loop
+            "movi v2.16b, #0x0\n"
+            "movi v10.16b, #0x0\n"
+            "add x25, %x[a_ptr], #0x8\n"
+            "mov x21, %x[nb]\n"
+            "movi v12.16b, #0x0\n"
+            "movi v28.16b, #0x0\n"
+            "7:"  // Row tail: Block loop
+            "ldr q6, [x24, #0x0]\n"
+            "ldr q5, [x24, #0x10]\n"
+            "movi v17.16b, #0x4\n"
+            "movi v8.4s, #0x0\n"
+            "ldr q4, [x25, #0x0]\n"
+            "ldr q13, [x25, #0x10]\n"
+            "movi v27.4s, #0x0\n"
+            "movi v0.4s, #0x0\n"
+            "ldr q31, [x24, #0x20]\n"
+            "ldr q14, [x24, #0x30]\n"
+            "movi v29.4s, #0x0\n"
+            "movi v22.16b, #0xf0\n"
+            "ldr q11, [x25, #0x20]\n"
+            "ldr q23, [x25, #0x30]\n"
+            "sshl v21.16b, v6.16b, v17.16b\n"
+            "sshl v16.16b, v5.16b, v17.16b\n"
+            "ldr q20, [x25, #0x40]\n"
+            "ldr q26, [x25, #0x50]\n"
+            "and v6.16b, v6.16b, v22.16b\n"
+            "and v5.16b, v5.16b, v22.16b\n"
+            "ldr q25, [x25, #0x60]\n"
+            "ldr q3, [x25, #0x70]\n"
+            "sshl v19.16b, v31.16b, v17.16b\n"
+            "sshl v18.16b, v14.16b, v17.16b\n"
+            "ldr d17, [x25, #-0x8]\n"
+            ".inst 0x4e95a488  // smmla v8.4s, v4.16b, v21.16b\n"
+            ".inst 0x4e90a49b  // smmla v27.4s, v4.16b, v16.16b\n"
+            "and v31.16b, v31.16b, v22.16b\n"
+            ".inst 0x4e95a5a0  // smmla v0.4s, v13.16b, v21.16b\n"
+            ".inst 0x4e90a5bd  // smmla v29.4s, v13.16b, v16.16b\n"
+            "and v14.16b, v14.16b, v22.16b\n"
+            "sub x20, x24, #0x8\n"
+            "ldr d16, [x20, #0x0]\n"
+            "subs x21, x21, #0x1\n"
+            "add x25, x25, #0x88\n"
+            "fcvtl v17.4s, v17.4h\n"
+            "add x24, x24, #0x48\n"
+            ".inst 0x4e93a568  // smmla v8.4s, v11.16b, v19.16b\n"
+            ".inst 0x4e92a57b  // smmla v27.4s, v11.16b, v18.16b\n"
+            ".inst 0x4e93a6e0  // smmla v0.4s, v23.16b, v19.16b\n"
+            ".inst 0x4e92a6fd  // smmla v29.4s, v23.16b, v18.16b\n"
+            "fcvtl v16.4s, v16.4h\n"
+            ".inst 0x4e86a688  // smmla v8.4s, v20.16b, v6.16b\n"
+            ".inst 0x4e85a69b  // smmla v27.4s, v20.16b, v5.16b\n"
+            "fmul v23.4s, v16.4s, v17.s[0]\n"
+            "fmul v21.4s, v16.4s, v17.s[1]\n"
+            "fmul v1.4s, v16.4s, v17.s[2]\n"
+            "fmul v20.4s, v16.4s, v17.s[3]\n"
+            ".inst 0x4e86a740  // smmla v0.4s, v26.16b, v6.16b\n"
+            ".inst 0x4e85a75d  // smmla v29.4s, v26.16b, v5.16b\n"
+            ".inst 0x4e9fa728  // smmla v8.4s, v25.16b, v31.16b\n"
+            ".inst 0x4e8ea73b  // smmla v27.4s, v25.16b, v14.16b\n"
+            ".inst 0x4e9fa460  // smmla v0.4s, v3.16b, v31.16b\n"
+            ".inst 0x4e8ea47d  // smmla v29.4s, v3.16b, v14.16b\n"
+            "uzp1 v19.2d, v8.2d, v27.2d\n"
+            "uzp2 v18.2d, v8.2d, v27.2d\n"
+            "scvtf v19.4s, v19.4s, #0x4\n"
+            "uzp1 v17.2d, v0.2d, v29.2d\n"
+            "uzp2 v16.2d, v0.2d, v29.2d\n"
+            "scvtf v18.4s, v18.4s, #0x4\n"
+            "fmla v2.4s, v19.4s, v23.4s\n"
+            "scvtf v17.4s, v17.4s, #0x4\n"
+            "scvtf v16.4s, v16.4s, #0x4\n"
+            "fmla v10.4s, v18.4s, v21.4s\n"
+            "fmla v12.4s, v17.4s, v1.4s\n"
+            "fmla v28.4s, v16.4s, v20.4s\n"
+            "bgt 7b\n"
+            "mov x20, %x[res_ptr]\n"
+            "cmp x10, #0x1\n"
+            "str q2, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "ble 8f\n"
+            "cmp x10, #0x2\n"
+            "str q10, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "ble 8f\n"
+            "cmp x10, #0x3\n"
+            "str q12, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "ble 8f\n"
+            "str q28, [x20, #0x0]\n"
+            "8:"  // Row tail: Accumulator store skip
+            "subs x23, x23, #0x4\n"
+            "add %x[res_ptr], %x[res_ptr], #0x10\n"
+            "bne 6b\n"
+            "subs x10, x10, #0x4\n"
+            "add %x[a_ptr], %x[a_ptr], x9\n"
+            "mov %x[res_ptr], x22\n"
+            "bgt 5b\n"
+            "9:"  // Row tail: Row loop skip
+            : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
+            : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
+            : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
+        );
+        return;
+    }
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    float sumf[4][4];
+    int sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+                                        (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
+                            }
+                            sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++)
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+            }
+        }
+    }
+}
+
+void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 8;
+    const int blocklen = 8;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
+#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
+    if (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) {
+        const void * b_ptr = vx;
+        const void * a_ptr = vy;
+        float * res_ptr = s;
+        size_t res_stride = bs * sizeof(float);
+
+        __asm__ __volatile__(
+            "mov x20, #0x4\n"
+            "mov x13, %x[nr]\n"
+            "mov z28.s, #-0x4\n"
+            "mov x12, #0x88\n"
+            "ptrue p1.b\n"
+            "whilelt p0.s, XZR, x20\n"
+            "cmp x13, #0x10\n"
+            "mul x12, %x[nb], x12\n"
+            "blt 4f\n"
+            "1:"  // Row loop
+            "add x11, %x[b_ptr], #0x10\n"
+            "mov x10, %x[nc]\n"
+            "add x9, %x[res_ptr], %x[res_stride], LSL #4\n"
+            "2:"  // Column loop
+            "add x28, %x[a_ptr], #0x8\n"
+            "mov z24.b, #0x0\n"
+            "mov z15.b, #0x0\n"
+            "mov x27, %x[nb]\n"
+            "add x26, x28, x12\n"
+            "mov z12.b, #0x0\n"
+            "mov z0.b, #0x0\n"
+            "add x25, x26, x12\n"
+            "mov z13.b, #0x0\n"
+            "mov z1.b, #0x0\n"
+            "add x24, x25, x12\n"
+            "mov z20.b, #0x0\n"
+            "mov z25.b, #0x0\n"
+            "mov z11.b, #0x0\n"
+            "mov z16.b, #0x0\n"
+            "mov z19.b, #0x0\n"
+            "mov z26.b, #0x0\n"
+            "mov z8.b, #0x0\n"
+            "mov z29.b, #0x0\n"
+            "mov z27.b, #0x0\n"
+            "mov z10.b, #0x0\n"
+            "3:"  // Block loop
+            "ld1b { z30.b }, p1/Z, [x11]\n"
+            "ld1b { z21.b }, p1/Z, [x11, #1, MUL VL]\n"
+            "mov z18.s, #0x0\n"
+            "mov z7.s, #0x0\n"
+            "ld1rqb { z3.b }, p1/Z, [x28]\n"
+            "ld1rqb { z5.b }, p1/Z, [x28, #16]\n"
+            "mov z9.s, #0x0\n"
+            "mov z22.s, #0x0\n"
+            "ld1b { z4.b }, p1/Z, [x11, #2, MUL VL]\n"
+            "ld1b { z17.b }, p1/Z, [x11, #3, MUL VL]\n"
+            "sub x20, x11, #0x10\n"
+            "sub x23, x28, #0x8\n"
+            "lsl z31.b, z30.b, #0x4\n"
+            "lsl z6.b, z21.b, #0x4\n"
+            "ld1h { z23.s }, p1/Z, [x20]\n"
+            "sub x22, x26, #0x8\n"
+            "and z30.b, z30.b, #0xf0\n"
+            "and z21.b, z21.b, #0xf0\n"
+            "sub x21, x25, #0x8\n"
+            "sub x20, x24, #0x8\n"
+            "lsl z14.b, z4.b, #0x4\n"
+            "lsl z2.b, z17.b, #0x4\n"
+            "subs x27, x27, #0x1\n"
+            "add x11, x11, #0x90\n"
+            ".inst 0x451f9872  // smmla z18.s, z3.b, z31.b\n"
+            ".inst 0x45069867  // smmla z7.s, z3.b, z6.b\n"
+            "ld1rqb { z3.b }, p1/Z, [x28, #32]\n"
+            "and z4.b, z4.b, #0xf0\n"
+            ".inst 0x451f98a9  // smmla z9.s, z5.b, z31.b\n"
+            ".inst 0x450698b6  // smmla z22.s, z5.b, z6.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x28, #48]\n"
+            "and z17.b, z17.b, #0xf0\n"
+            "fcvt z23.s, p1/m, z23.h\n"
+            ".inst 0x450e9872  // smmla z18.s, z3.b, z14.b\n"
+            ".inst 0x45029867  // smmla z7.s, z3.b, z2.b\n"
+            "ld1rqb { z3.b }, p1/Z, [x28, #64]\n"
+            ".inst 0x450e98a9  // smmla z9.s, z5.b, z14.b\n"
+            ".inst 0x450298b6  // smmla z22.s, z5.b, z2.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x28, #80]\n"
+            "fscale z23.s, p1/m, z23.s, z28.s\n"
+            ".inst 0x451e9872  // smmla z18.s, z3.b, z30.b\n"
+            ".inst 0x45159867  // smmla z7.s, z3.b, z21.b\n"
+            "ld1rqb { z3.b }, p1/Z, [x28, #96]\n"
+            ".inst 0x451e98a9  // smmla z9.s, z5.b, z30.b\n"
+            ".inst 0x451598b6  // smmla z22.s, z5.b, z21.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x28, #112]\n"
+            "add x28, x28, #0x88\n"
+            ".inst 0x45049872  // smmla z18.s, z3.b, z4.b\n"
+            ".inst 0x45119867  // smmla z7.s, z3.b, z17.b\n"
+            "ld1h { z3.s }, p0/Z, [x23]\n"
+            ".inst 0x450498a9  // smmla z9.s, z5.b, z4.b\n"
+            ".inst 0x451198b6  // smmla z22.s, z5.b, z17.b\n"
+            "fcvt z3.s, p1/m, z3.h\n"
+            "uzp1 z5.d, z18.d, z7.d\n"
+            "uzp2 z18.d, z18.d, z7.d\n"
+            "mov z3.q, z3.q[0]\n"
+            "uzp1 z7.d, z9.d, z22.d\n"
+            "uzp2 z22.d, z9.d, z22.d\n"
+            "fmul z9.s, z23.s, z3.s[0]\n"
+            "scvtf z5.s, p1/m, z5.s\n"
+            "scvtf z18.s, p1/m, z18.s\n"
+            "scvtf z7.s, p1/m, z7.s\n"
+            "scvtf z22.s, p1/m, z22.s\n"
+            "fmla z24.s, p1/M, z5.s, z9.s\n"
+            "ld1rqb { z5.b }, p1/Z, [x26]\n"
+            "fmul z9.s, z23.s, z3.s[1]\n"
+            "fmla z15.s, p1/M, z18.s, z9.s\n"
+            "ld1rqb { z18.b }, p1/Z, [x26, #16]\n"
+            "fmul z9.s, z23.s, z3.s[2]\n"
+            "fmul z3.s, z23.s, z3.s[3]\n"
+            "fmla z12.s, p1/M, z7.s, z9.s\n"
+            "mov z9.s, #0x0\n"
+            "ld1h { z7.s }, p0/Z, [x22]\n"
+            ".inst 0x451f98a9  // smmla z9.s, z5.b, z31.b\n"
+            "fmla z0.s, p1/M, z22.s, z3.s\n"
+            "mov z22.s, #0x0\n"
+            "ld1h { z3.s }, p0/Z, [x21]\n"
+            ".inst 0x450698b6  // smmla z22.s, z5.b, z6.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x26, #32]\n"
+            "fcvt z7.s, p1/m, z7.h\n"
+            "fcvt z3.s, p1/m, z3.h\n"
+            ".inst 0x450e98a9  // smmla z9.s, z5.b, z14.b\n"
+            ".inst 0x450298b6  // smmla z22.s, z5.b, z2.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x26, #64]\n"
+            "mov z7.q, z7.q[0]\n"
+            "mov z3.q, z3.q[0]\n"
+            ".inst 0x451e98a9  // smmla z9.s, z5.b, z30.b\n"
+            ".inst 0x451598b6  // smmla z22.s, z5.b, z21.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x26, #96]\n"
+            ".inst 0x450498a9  // smmla z9.s, z5.b, z4.b\n"
+            ".inst 0x451198b6  // smmla z22.s, z5.b, z17.b\n"
+            "uzp1 z5.d, z9.d, z22.d\n"
+            "scvtf z5.s, p1/m, z5.s\n"
+            "uzp2 z22.d, z9.d, z22.d\n"
+            "fmul z9.s, z23.s, z7.s[0]\n"
+            "scvtf z22.s, p1/m, z22.s\n"
+            "fmla z13.s, p1/M, z5.s, z9.s\n"
+            "ld1rqb { z9.b }, p1/Z, [x25]\n"
+            "fmul z5.s, z23.s, z7.s[1]\n"
+            "fmla z1.s, p1/M, z22.s, z5.s\n"
+            "mov z5.s, #0x0\n"
+            "mov z22.s, #0x0\n"
+            ".inst 0x451f9a45  // smmla z5.s, z18.b, z31.b\n"
+            ".inst 0x45069a56  // smmla z22.s, z18.b, z6.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x26, #48]\n"
+            ".inst 0x450e9a45  // smmla z5.s, z18.b, z14.b\n"
+            ".inst 0x45029a56  // smmla z22.s, z18.b, z2.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x26, #80]\n"
+            ".inst 0x451e9a45  // smmla z5.s, z18.b, z30.b\n"
+            ".inst 0x45159a56  // smmla z22.s, z18.b, z21.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x26, #112]\n"
+            "add x26, x26, #0x88\n"
+            ".inst 0x45049a45  // smmla z5.s, z18.b, z4.b\n"
+            ".inst 0x45119a56  // smmla z22.s, z18.b, z17.b\n"
+            "uzp1 z18.d, z5.d, z22.d\n"
+            "scvtf z18.s, p1/m, z18.s\n"
+            "uzp2 z22.d, z5.d, z22.d\n"
+            "fmul z5.s, z23.s, z7.s[2]\n"
+            "fmul z7.s, z23.s, z7.s[3]\n"
+            "scvtf z22.s, p1/m, z22.s\n"
+            "fmla z20.s, p1/M, z18.s, z5.s\n"
+            "ld1rqb { z18.b }, p1/Z, [x25, #16]\n"
+            "ld1h { z5.s }, p0/Z, [x20]\n"
+            "fcvt z5.s, p1/m, z5.h\n"
+            "fmla z25.s, p1/M, z22.s, z7.s\n"
+            "mov z22.s, #0x0\n"
+            "mov z7.s, #0x0\n"
+            ".inst 0x451f9936  // smmla z22.s, z9.b, z31.b\n"
+            ".inst 0x45069927  // smmla z7.s, z9.b, z6.b\n"
+            "ld1rqb { z9.b }, p1/Z, [x25, #32]\n"
+            "mov z5.q, z5.q[0]\n"
+            ".inst 0x450e9936  // smmla z22.s, z9.b, z14.b\n"
+            ".inst 0x45029927  // smmla z7.s, z9.b, z2.b\n"
+            "ld1rqb { z9.b }, p1/Z, [x25, #64]\n"
+            ".inst 0x451e9936  // smmla z22.s, z9.b, z30.b\n"
+            ".inst 0x45159927  // smmla z7.s, z9.b, z21.b\n"
+            "ld1rqb { z9.b }, p1/Z, [x25, #96]\n"
+            ".inst 0x45049936  // smmla z22.s, z9.b, z4.b\n"
+            ".inst 0x45119927  // smmla z7.s, z9.b, z17.b\n"
+            "uzp1 z9.d, z22.d, z7.d\n"
+            "scvtf z9.s, p1/m, z9.s\n"
+            "uzp2 z22.d, z22.d, z7.d\n"
+            "fmul z7.s, z23.s, z3.s[0]\n"
+            "scvtf z22.s, p1/m, z22.s\n"
+            "fmla z11.s, p1/M, z9.s, z7.s\n"
+            "ld1rqb { z9.b }, p1/Z, [x24]\n"
+            "fmul z7.s, z23.s, z3.s[1]\n"
+            "fmla z16.s, p1/M, z22.s, z7.s\n"
+            "mov z22.s, #0x0\n"
+            "mov z7.s, #0x0\n"
+            ".inst 0x451f9a56  // smmla z22.s, z18.b, z31.b\n"
+            ".inst 0x45069a47  // smmla z7.s, z18.b, z6.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x25, #48]\n"
+            ".inst 0x450e9a56  // smmla z22.s, z18.b, z14.b\n"
+            ".inst 0x45029a47  // smmla z7.s, z18.b, z2.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x25, #80]\n"
+            ".inst 0x451e9a56  // smmla z22.s, z18.b, z30.b\n"
+            ".inst 0x45159a47  // smmla z7.s, z18.b, z21.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x25, #112]\n"
+            "add x25, x25, #0x88\n"
+            ".inst 0x45049a56  // smmla z22.s, z18.b, z4.b\n"
+            ".inst 0x45119a47  // smmla z7.s, z18.b, z17.b\n"
+            "uzp1 z18.d, z22.d, z7.d\n"
+            "scvtf z18.s, p1/m, z18.s\n"
+            "uzp2 z7.d, z22.d, z7.d\n"
+            "fmul z22.s, z23.s, z3.s[2]\n"
+            "fmul z3.s, z23.s, z3.s[3]\n"
+            "scvtf z7.s, p1/m, z7.s\n"
+            "fmla z19.s, p1/M, z18.s, z22.s\n"
+            "ld1rqb { z18.b }, p1/Z, [x24, #16]\n"
+            "fmul z22.s, z23.s, z5.s[0]\n"
+            "fmla z26.s, p1/M, z7.s, z3.s\n"
+            "mov z3.s, #0x0\n"
+            "mov z7.s, #0x0\n"
+            ".inst 0x451f9923  // smmla z3.s, z9.b, z31.b\n"
+            ".inst 0x45069927  // smmla z7.s, z9.b, z6.b\n"
+            "ld1rqb { z9.b }, p1/Z, [x24, #32]\n"
+            ".inst 0x450e9923  // smmla z3.s, z9.b, z14.b\n"
+            ".inst 0x45029927  // smmla z7.s, z9.b, z2.b\n"
+            "mov z9.s, #0x0\n"
+            ".inst 0x451f9a49  // smmla z9.s, z18.b, z31.b\n"
+            "mov z31.s, #0x0\n"
+            ".inst 0x45069a5f  // smmla z31.s, z18.b, z6.b\n"
+            "ld1rqb { z6.b }, p1/Z, [x24, #48]\n"
+            "ld1rqb { z18.b }, p1/Z, [x24, #64]\n"
+            ".inst 0x450e98c9  // smmla z9.s, z6.b, z14.b\n"
+            "fmul z14.s, z23.s, z5.s[1]\n"
+            ".inst 0x450298df  // smmla z31.s, z6.b, z2.b\n"
+            "ld1rqb { z6.b }, p1/Z, [x24, #80]\n"
+            "fmul z2.s, z23.s, z5.s[2]\n"
+            "fmul z23.s, z23.s, z5.s[3]\n"
+            ".inst 0x451e9a43  // smmla z3.s, z18.b, z30.b\n"
+            ".inst 0x45159a47  // smmla z7.s, z18.b, z21.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x24, #96]\n"
+            ".inst 0x451e98c9  // smmla z9.s, z6.b, z30.b\n"
+            ".inst 0x451598df  // smmla z31.s, z6.b, z21.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x24, #112]\n"
+            "add x24, x24, #0x88\n"
+            ".inst 0x450498a3  // smmla z3.s, z5.b, z4.b\n"
+            ".inst 0x451198a7  // smmla z7.s, z5.b, z17.b\n"
+            ".inst 0x45049a49  // smmla z9.s, z18.b, z4.b\n"
+            ".inst 0x45119a5f  // smmla z31.s, z18.b, z17.b\n"
+            "uzp1 z18.d, z3.d, z7.d\n"
+            "uzp2 z5.d, z3.d, z7.d\n"
+            "scvtf z18.s, p1/m, z18.s\n"
+            "uzp1 z6.d, z9.d, z31.d\n"
+            "uzp2 z9.d, z9.d, z31.d\n"
+            "scvtf z5.s, p1/m, z5.s\n"
+            "fmla z8.s, p1/M, z18.s, z22.s\n"
+            "scvtf z6.s, p1/m, z6.s\n"
+            "scvtf z9.s, p1/m, z9.s\n"
+            "fmla z29.s, p1/M, z5.s, z14.s\n"
+            "fmla z27.s, p1/M, z6.s, z2.s\n"
+            "fmla z10.s, p1/M, z9.s, z23.s\n"
+            "bgt 3b\n"
+            "mov x20, %x[res_ptr]\n"
+            "subs x10, x10, #0x8\n"
+            "add %x[res_ptr], %x[res_ptr], #0x20\n"
+            "st1w { z24.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z15.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z12.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z0.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z13.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z1.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z20.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z25.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z11.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z16.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z19.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z26.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z8.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z29.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z27.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z10.s }, p1, [x20]\n"
+            "bne 2b\n"
+            "mov x20, #0x4\n"
+            "sub x13, x13, #0x10\n"
+            "cmp x13, #0x10\n"
+            "mov %x[res_ptr], x9\n"
+            "madd %x[a_ptr], x20, x12, %x[a_ptr]\n"
+            "bge 1b\n"
+            "4:"  // Row loop skip
+            "cbz x13, 9f\n"
+            "5:"  // Row tail: Row loop
+            "add x25, %x[b_ptr], #0x10\n"
+            "mov x24, %x[nc]\n"
+            "add x23, %x[res_ptr], %x[res_stride], LSL #2\n"
+            "6:"  // Row tail: Column loop
+            "mov z24.b, #0x0\n"
+            "mov z15.b, #0x0\n"
+            "add x28, %x[a_ptr], #0x8\n"
+            "mov x22, %x[nb]\n"
+            "mov z12.b, #0x0\n"
+            "mov z0.b, #0x0\n"
+            "7:"  // Row tail: Block loop
+            "ld1b { z3.b }, p1/Z, [x25]\n"
+            "ld1b { z6.b }, p1/Z, [x25, #1, MUL VL]\n"
+            "mov z2.s, #0x0\n"
+            "mov z25.s, #0x0\n"
+            "ld1rqb { z26.b }, p1/Z, [x28]\n"
+            "ld1rqb { z21.b }, p1/Z, [x28, #16]\n"
+            "mov z27.s, #0x0\n"
+            "mov z19.s, #0x0\n"
+            "ld1b { z29.b }, p1/Z, [x25, #2, MUL VL]\n"
+            "ld1b { z16.b }, p1/Z, [x25, #3, MUL VL]\n"
+            "sub x21, x25, #0x10\n"
+            "sub x20, x28, #0x8\n"
+            "lsl z20.b, z3.b, #0x4\n"
+            "lsl z4.b, z6.b, #0x4\n"
+            "ld1rqb { z10.b }, p1/Z, [x28, #32]\n"
+            "ld1rqb { z23.b }, p1/Z, [x28, #48]\n"
+            "and z3.b, z3.b, #0xf0\n"
+            "and z6.b, z6.b, #0xf0\n"
+            "ld1rqb { z11.b }, p1/Z, [x28, #64]\n"
+            "ld1rqb { z7.b }, p1/Z, [x28, #80]\n"
+            "lsl z8.b, z29.b, #0x4\n"
+            "lsl z14.b, z16.b, #0x4\n"
+            "ld1rqb { z18.b }, p1/Z, [x28, #96]\n"
+            "ld1rqb { z30.b }, p1/Z, [x28, #112]\n"
+            ".inst 0x45149b42  // smmla z2.s, z26.b, z20.b\n"
+            ".inst 0x45049b59  // smmla z25.s, z26.b, z4.b\n"
+            "and z29.b, z29.b, #0xf0\n"
+            "ld1h { z17.s }, p1/Z, [x21]\n"
+            ".inst 0x45149abb  // smmla z27.s, z21.b, z20.b\n"
+            ".inst 0x45049ab3  // smmla z19.s, z21.b, z4.b\n"
+            "and z16.b, z16.b, #0xf0\n"
+            "ld1h { z4.s }, p0/Z, [x20]\n"
+            "subs x22, x22, #0x1\n"
+            "add x28, x28, #0x88\n"
+            "fcvt z17.s, p1/m, z17.h\n"
+            "add x25, x25, #0x90\n"
+            ".inst 0x45089942  // smmla z2.s, z10.b, z8.b\n"
+            ".inst 0x450e9959  // smmla z25.s, z10.b, z14.b\n"
+            "fcvt z4.s, p1/m, z4.h\n"
+            ".inst 0x45089afb  // smmla z27.s, z23.b, z8.b\n"
+            ".inst 0x450e9af3  // smmla z19.s, z23.b, z14.b\n"
+            "fscale z17.s, p1/m, z17.s, z28.s\n"
+            "mov z4.q, z4.q[0]\n"
+            ".inst 0x45039962  // smmla z2.s, z11.b, z3.b\n"
+            ".inst 0x45069979  // smmla z25.s, z11.b, z6.b\n"
+            "fmul z23.s, z17.s, z4.s[0]\n"
+            "fmul z9.s, z17.s, z4.s[1]\n"
+            "fmul z21.s, z17.s, z4.s[2]\n"
+            "fmul z4.s, z17.s, z4.s[3]\n"
+            ".inst 0x450398fb  // smmla z27.s, z7.b, z3.b\n"
+            ".inst 0x450698f3  // smmla z19.s, z7.b, z6.b\n"
+            ".inst 0x451d9a42  // smmla z2.s, z18.b, z29.b\n"
+            ".inst 0x45109a59  // smmla z25.s, z18.b, z16.b\n"
+            ".inst 0x451d9bdb  // smmla z27.s, z30.b, z29.b\n"
+            ".inst 0x45109bd3  // smmla z19.s, z30.b, z16.b\n"
+            "uzp1 z31.d, z2.d, z25.d\n"
+            "uzp2 z13.d, z2.d, z25.d\n"
+            "scvtf z31.s, p1/m, z31.s\n"
+            "uzp1 z17.d, z27.d, z19.d\n"
+            "uzp2 z18.d, z27.d, z19.d\n"
+            "scvtf z13.s, p1/m, z13.s\n"
+            "fmla z24.s, p1/M, z31.s, z23.s\n"
+            "scvtf z17.s, p1/m, z17.s\n"
+            "scvtf z18.s, p1/m, z18.s\n"
+            "fmla z15.s, p1/M, z13.s, z9.s\n"
+            "fmla z12.s, p1/M, z17.s, z21.s\n"
+            "fmla z0.s, p1/M, z18.s, z4.s\n"
+            "bgt 7b\n"
+            "mov x20, %x[res_ptr]\n"
+            "cmp x13, #0x1\n"
+            "st1w { z24.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "ble 8f\n"
+            "cmp x13, #0x2\n"
+            "st1w { z15.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "ble 8f\n"
+            "cmp x13, #0x3\n"
+            "st1w { z12.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "ble 8f\n"
+            "st1w { z0.s }, p1, [x20]\n"
+            "8:"  // Row tail: Accumulator store skip
+            "subs x24, x24, #0x8\n"
+            "add %x[res_ptr], %x[res_ptr], #0x20\n"
+            "bne 6b\n"
+            "subs x13, x13, #0x4\n"
+            "add %x[a_ptr], %x[a_ptr], x12\n"
+            "mov %x[res_ptr], x23\n"
+            "bgt 5b\n"
+            "9:"  // Row tail: Row loop skip
+            : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
+            : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
+            : "cc", "memory", "p0", "p1", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
+        );
+        return;
+    }
+#endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
+#elif defined(__AVX2__) || defined(__AVX512F__)
+    {
+        const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx;
+        const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *)vy;
+        int64_t b_nb = n / QK4_0;
+        int64_t y = 0;
+        // Mask to mask out nibbles from packed bytes
+        const __m256i m4b = _mm256_set1_epi8(0x0F);
+        const __m128i loadMask = _mm_blend_epi32(_mm_setzero_si128(), _mm_set1_epi32(0xFFFFFFFF), 3);
+        // Lookup table to convert signed nibbles to signed bytes
+        __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
+        signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
+        // Permute mask used for easier vector processing at later stages
+        __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
+        int64_t xstart = 0;
+        int anr = nr - nr%16; // Used to align nr with boundary of 16
+    #ifdef __AVX512F__
+        int anc = nc - nc%16; // Used to align nc with boundary of 16
+        // Mask to mask out nibbles from packed bytes expanded to 512 bit length
+        const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
+        // Lookup table to convert signed nibbles to signed bytes expanded to 512 bit length
+        __m512i signextendlutexpanded = _mm512_inserti32x8(_mm512_castsi256_si512(signextendlut), signextendlut, 1);
+
+        // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
+        for (; y < anr / 4; y += 4) {
+
+            const block_q8_0x4 * a_ptrs[4];
+
+            a_ptrs[0] = a_ptr_start + (y * nb);
+            for (int i = 0; i < 3; ++i) {
+                a_ptrs[i + 1] = a_ptrs[i] + nb;
+            }
+
+            // Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation
+            for (int64_t x = 0; x < anc / 8; x += 2) {
+
+                const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x)     * b_nb);
+                const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
+
+                // Master FP accumulators
+                __m512 acc_rows[16];
+                for (int i = 0; i < 16; i++) {
+                    acc_rows[i] = _mm512_setzero_ps();
+                }
+
+                for (int64_t b = 0; b < nb; b++) {
+                    // Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF
+                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs));
+                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32));
+                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64));
+                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96));
+
+                    const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs));
+                    const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32));
+                    const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64));
+                    const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96));
+
+                    // Save the values in the following vectors in the formats B0B1B4B5B8B9BCBD, B2B3B6B7BABBBEBF for further processing and storing of values
+                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+
+                    const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
+                    const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
+
+                    const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);
+                    const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);
+                    const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);
+                    const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);
+
+                    // 4-bit -> 8-bit - Sign is maintained
+                    const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7)
+                    const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7)
+
+                    const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15)
+                    const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15)
+
+                    const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23)
+                    const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23)
+
+                    const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31)
+                    const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
+
+                    // Shuffle pattern one - right side input
+                    const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
+                    const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
+
+                    const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
+                    const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
+
+                    const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
+                    const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
+
+                    const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
+                    const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
+
+                    // Shuffle pattern two - right side input
+
+                    const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
+                    const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
+
+                    const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
+                    const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
+
+                    const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
+                    const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
+
+                    const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
+                    const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
+
+                    // Scale values - Load the weight scale values of two block_q4_0x8
+                    const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
+
+                    // Process LHS in pairs of rows
+                    for (int rp = 0; rp < 4; rp++) {
+
+                        // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
+                        // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector
+                        __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));
+                        __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0);
+                        __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17);
+                        __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));
+                        __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0);
+                        __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17);
+                        __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));
+                        __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0);
+                        __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17);
+                        __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));
+                        __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0);
+                        __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17);
+
+                        __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1);
+                        __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1);
+                        __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1);
+                        __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1);
+                        __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1);
+                        __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1);
+                        __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1);
+                        __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1);
+
+                        // Shuffle pattern one - left side input
+
+                        const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160);  //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
+                        const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160);  //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
+
+                        const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160);  //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
+                        const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160);  //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
+
+                        const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160);  //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
+                        const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160);  //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
+
+                        const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160);  //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
+                        const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160);  //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
+
+                        // Shuffle pattern two - left side input
+
+                        const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245);  //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
+                        const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245);  //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
+
+                        const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245);  //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
+                        const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245);  //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
+
+                        const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245);  //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
+                        const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245);  //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
+
+                        const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245);  //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
+                        const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245);  //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
+
+                        // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+                        // Resembles MMLAs into 2x2 matrices in ARM Version
+                        __m512i iacc_mat_00_sp1 =
+                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1));
+                        __m512i iacc_mat_01_sp1 =
+                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1));
+                        __m512i iacc_mat_10_sp1 =
+                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1));
+                        __m512i iacc_mat_11_sp1 =
+                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1));
+                        __m512i iacc_mat_00_sp2 =
+                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2));
+                        __m512i iacc_mat_01_sp2 =
+                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2));
+                        __m512i iacc_mat_10_sp2 =
+                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2));
+                        __m512i iacc_mat_11_sp2 =
+                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2));
+
+                        // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                        __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
+                        __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
+                        __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
+                        __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
+
+
+                        // Straighten out to make 4 row vectors
+                        __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
+                        __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
+                        __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
+                        __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
+
+                        // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
+                        const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68);
+                        const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16);
+
+                        // Multiply with appropiate scales and accumulate
+                        acc_rows[rp * 4]     = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)),   acc_rows[rp * 4]);
+                        acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)),  acc_rows[rp * 4 + 1]);
+                        acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
+                        acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
+                    }
+                }
+
+                // Store the accumulated values
+                for (int i = 0; i < 16; i++) {
+                    _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
+                }
+            }
+        }
+        // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation
+        for (; y < nr / 4; y ++) {
+
+            const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb);
+
+            // Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation
+            for (int64_t x = 0; x < anc / 8; x += 2) {
+
+                const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x)     * b_nb);
+                const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
+
+                // Master FP accumulators
+                __m512 acc_rows[4];
+                for (int i = 0; i < 4; i++) {
+                    acc_rows[i] = _mm512_setzero_ps();
+                }
+
+                for (int64_t b = 0; b < nb; b++) {
+                    // Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF
+                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs));
+                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32));
+                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64));
+                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96));
+
+                    const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs));
+                    const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32));
+                    const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64));
+                    const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96));
+
+                    // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess
+                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+
+                    const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
+                    const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
+
+                    const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);
+                    const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);
+                    const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);
+                    const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);
+
+                    // 4-bit -> 8-bit - Sign is maintained
+                    const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7)
+                    const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7)
+
+                    const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15)
+                    const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15)
+
+                    const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23)
+                    const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23)
+
+                    const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31)
+                    const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
+
+                    // Shuffle pattern one - right side input
+                    const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
+                    const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
+
+                    const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
+                    const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
+
+                    const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
+                    const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
+
+                    const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
+                    const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
+
+                    // Shuffle pattern two - right side input
+
+                    const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
+                    const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
+
+                    const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
+                    const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
+
+                    const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
+                    const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
+
+                    const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
+                    const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
+
+
+                    // Scale values - Load the weight scale values of two block_q4_0x8
+                    const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
+
+                    // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
+                    // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector
+                    __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));
+                    __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0);
+                    __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17);
+                    __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));
+                    __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0);
+                    __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17);
+                    __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));
+                    __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0);
+                    __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17);
+                    __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));
+                    __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0);
+                    __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17);
+
+                    __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1);
+                    __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1);
+                    __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1);
+                    __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1);
+                    __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1);
+                    __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1);
+                    __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1);
+                    __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1);
+
+                    // Shuffle pattern one - left side input
+
+                    const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160);  //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
+                    const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160);  //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
+
+                    const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160);  //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
+                    const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160);  //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
+
+                    const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160);  //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
+                    const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160);  //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
+
+                    const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160);  //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
+                    const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160);  //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
+
+                    // Shuffle pattern two - left side input
+
+                    const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245);  //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
+                    const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245);  //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
+
+                    const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245);  //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
+                    const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245);  //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
+
+                    const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245);  //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
+                    const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245);  //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
+
+                    const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245);  //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
+                    const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245);  //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
+
+                    // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+                    // Resembles MMLAs into 2x2 matrices in ARM Version
+                    __m512i iacc_mat_00_sp1 =
+                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1));
+                    __m512i iacc_mat_01_sp1 =
+                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1));
+                    __m512i iacc_mat_10_sp1 =
+                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1));
+                    __m512i iacc_mat_11_sp1 =
+                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1));
+                    __m512i iacc_mat_00_sp2 =
+                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2));
+                    __m512i iacc_mat_01_sp2 =
+                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2));
+                    __m512i iacc_mat_10_sp2 =
+                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2));
+                    __m512i iacc_mat_11_sp2 =
+                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2));
+
+                    // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                    __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
+                    __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
+                    __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
+                    __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
+
+
+                    // Straighten out to make 4 row vectors
+                    __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
+                    __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
+                    __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
+                    __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
+
+                    // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
+                    const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68);
+                    const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16);
+
+                    // Multiply with appropiate scales and accumulate
+                    acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)),   acc_rows[0]);
+                    acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)),  acc_rows[1]);
+                    acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
+                    acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
+                }
+
+                // Store the accumulated values
+                for (int i = 0; i < 4; i++) {
+                    _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
+                }
+            }
+        }
+        if (anc != nc) {
+            xstart = anc/8;
+            y = 0;
+        }
+    #endif // __AVX512F__
+
+        // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
+
+        for (; y < anr / 4; y += 4) {
+            const block_q8_0x4 * a_ptrs[4];
+
+            a_ptrs[0] = a_ptr_start + (y * nb);
+            for (int i = 0; i < 3; ++i) {
+                a_ptrs[i + 1] = a_ptrs[i] + nb;
+            }
+
+            // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation
+            for (int64_t x = xstart; x < nc / 8; x++) {
+
+                const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
+
+                // Master FP accumulators
+                __m256 acc_rows[16];
+                for (int i = 0; i < 16; i++) {
+                    acc_rows[i] = _mm256_setzero_ps();
+                }
+
+                for (int64_t b = 0; b < nb; b++) {
+                    // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
+                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
+                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));
+                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));
+                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));
+
+                    // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values
+                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+
+                    // 4-bit -> 8-bit - Sign is maintained
+                    const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7)
+                    const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7)
+
+                    const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15)
+                    const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15)
+
+                    const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23)
+                    const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23)
+
+                    const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31)
+                    const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31)
+
+                    // Shuffle pattern one - right side input
+                    const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136);  //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3)
+                    const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136);  //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)
+
+                    const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136);  //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11)
+                    const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136);  //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11)
+
+                    const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136);  //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19)
+                    const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136);  //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19)
+
+                    const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136);  //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27)
+                    const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136);  //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27)
+
+                    // Shuffle pattern two - right side input
+
+                    const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221);  //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7)
+                    const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221);  //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7)
+
+                    const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221);  //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15)
+                    const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221);  //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15)
+
+                    const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221);  //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23)
+                    const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221);  //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23)
+
+                    const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221);  //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31)
+                    const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221);  //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31)
+
+                    // Scale values - Load the wight scale values of block_q4_0x8
+                    const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
+
+                    // Process LHS in groups of four
+                    for (int rp = 0; rp < 4; rp++) {
+                        // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
+                        // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
+                        __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));
+                        __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
+                        __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
+                        __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));
+                        __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
+                        __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
+                        __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));
+                        __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
+                        __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
+                        __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));
+                        __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
+                        __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
+
+                        // Shuffle pattern one - left side input
+                        const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160);  //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
+                        const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160);  //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
+
+                        const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160);  //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
+                        const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160);  //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
+
+                        const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160);  //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
+                        const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160);  //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
+
+                        const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160);  //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
+                        const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160);  //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
+
+                        // Shuffle pattern two - left side input
+                        const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245);  //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
+                        const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245);  //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
+
+                        const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245);  //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
+                        const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245);  //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
+
+                        const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245);  //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
+                        const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245);  //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
+
+                        const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245);  //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
+                        const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245);  //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
+
+                        // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+                        // Resembles MMLAs into 2x2 matrices in ARM Version
+                        __m256i iacc_mat_00_sp1 =
+                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1));
+                        __m256i iacc_mat_01_sp1 =
+                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1));
+                        __m256i iacc_mat_10_sp1 =
+                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1));
+                        __m256i iacc_mat_11_sp1 =
+                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1));
+                        __m256i iacc_mat_00_sp2 =
+                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
+                        __m256i iacc_mat_01_sp2 =
+                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
+                        __m256i iacc_mat_10_sp2 =
+                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
+                        __m256i iacc_mat_11_sp2 =
+                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
+
+                        // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                        __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
+                        __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
+                        __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
+                        __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
+
+                        // Straighten out to make 4 row vectors
+                        __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);
+                        __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);
+                        __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);
+                        __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);
+
+                        // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
+                        const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
+
+                        // Multiply with appropiate scales and accumulate
+                        acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
+                        acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
+                        acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
+                        acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32,  255)), acc_rows[rp * 4 + 3]);
+                    }
+                }
+
+                // Store the accumulated values
+                for (int i = 0; i < 16; i++) {
+                    _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
+                }
+            }
+        }
+
+        // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation
+        for (; y < nr / 4; y ++) {
+
+            const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb);
+
+            // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
+            for (int64_t x = xstart; x < nc / 8; x++) {
+
+                const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
+
+                // Master FP accumulators
+                __m256 acc_rows[4];
+                for (int i = 0; i < 4; i++) {
+                    acc_rows[i] = _mm256_setzero_ps();
+                }
+
+                for (int64_t b = 0; b < nb; b++) {
+                    // Load the eight block_q8_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
+                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
+                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));
+                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));
+                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));
+
+                    // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess
+                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+
+                    // 4-bit -> 8-bit - Sign is maintained
+                    const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b));  //B0(0-7) B1(0-7) B4(0-7) B5(0-7)
+                    const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b));  //B2(0-7) B3(0-7) B6(0-7) B7(0-7)
+
+                    const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b));  //B0(8-15) B1(8-15) B4(8-15) B5(8-15)
+                    const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b));  //B2(8-15) B3(8-15) B6(8-15) B7(8-15)
+
+                    const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b));  //B0(16-23) B1(16-23) B4(16-23) B5(16-23)
+                    const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b));  //B2(16-23) B3(16-23) B6(16-23) B7(16-23)
+
+                    const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b));  //B0(24-31) B1(24-31) B4(24-31) B5(24-31)
+                    const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b));  //B2(24-31) B3(24-31) B6(24-31) B7(24-31)
+
+                    // Shuffle pattern one - right side input
+                    const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136);  //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3)
+                    const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136);  //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)
+
+                    const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136);  //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11)
+                    const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136);  //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11)
+
+                    const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136);  //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19)
+                    const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136);  //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19)
+
+                    const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136);  //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27)
+                    const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136);  //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27)
+
+                    // Shuffle pattern two - right side input
+
+                    const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221);  //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7)
+                    const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221);  //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7)
+
+                    const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221);  //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15)
+                    const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221);  //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15)
+
+                    const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221);  //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23)
+                    const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221);  //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23)
+
+                    const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221);  //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31)
+                    const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221);  //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31)
+
+                    // Scale values - Load the wight scale values of block_q4_0x8
+                    const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
+
+                    // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
+                    // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
+                    __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));
+                    __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
+                    __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
+                    __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));
+                    __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
+                    __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
+                    __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));
+                    __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
+                    __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
+                    __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));
+                    __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
+                    __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
+
+                    // Shuffle pattern one - left side input
+
+                    const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160);  //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
+                    const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160);  //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
+
+                    const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160);  //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
+                    const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160);  //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
+
+                    const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160);  //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
+                    const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160);  //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
+
+                    const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160);  //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
+                    const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160);  //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
+
+                    // Shuffle pattern two - left side input
+
+                    const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245);  //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
+                    const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245);  //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
+
+                    const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245);  //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
+                    const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245);  //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
+
+                    const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245);  //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
+                    const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245);  //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
+
+                    const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245);  //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
+                    const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245);  //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
+
+                    // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+                    // Resembles MMLAs into 2x2 matrices in ARM Version
+                    __m256i iacc_mat_00_sp1 =
+                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1));
+                    __m256i iacc_mat_01_sp1 =
+                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1));
+                    __m256i iacc_mat_10_sp1 =
+                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1));
+                    __m256i iacc_mat_11_sp1 =
+                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1));
+                    __m256i iacc_mat_00_sp2 =
+                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
+                    __m256i iacc_mat_01_sp2 =
+                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
+                    __m256i iacc_mat_10_sp2 =
+                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
+                    __m256i iacc_mat_11_sp2 =
+                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
+
+                    // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                    __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
+                    __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
+                    __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
+                    __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
+
+
+                    // Straighten out to make 4 row vectors
+                    __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);
+                    __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);
+                    __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);
+                    __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);
+
+                    // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
+                    const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptr[b].d, loadMask);
+
+                    // Multiply with appropiate scales and accumulate
+                    acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
+                    acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
+                    acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
+                    acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
+                }
+
+                // Store the accumulated values
+                for (int i = 0; i < 4; i++) {
+                    _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
+                }
+            }
+        }
+        return;
+    }
+#elif defined(__riscv_v_intrinsic)
+    if (__riscv_vlenb() >= QK4_0) {
+        const size_t vl = QK4_0;
+
+        for (int y = 0; y < nr / 4; y++) {
+            const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+            for (int x = 0; x < nc / ncols_interleaved; x++) {
+                const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
+                vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
+                vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
+                vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
+                vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
+                for (int l = 0; l < nb; l++) {
+                    const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
+                    const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
+                    const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);
+                    const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0);
+                    const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1);
+                    const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0);
+                    const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1);
+
+                    // vector version needs Zvfhmin extension
+                    const float a_scales[4] = {
+                        GGML_FP16_TO_FP32(a_ptr[l].d[0]),
+                        GGML_FP16_TO_FP32(a_ptr[l].d[1]),
+                        GGML_FP16_TO_FP32(a_ptr[l].d[2]),
+                        GGML_FP16_TO_FP32(a_ptr[l].d[3])
+                    };
+                    const float b_scales[8] = {
+                        GGML_FP16_TO_FP32(b_ptr[l].d[0]),
+                        GGML_FP16_TO_FP32(b_ptr[l].d[1]),
+                        GGML_FP16_TO_FP32(b_ptr[l].d[2]),
+                        GGML_FP16_TO_FP32(b_ptr[l].d[3]),
+                        GGML_FP16_TO_FP32(b_ptr[l].d[4]),
+                        GGML_FP16_TO_FP32(b_ptr[l].d[5]),
+                        GGML_FP16_TO_FP32(b_ptr[l].d[6]),
+                        GGML_FP16_TO_FP32(b_ptr[l].d[7])
+                    };
+                    const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4);
+
+                    const int64_t A0 = *(const int64_t *)&a_ptr[l].qs[0];
+                    const int64_t A4 = *(const int64_t *)&a_ptr[l].qs[32];
+                    const int64_t A8 = *(const int64_t *)&a_ptr[l].qs[64];
+                    const int64_t Ac = *(const int64_t *)&a_ptr[l].qs[96];
+                    __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
+                    vint16m4_t sumi_l0;
+                    {
+                        const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A0, vl / 4));
+                        const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A4, vl / 4));
+                        const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A8, vl / 4));
+                        const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ac, vl / 4));
+                        const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
+                        const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
+                        const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
+                        const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
+
+                        sumi_l0 = sumi_hi_m;
+                    }
+
+                    {
+                        const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l0));
+                        const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
+                        const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
+                        const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
+                        const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
+                        const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
+                        const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
+                        const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
+                        const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
+                        const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
+                        const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
+                        const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
+                        const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
+
+                        const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4);
+                        sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, tmp1, b_scales_vec, vl / 4);
+                    }
+
+                    const int64_t A1 = *(const int64_t *)&a_ptr[l].qs[8];
+                    const int64_t A5 = *(const int64_t *)&a_ptr[l].qs[40];
+                    const int64_t A9 = *(const int64_t *)&a_ptr[l].qs[72];
+                    const int64_t Ad = *(const int64_t *)&a_ptr[l].qs[104];
+                    __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
+                    vint16m4_t sumi_l1;
+                    {
+                        const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A1, vl / 4));
+                        const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A5, vl / 4));
+                        const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A9, vl / 4));
+                        const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ad, vl / 4));
+                        const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
+                        const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
+                        const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
+                        const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
+
+                        sumi_l1 = sumi_hi_m;
+                    }
+
+                    {
+                        const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l1));
+                        const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
+                        const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
+                        const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
+                        const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
+                        const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
+                        const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
+                        const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
+                        const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
+                        const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
+                        const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
+                        const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
+                        const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
+
+                        const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[1], vl / 4);
+                        sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, tmp1, b_scales_vec, vl / 4);
+                    }
+
+                    const int64_t A2 = *(const int64_t *)&a_ptr[l].qs[16];
+                    const int64_t A6 = *(const int64_t *)&a_ptr[l].qs[48];
+                    const int64_t Aa = *(const int64_t *)&a_ptr[l].qs[80];
+                    const int64_t Ae = *(const int64_t *)&a_ptr[l].qs[112];
+                    __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
+                    vint16m4_t sumi_l2;
+                    {
+                        const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A2, vl / 4));
+                        const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A6, vl / 4));
+                        const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Aa, vl / 4));
+                        const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ae, vl / 4));
+                        const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
+                        const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
+                        const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
+                        const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
+
+                        sumi_l2 = sumi_hi_m;
+                    }
+
+                    {
+                        const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l2));
+                        const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
+                        const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
+                        const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
+                        const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
+                        const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
+                        const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
+                        const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
+                        const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
+                        const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
+                        const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
+                        const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
+                        const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
+
+                        const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[2], vl / 4);
+                        sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, tmp1, b_scales_vec, vl / 4);
+                    }
+
+                    const int64_t A3 = *(const int64_t *)&a_ptr[l].qs[24];
+                    const int64_t A7 = *(const int64_t *)&a_ptr[l].qs[56];
+                    const int64_t Ab = *(const int64_t *)&a_ptr[l].qs[88];
+                    const int64_t Af = *(const int64_t *)&a_ptr[l].qs[120];
+                    __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
+                    vint16m4_t sumi_l3;
+                    {
+                        const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A3, vl / 4));
+                        const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A7, vl / 4));
+                        const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ab, vl / 4));
+                        const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Af, vl / 4));
+                        const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
+                        const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
+                        const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
+                        const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
+
+                        sumi_l3 = sumi_hi_m;
+                    }
+
+                    {
+                        const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l3));
+                        const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
+                        const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
+                        const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
+                        const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
+                        const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
+                        const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
+                        const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
+                        const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
+                        const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
+                        const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
+                        const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
+                        const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
+
+                        const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[3], vl / 4);
+                        sumf3 = __riscv_vfmacc_vv_f32m1(sumf3, tmp1, b_scales_vec, vl / 4);
+                    }
+                }
+                __riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4);
+                __riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, vl / 4);
+                __riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, vl / 4);
+                __riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, vl / 4);
+            }
+        }
+
+        return;
+    }
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
+    float sumf[4][8];
+    int sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
+                            }
+                            sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++)
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+            }
+        }
+    }
+}
+
+void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+        const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
+
+        for (int y = 0; y < nr / 4; y++) {
+            const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+            for (int x = 0; x < nc / ncols_interleaved; x++) {
+                const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
+
+                float32x4_t sumf[4];
+                for (int m = 0; m < 4; m++) {
+                    sumf[m] = vdupq_n_f32(0);
+                }
+
+                for (int l = 0; l < nb; l++) {
+                    float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
+                    float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
+
+                    int32x4_t sumi_0 = vdupq_n_s32(0);
+                    int32x4_t sumi_1 = vdupq_n_s32(0);
+                    int32x4_t sumi_2 = vdupq_n_s32(0);
+                    int32x4_t sumi_3 = vdupq_n_s32(0);
+
+                    for (int k = 0; k < 4; k++) {
+                        int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
+                        int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
+
+                        uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
+                        int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
+                        int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
+
+                        sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
+                        sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
+                        sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
+                        sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
+                        sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
+                        sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
+                        sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
+                        sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
+                    }
+
+                    sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
+                    sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
+                    sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
+                    sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
+                }
+
+                for (int m = 0; m < 4; m++) {
+                    vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
+                }
+            }
+        }
+        return;
+    }
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
+    {
+        float sumf[4][4];
+        int sumi;
+
+        for (int y = 0; y < nr / 4; y++) {
+            const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+            for (int x = 0; x < nc / ncols_interleaved; x++) {
+                const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
+                for (int m = 0; m < 4; m++) {
+                    for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+                }
+                for (int l = 0; l < nb; l++) {
+                    for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                        for (int m = 0; m < 4; m++) {
+                            for (int j = 0; j < ncols_interleaved; j++) {
+                                sumi = 0;
+                                for (int i = 0; i < blocklen; ++i) {
+                                    const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
+                                    const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
+                                    sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+                                            (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
+                                }
+                                sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
+                            }
+                        }
+                    }
+                }
+                for (int m = 0; m < 4; m++) {
+                    for (int j = 0; j < ncols_interleaved; j++)
+                        s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+                }
+            }
+        }
+    }
+}
+
+// FIXME: this code is duplicated from ggml-aarch64.c
+static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
+    block_q4_0x4 out;
+
+    for (int i = 0; i < 4; i++) {
+        out.d[i] = in[i].d;
+    }
+
+    const int end = QK4_0 * 2 / blck_size_interleave;
+
+    if (blck_size_interleave == 8) {
+        const uint64_t xor_mask = 0x8888888888888888ULL;
+        for (int i = 0; i < end; ++i) {
+            int src_id = i % 4;
+            int src_offset = (i / 4) * blck_size_interleave;
+            int dst_offset = i * blck_size_interleave;
+
+            uint64_t elems;
+            // Using memcpy to avoid unaligned memory accesses
+            memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
+            elems ^= xor_mask;
+            memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
+        }
+    } else if (blck_size_interleave == 4) {
+        const uint32_t xor_mask = 0x88888888;
+        for (int i = 0; i < end; ++i) {
+            int src_id = i % 4;
+            int src_offset = (i / 4) * blck_size_interleave;
+            int dst_offset = i * blck_size_interleave;
+
+            uint32_t elems;
+            memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint32_t));
+            elems ^= xor_mask;
+            memcpy(&out.qs[dst_offset], &elems, sizeof(uint32_t));
+        }
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    return out;
+}
+
+// interleave 8 block_q4_0s in blocks of blck_size_interleave
+// returns an interleaved block_q4_0x8
+// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks
+// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave
+static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave) {
+    block_q4_0x8 out;
+
+    for (int i = 0; i < 8; i++) {
+        out.d[i] = in[i].d;
+    }
+
+    const int end = QK4_0 * 4 / blck_size_interleave;
+    const uint64_t xor_mask = 0x8888888888888888ULL;
+
+    for (int i = 0; i < end; ++i) {
+        int src_id = i % 8;
+        int src_offset = (i / 8) * blck_size_interleave;
+        int dst_offset = i * blck_size_interleave;
+
+        uint64_t elems;
+        memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
+        elems ^= xor_mask;
+        memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
+    }
+
+    return out;
+}
+
+static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * restrict data, size_t data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
+    GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
+
+    block_q4_0x4 * dst = (block_q4_0x4 *)t->data;
+    const block_q4_0 * src = (const block_q4_0 *)data;
+    block_q4_0 dst_tmp[4];
+    int nrow = t->ne[1]; // Number of rows
+    int nrows_interleaved = 4;
+    int nblocks = t->ne[0] / QK4_0;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
+
+    if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i = 0; i < nrows_interleaved; i++) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_q4_0x4(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+
+    GGML_UNUSED(data_size);
+}
+
+static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor *t, int interleave_block, const void * restrict data, size_t data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
+    GGML_ASSERT(interleave_block == 8);
+
+    block_q4_0x8 * dst = (block_q4_0x8*)t->data;
+    const block_q4_0 * src = (const block_q4_0*) data;
+    block_q4_0 dst_tmp[8];
+    int nrow = t->ne[1]; // Number of rows
+    int nrows_interleaved = 8;
+    int nblocks = t->ne[0] / QK4_0;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
+
+    if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i  = 0; i < nrows_interleaved; i++ ) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_q4_0x8(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+
+    GGML_UNUSED(data_size);
+}
+
+static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) {
+    block_iq4_nlx4 out;
+
+    for (int i = 0; i < 4; i++) {
+        out.d[i] = in[i].d;
+    }
+
+    const int end = QK4_NL * 2 / blck_size_interleave;
+
+    if (blck_size_interleave == 8) {
+        for (int i = 0; i < end; ++i) {
+            int src_id = i % 4;
+            int src_offset = (i / 4) * blck_size_interleave;
+            int dst_offset = i * blck_size_interleave;
+
+            // Using memcpy to avoid unaligned memory accesses
+            memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
+        }
+    } else if (blck_size_interleave == 4) {
+        for (int i = 0; i < end; ++i) {
+            int src_id = i % 4;
+            int src_offset = (i / 4) * blck_size_interleave;
+            int dst_offset = i * blck_size_interleave;
+
+            memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t));
+        }
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    return out;
+}
+
+static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * restrict data, size_t data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
+    GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
+
+    block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data;
+    const block_iq4_nl * src = (const block_iq4_nl *)data;
+    block_iq4_nl dst_tmp[4];
+    int nrow = t->ne[1]; // Number of rows
+    int nrows_interleaved = 4;
+    int nblocks = t->ne[0] / QK4_0;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
+
+    if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i = 0; i < nrows_interleaved; i++) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_iq4_nlx4(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+
+    GGML_UNUSED(data_size);
+}
+
+// Prepare for optimized kernels if applicable
+void ggml_aarch64_repack_tensor(struct ggml_tensor * cur, enum ggml_type repack_type, const void * restrict data, size_t data_size) {
+    if (cur->type == repack_type) {
+        memcpy(cur->data, data, data_size);
+        return;
+    }
+
+    if (cur->type == GGML_TYPE_Q4_0) {
+        switch (repack_type) {
+            case GGML_TYPE_Q4_0_8_8:
+                repack_q4_0_to_q4_0_8_bl(cur, 8, data, data_size);
+                break;
+            case GGML_TYPE_Q4_0_4_8:
+                repack_q4_0_to_q4_0_4_bl(cur, 8, data, data_size);
+                break;
+            case GGML_TYPE_Q4_0_4_4:
+                repack_q4_0_to_q4_0_4_bl(cur, 4, data, data_size);
+                break;
+            default:
+                GGML_ABORT("Unsupported type");
+        }
+    } else if (cur->type == GGML_TYPE_IQ4_NL) {
+        switch (repack_type) {
+            case GGML_TYPE_IQ4_NL_4_4:
+                repack_iq4_nl_to_iq4_nl_4_bl(cur, 4, data, data_size);
+                break;
+            default:
+                GGML_ABORT("Unsupported type");
+        }
+    } else {
+        GGML_ABORT("Unsupported type");
+    }
+}
+
+enum ggml_type ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur) {
+    if (cur->type == GGML_TYPE_Q4_0) {
+        // TODO: enable for AVX2 - currently disabled due to bad gemv performance
+        if (/* ggml_cpu_has_avx2() || */ (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) {
+            return GGML_TYPE_Q4_0_8_8;
+        }
+        if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+            return GGML_TYPE_Q4_0_4_8;
+        }
+        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+            return GGML_TYPE_Q4_0_4_4;
+        }
+    } else if (cur->type == GGML_TYPE_IQ4_NL) {
+        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+            return GGML_TYPE_IQ4_NL_4_4;
+        }
+    }
+
+    return cur->type;
+}

+ 58 - 0
llama/ggml-cpu-aarch64.h

@@ -0,0 +1,58 @@
+/**
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "ggml.h"
+
+// GGML internal header
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// Quantization
+void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t blck_size_interleave);
+
+// GEMV
+void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+
+// GEMM
+void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+
+void           ggml_aarch64_repack_tensor(struct ggml_tensor * cur, enum ggml_type repack_type, const void * data, size_t data_size);
+enum ggml_type ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur);
+
+#ifdef __cplusplus
+}
+#endif
+

+ 15 - 243
llama/ggml-cpu-impl.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -41,6 +41,18 @@
 extern "C" {
 #endif
 
+struct ggml_compute_params {
+    // ith = thread index, nth = number of threads
+    int ith, nth;
+
+    // work buffer for all threads
+    size_t wsize;
+    void * wdata;
+
+    struct ggml_threadpool * threadpool;
+};
+
+
 #if defined(_MSC_VER)
 
 #define m512bh(p) p
@@ -53,80 +65,6 @@ extern "C" {
 
 #endif
 
-/**
- * Converts brain16 to float32.
- *
- * The bfloat16 floating point format has the following structure:
- *
- *       ┌sign
- *       │
- *       │   ┌exponent
- *       │   │
- *       │   │      ┌mantissa
- *       │   │      │
- *       │┌──┴───┐┌─┴───┐
- *     0b0000000000000000 brain16
- *
- * Since bf16 has the same number of exponent bits as a 32bit float,
- * encoding and decoding numbers becomes relatively straightforward.
- *
- *       ┌sign
- *       │
- *       │   ┌exponent
- *       │   │
- *       │   │      ┌mantissa
- *       │   │      │
- *       │┌──┴───┐┌─┴───────────────────┐
- *     0b00000000000000000000000000000000 IEEE binary32
- *
- * For comparison, the standard fp16 format has fewer exponent bits.
- *
- *       ┌sign
- *       │
- *       │  ┌exponent
- *       │  │
- *       │  │    ┌mantissa
- *       │  │    │
- *       │┌─┴─┐┌─┴──────┐
- *     0b0000000000000000 IEEE binary16
- *
- * @see IEEE 754-2008
- */
-static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
-    union {
-        float f;
-        uint32_t i;
-    } u;
-    u.i = (uint32_t)h.bits << 16;
-    return u.f;
-}
-
-/**
- * Converts float32 to brain16.
- *
- * This is binary identical with Google Brain float conversion.
- * Floats shall round to nearest even, and NANs shall be quiet.
- * Subnormals aren't flushed to zero, except perhaps when used.
- * This code should vectorize nicely if using modern compilers.
- */
-static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
-    ggml_bf16_t h;
-    union {
-        float f;
-        uint32_t i;
-    } u;
-    u.f = s;
-    if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */
-        h.bits = (u.i >> 16) | 64; /* force to quiet */
-        return h;
-    }
-    h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;
-    return h;
-}
-
-#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
-#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
-
 // __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
 #if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
 #ifndef __FMA__
@@ -414,28 +352,6 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
 
 #endif // defined(__ARM_NEON)
 
-#if defined(__ARM_NEON) && !defined(_MSC_VER)
-
-#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
-#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
-
-#define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
-
-static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
-    ggml_fp16_internal_t tmp;
-    memcpy(&tmp, &h, sizeof(ggml_fp16_t));
-    return (float)tmp;
-}
-
-static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
-    ggml_fp16_t res;
-    ggml_fp16_internal_t tmp = f;
-    memcpy(&res, &tmp, sizeof(ggml_fp16_t));
-    return res;
-}
-
-#else
-
 #ifdef __wasm_simd128__
 #include <wasm_simd128.h>
 #else
@@ -488,152 +404,8 @@ static __m256 __lasx_xvreplfr2vr_s(float val) {
 }
 #endif
 
-#ifdef __F16C__
-
-#ifdef _MSC_VER
-#define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
-#define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
-#else
-#define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
-#define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
-#endif
-
-#elif defined(__POWER9_VECTOR__)
-
-#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
-#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
-/* the inline asm below is about 12% faster than the lookup method */
-#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
-#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
-
-static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
-    register float f;
-    register double d;
-    __asm__(
-        "mtfprd %0,%2\n"
-        "xscvhpdp %0,%0\n"
-        "frsp %1,%0\n" :
-        /* temp */ "=d"(d),
-        /* out */  "=f"(f):
-        /* in */   "r"(h));
-    return f;
-}
-
-static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
-    register double d;
-    register ggml_fp16_t r;
-    __asm__( /* xscvdphp can work on double or single precision */
-        "xscvdphp %0,%2\n"
-        "mffprd %1,%0\n" :
-        /* temp */ "=d"(d),
-        /* out */  "=r"(r):
-        /* in */   "f"(f));
-    return r;
-}
-
-#else
-
-// FP16 <-> FP32
-// ref: https://github.com/Maratyszcza/FP16
-
-static inline float fp32_from_bits(uint32_t w) {
-    union {
-        uint32_t as_bits;
-        float as_value;
-    } fp32;
-    fp32.as_bits = w;
-    return fp32.as_value;
-}
-
-static inline uint32_t fp32_to_bits(float f) {
-    union {
-        float as_value;
-        uint32_t as_bits;
-    } fp32;
-    fp32.as_value = f;
-    return fp32.as_bits;
-}
-
-static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
-    const uint32_t w = (uint32_t) h << 16;
-    const uint32_t sign = w & UINT32_C(0x80000000);
-    const uint32_t two_w = w + w;
-
-    const uint32_t exp_offset = UINT32_C(0xE0) << 23;
-#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
-    const float exp_scale = 0x1.0p-112f;
-#else
-    const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
-#endif
-    const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
-
-    const uint32_t magic_mask = UINT32_C(126) << 23;
-    const float magic_bias = 0.5f;
-    const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
-
-    const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
-    const uint32_t result = sign |
-        (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
-    return fp32_from_bits(result);
-}
-
-static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
-#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
-    const float scale_to_inf = 0x1.0p+112f;
-    const float scale_to_zero = 0x1.0p-110f;
-#else
-    const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
-    const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
-#endif
-    float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
-
-    const uint32_t w = fp32_to_bits(f);
-    const uint32_t shl1_w = w + w;
-    const uint32_t sign = w & UINT32_C(0x80000000);
-    uint32_t bias = shl1_w & UINT32_C(0xFF000000);
-    if (bias < UINT32_C(0x71000000)) {
-        bias = UINT32_C(0x71000000);
-    }
-
-    base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
-    const uint32_t bits = fp32_to_bits(base);
-    const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
-    const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
-    const uint32_t nonsign = exp_bits + mantissa_bits;
-    return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
-}
-
-#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
-#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
-
-#endif // __F16C__
-
-#endif // defined(__ARM_NEON) && (!defined(__MSC_VER)
-
-#ifdef __ARM_FEATURE_SVE
-#include <arm_sve.h>
-#endif // __ARM_FEATURE_SVE
-
-// precomputed f32 table for f16 (256 KB)
-// defined in ggml.c, initialized in ggml_init()
-extern float ggml_table_f32_f16[1 << 16];
-
-// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
-// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
-// This is also true for POWER9.
-#if !defined(GGML_FP16_TO_FP32)
-inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
-    uint16_t s;
-    memcpy(&s, &f, sizeof(uint16_t));
-    return ggml_table_f32_f16[s];
-}
-
-#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
-#endif
-
-#if !defined(GGML_FP32_TO_FP16)
-#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
-#endif
+// TODO: move to ggml-threading
+void ggml_barrier(struct ggml_threadpool * tp);
 
 #ifdef __cplusplus
 }

+ 10861 - 0
llama/ggml-cpu-quants.c

@@ -0,0 +1,10861 @@
+/**
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#define GGML_COMMON_IMPL_C
+#include "ggml-common.h"
+
+#include "ggml-quants.h"
+#include "ggml-cpu-quants.h"
+#include "ggml-impl.h"
+#include "ggml-cpu-impl.h"
+#include "ggml-cpu.h"
+
+#include <math.h>
+#include <string.h>
+#include <assert.h>
+#include <float.h>
+#include <stdlib.h> // for qsort
+#include <stdio.h>  // for GGML_ASSERT
+
+#define GROUP_MAX_EPS 1e-15f
+#define GROUP_MAX_EPS_IQ3_XXS 1e-8f
+#define GROUP_MAX_EPS_IQ2_S 1e-8f
+#define GROUP_MAX_EPS_IQ1_M 1e-7f
+#define GROUP_MAX_EPS_IQ1_S 1e-12f
+
+#if defined(_MSC_VER)
+// disable "possible loss of data" to avoid warnings for hundreds of casts
+// we should just be careful :)
+#pragma warning(disable: 4244 4267)
+#endif
+
+#define UNUSED GGML_UNUSED
+
+// some compilers don't provide _mm256_set_m128i, e.g. gcc 7
+#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
+
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
+// multiply int8_t, add results pairwise twice
+static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
+    // Get absolute values of x vectors
+    const __m128i ax = _mm_sign_epi8(x, x);
+    // Sign the values of the y vectors
+    const __m128i sy = _mm_sign_epi8(y, x);
+    // Perform multiplication and create 16-bit values
+    const __m128i dot = _mm_maddubs_epi16(ax, sy);
+    const __m128i ones = _mm_set1_epi16(1);
+    return _mm_madd_epi16(ones, dot);
+}
+
+#if __AVX__ || __AVX2__ || __AVX512F__
+// horizontally add 8 floats
+static inline float hsum_float_8(const __m256 x) {
+    __m128 res = _mm256_extractf128_ps(x, 1);
+    res = _mm_add_ps(res, _mm256_castps256_ps128(x));
+    res = _mm_add_ps(res, _mm_movehl_ps(res, res));
+    res = _mm_add_ss(res, _mm_movehdup_ps(res));
+    return _mm_cvtss_f32(res);
+}
+
+// horizontally add 8 int32_t
+static inline int hsum_i32_8(const __m256i a) {
+    const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
+    const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
+    const __m128i sum64 = _mm_add_epi32(hi64, sum128);
+    const __m128i hi32  = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
+    return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
+}
+
+// horizontally add 4 int32_t
+static inline int hsum_i32_4(const __m128i a) {
+    const __m128i hi64 = _mm_unpackhi_epi64(a, a);
+    const __m128i sum64 = _mm_add_epi32(hi64, a);
+    const __m128i hi32  = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
+    return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
+}
+
+#if defined(__AVX2__) || defined(__AVX512F__)
+// spread 32 bits to 32 bytes { 0x00, 0xFF }
+static inline __m256i bytes_from_bits_32(const uint8_t * x) {
+    uint32_t x32;
+    memcpy(&x32, x, sizeof(uint32_t));
+    const __m256i shuf_mask = _mm256_set_epi64x(
+            0x0303030303030303, 0x0202020202020202,
+            0x0101010101010101, 0x0000000000000000);
+    __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
+    const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
+    bytes = _mm256_or_si256(bytes, bit_mask);
+    return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
+}
+
+// Unpack 32 4-bit fields into 32 bytes
+// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
+static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
+{
+    const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
+    const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
+    const __m256i lowMask = _mm256_set1_epi8( 0xF );
+    return _mm256_and_si256(lowMask, bytes);
+}
+
+// add int16_t pairwise and return as float vector
+static inline __m256 sum_i16_pairs_float(const __m256i x) {
+    const __m256i ones = _mm256_set1_epi16(1);
+    const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
+    return _mm256_cvtepi32_ps(summed_pairs);
+}
+
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
+#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
+    const __m256i zero = _mm256_setzero_si256();
+    const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
+    return _mm256_cvtepi32_ps(summed_pairs);
+#else
+    // Perform multiplication and create 16-bit values
+    const __m256i dot = _mm256_maddubs_epi16(ax, sy);
+    return sum_i16_pairs_float(dot);
+#endif
+}
+
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+#if __AVXVNNIINT8__
+    const __m256i zero = _mm256_setzero_si256();
+    const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
+    return _mm256_cvtepi32_ps(summed_pairs);
+#else
+    // Get absolute values of x vectors
+    const __m256i ax = _mm256_sign_epi8(x, x);
+    // Sign the values of the y vectors
+    const __m256i sy = _mm256_sign_epi8(y, x);
+    return mul_sum_us8_pairs_float(ax, sy);
+#endif
+}
+
+static inline __m128i packNibbles( __m256i bytes )
+{
+    // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
+#if __AVX512F__
+    const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4);   // 0000_0000_abcd_0000
+    bytes = _mm256_or_si256(bytes, bytes_srli_4);               // 0000_abcd_abcd_efgh
+    return _mm256_cvtepi16_epi8(bytes);                         // abcd_efgh
+#else
+    const __m256i lowByte = _mm256_set1_epi16( 0xFF );
+    __m256i high = _mm256_andnot_si256( lowByte, bytes );
+    __m256i low = _mm256_and_si256( lowByte, bytes );
+    high = _mm256_srli_epi16( high, 4 );
+    bytes = _mm256_or_si256( low, high );
+
+    // Compress uint16_t lanes into bytes
+    __m128i r0 = _mm256_castsi256_si128( bytes );
+    __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
+    return _mm_packus_epi16( r0, r1 );
+#endif
+}
+#elif defined(__AVX__)
+static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
+{
+    // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
+    const __m128i lowByte = _mm_set1_epi16( 0xFF );
+    __m128i high = _mm_andnot_si128( lowByte, bytes1 );
+    __m128i low = _mm_and_si128( lowByte, bytes1 );
+    high = _mm_srli_epi16( high, 4 );
+    bytes1 = _mm_or_si128( low, high );
+    high = _mm_andnot_si128( lowByte, bytes2 );
+    low = _mm_and_si128( lowByte, bytes2 );
+    high = _mm_srli_epi16( high, 4 );
+    bytes2 = _mm_or_si128( low, high );
+
+    return _mm_packus_epi16( bytes1, bytes2);
+}
+
+static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
+    const __m128i ax = _mm_sign_epi8(x, x);
+    const __m128i sy = _mm_sign_epi8(y, x);
+    return _mm_maddubs_epi16(ax, sy);
+}
+
+// spread 32 bits to 32 bytes { 0x00, 0xFF }
+static inline __m256i bytes_from_bits_32(const uint8_t * x) {
+    uint32_t x32;
+    memcpy(&x32, x, sizeof(uint32_t));
+    const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
+    const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
+    __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
+    __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
+    const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
+    bytesl = _mm_or_si128(bytesl, bit_mask);
+    bytesh = _mm_or_si128(bytesh, bit_mask);
+    bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
+    bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
+    return MM256_SET_M128I(bytesh, bytesl);
+}
+
+// Unpack 32 4-bit fields into 32 bytes
+// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
+static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
+{
+    // Load 16 bytes from memory
+    __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
+    __m128i tmph = _mm_srli_epi16(tmpl, 4);
+    const __m128i lowMask = _mm_set1_epi8(0xF);
+    tmpl = _mm_and_si128(lowMask, tmpl);
+    tmph = _mm_and_si128(lowMask, tmph);
+    return MM256_SET_M128I(tmph, tmpl);
+}
+
+// add int16_t pairwise and return as float vector
+static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
+    const __m128i ones = _mm_set1_epi16(1);
+    const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
+    const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
+    const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
+    return _mm256_cvtepi32_ps(summed_pairs);
+}
+
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
+    const __m128i axl = _mm256_castsi256_si128(ax);
+    const __m128i axh = _mm256_extractf128_si256(ax, 1);
+    const __m128i syl = _mm256_castsi256_si128(sy);
+    const __m128i syh = _mm256_extractf128_si256(sy, 1);
+    // Perform multiplication and create 16-bit values
+    const __m128i dotl = _mm_maddubs_epi16(axl, syl);
+    const __m128i doth = _mm_maddubs_epi16(axh, syh);
+    return sum_i16_pairs_float(doth, dotl);
+}
+
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+    const __m128i xl = _mm256_castsi256_si128(x);
+    const __m128i xh = _mm256_extractf128_si256(x, 1);
+    const __m128i yl = _mm256_castsi256_si128(y);
+    const __m128i yh = _mm256_extractf128_si256(y, 1);
+    // Get absolute values of x vectors
+    const __m128i axl = _mm_sign_epi8(xl, xl);
+    const __m128i axh = _mm_sign_epi8(xh, xh);
+    // Sign the values of the y vectors
+    const __m128i syl = _mm_sign_epi8(yl, xl);
+    const __m128i syh = _mm_sign_epi8(yh, xh);
+    // Perform multiplication and create 16-bit values
+    const __m128i dotl = _mm_maddubs_epi16(axl, syl);
+    const __m128i doth = _mm_maddubs_epi16(axh, syh);
+    return sum_i16_pairs_float(doth, dotl);
+}
+
+// larger version of mul_sum_i8_pairs_float where x and y are each represented by four 128-bit vectors
+static inline __m256 mul_sum_i8_quad_float(const __m128i x_1_0, const __m128i x_1_1, const __m128i x_2_0, const __m128i x_2_1,
+                                           const __m128i y_1_0, const __m128i y_1_1, const __m128i y_2_0, const __m128i y_2_1) {
+    const __m128i mone = _mm_set1_epi16(1);
+
+    const __m128i p16_1_0 = mul_add_epi8_sse(x_1_0, y_1_0);
+    const __m128i p16_1_1 = mul_add_epi8_sse(x_1_1, y_1_1);
+    const __m128i p16_2_0 = mul_add_epi8_sse(x_2_0, y_2_0);
+    const __m128i p16_2_1 = mul_add_epi8_sse(x_2_1, y_2_1);
+    const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone);
+    const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone);
+    const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone);
+    const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone);
+    const __m128i p_1 = _mm_add_epi32(p_1_0, p_1_1);
+    const __m128i p_2 = _mm_add_epi32(p_2_0, p_2_1);
+    return _mm256_cvtepi32_ps(MM256_SET_M128I(p_2, p_1));
+}
+
+// quad fp16 delta calculation
+static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const float x1, const float y1) {
+    // GGML_FP16_TO_FP32 is faster than Intel F16C
+    return _mm256_set_m128(_mm_set1_ps(GGML_FP16_TO_FP32(x1) * GGML_FP16_TO_FP32(y1)),
+                           _mm_set1_ps(GGML_FP16_TO_FP32(x0) * GGML_FP16_TO_FP32(y0)));
+}
+#endif
+#elif defined(__SSSE3__)
+// horizontally add 4x4 floats
+static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
+    __m128 res_0 =_mm_hadd_ps(a, b);
+    __m128 res_1 =_mm_hadd_ps(c, d);
+    __m128 res =_mm_hadd_ps(res_0, res_1);
+    res =_mm_hadd_ps(res, res);
+    res =_mm_hadd_ps(res, res);
+
+    return _mm_cvtss_f32(res);
+}
+#endif // __AVX__ || __AVX2__ || __AVX512F__
+#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
+
+#if defined(__ARM_NEON) || defined(__wasm_simd128__) || defined(__POWER9_VECTOR__)
+#define B1(c,s,n)  0x ## n ## c ,  0x ## n ## s
+#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
+#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
+#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
+#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
+#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
+#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
+#define B8(c,s  ) B7(c,s,     c), B7(c,s,     s)
+
+// precomputed tables for expanding 8bits to 8 bytes:
+static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
+static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
+#endif
+
+#if defined(__loongarch_asx)
+
+#ifdef __clang__
+#define VREGS_PREFIX "$vr"
+#define XREGS_PREFIX "$xr"
+#else // GCC
+#define VREGS_PREFIX "$f"
+#define XREGS_PREFIX "$f"
+#endif
+#define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31"
+// Convert __m128i to __m256i
+static inline __m256i ____m256i(__m128i in) {
+    __m256i out = __lasx_xvldi(0);
+    __asm__ volatile (
+        ".irp i," __ALL_REGS                "\n\t"
+        " .ifc %[out], " XREGS_PREFIX"\\i    \n\t"
+        "  .irp j," __ALL_REGS              "\n\t"
+        "   .ifc %[in], " VREGS_PREFIX "\\j  \n\t"
+        "    xvpermi.q $xr\\i, $xr\\j, 0x20  \n\t"
+        "   .endif                           \n\t"
+        "  .endr                             \n\t"
+        " .endif                             \n\t"
+        ".endr                               \n\t"
+        : [out] "+f" (out) : [in] "f" (in)
+    );
+    return out;
+}
+// Convert two __m128i to __m256i
+static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) {
+    __m256i out;
+    __asm__ volatile (
+        ".irp i," __ALL_REGS                "\n\t"
+        " .ifc %[hi], " VREGS_PREFIX "\\i    \n\t"
+        "  .irp j," __ALL_REGS              "\n\t"
+        "   .ifc %[lo], " VREGS_PREFIX "\\j  \n\t"
+        "    xvpermi.q $xr\\i, $xr\\j, 0x20  \n\t"
+        "   .endif                           \n\t"
+        "  .endr                             \n\t"
+        " .endif                             \n\t"
+        ".endr                               \n\t"
+        ".ifnc %[out], %[hi]                 \n\t"
+        ".irp i," __ALL_REGS                "\n\t"
+        " .ifc %[out], " XREGS_PREFIX "\\i   \n\t"
+        "  .irp j," __ALL_REGS              "\n\t"
+        "   .ifc %[hi], " VREGS_PREFIX "\\j  \n\t"
+        "    xvori.b $xr\\i, $xr\\j, 0       \n\t"
+        "   .endif                           \n\t"
+        "  .endr                             \n\t"
+        " .endif                             \n\t"
+        ".endr                               \n\t"
+        ".endif                              \n\t"
+        : [out] "=f" (out), [hi] "+f" (inhi)
+        : [lo] "f" (inlo)
+    );
+    return out;
+}
+// Convert __m256i low part to __m128i
+static inline __m128i lasx_extracti128_lo(__m256i in) {
+    __m128i out;
+    __asm__ volatile (
+        ".ifnc %[out], %[in]                 \n\t"
+        ".irp i," __ALL_REGS                "\n\t"
+        " .ifc %[out], " VREGS_PREFIX "\\i   \n\t"
+        "  .irp j," __ALL_REGS              "\n\t"
+        "   .ifc %[in], " XREGS_PREFIX "\\j  \n\t"
+        "    vori.b $vr\\i, $vr\\j, 0        \n\t"
+        "   .endif                           \n\t"
+        "  .endr                             \n\t"
+        " .endif                             \n\t"
+        ".endr                               \n\t"
+        ".endif                              \n\t"
+        : [out] "=f" (out) : [in] "f" (in)
+    );
+    return out;
+}
+// Convert __m256i high part to __m128i
+static inline __m128i lasx_extracti128_hi(__m256i in) {
+    __m128i out;
+    __asm__ volatile (
+        ".irp i," __ALL_REGS                "\n\t"
+        " .ifc %[out], " VREGS_PREFIX "\\i   \n\t"
+        "  .irp j," __ALL_REGS              "\n\t"
+        "   .ifc %[in], " XREGS_PREFIX "\\j  \n\t"
+        "    xvpermi.q $xr\\i, $xr\\j, 0x11  \n\t"
+        "   .endif                           \n\t"
+        "  .endr                             \n\t"
+        " .endif                             \n\t"
+        ".endr                               \n\t"
+        : [out] "=f" (out) : [in] "f" (in)
+    );
+    return out;
+}
+
+static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1, int e0) {
+    v8i32 __ret = {e0, e1, e2, e3, e4, e5, e6, e7};
+    return (__m256i)__ret;
+}
+
+static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
+    v4i32 __ret = {d, c, b, a};
+    return (__m128i)__ret;
+}
+
+static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) {
+    v4i64 __ret = {d, c, b, a};
+    return (__m256i)__ret;
+}
+
+static __m256i lasx_insertf128( __m128i x, __m128i y) {
+    return lasx_set_q(x, y);
+}
+
+static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
+    __m128i mask_f, zero, tmp0, tmp2, mask;
+    int f = 0x8f;
+    mask_f = __lsx_vreplgr2vr_b(f);
+    zero = __lsx_vldi(0);
+    tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
+    tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or  with 0x10 prepare for positive
+    mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
+    tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
+    return __lsx_vshuf_b(a, zero, tmp2);
+}
+
+static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
+    __m256i mask_f, zero, tmp0, tmp2, mask;
+    int f = 0x8f;
+    mask_f = __lasx_xvreplgr2vr_b(f);
+    zero = __lasx_xvldi(0);
+    tmp0 = __lasx_xvand_v(b, mask_f); // get mask with low 4 bit and sign bits
+    tmp0 = __lasx_xvori_b(tmp0, 0x10); // make each mask or  with 0x10 prepare for positive
+    mask = __lasx_xvsle_b(zero, tmp0); // if mask >= 0, set mask
+    tmp2 = __lasx_xvand_v(tmp0, mask); // maskout the in2 < ones
+    return __lasx_xvshuf_b(a, zero, tmp2);
+}
+
+static __m256i lasx_extu8_16(__m128i a) {
+    __m128i zero = __lsx_vldi(0);
+    __m128i vlo = __lsx_vilvl_b(zero, a);
+    __m128i vhi = __lsx_vilvh_b(zero, a);
+    return lasx_set_q(vhi, vlo);
+}
+
+static __m256i lasx_ext8_16(__m128i a) {
+     __m128i sign = __lsx_vslti_b(a, 0);
+     __m128i vlo = __lsx_vilvl_b(sign, a);
+     __m128i vhi = __lsx_vilvh_b(sign, a);
+     return lasx_set_q(vhi, vlo);
+}
+
+static __m256i lasx_ext16_32(__m128i a) {
+    __m256i tmp1;
+    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0);
+    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1);
+    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2);
+    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3);
+    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4);
+    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5);
+    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6);
+    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7);
+    return tmp1;
+}
+
+static __m128i lasx_extracti128( __m256i a, int pos) {
+    __m128i ret;
+    if( pos == 0)
+    {
+       ret = lasx_extracti128_lo(a);
+    } else {
+       ret = lasx_extracti128_hi(a);
+    }
+    return ret;
+}
+
+static __m128 lasx_extractf128( __m256 a, int pos) {
+    __m128 ret;
+    if( pos == 0)
+    {
+       ret = (__m128)lasx_extracti128_lo((__m256i)a);
+    } else {
+       ret = (__m128)lasx_extracti128_hi((__m256i)a);
+    }
+    return ret;
+}
+
+static __m128i lsx_hadd_h(__m128i a, __m128i b) {
+    __m128i tmp1 = __lsx_vpickev_h(b, a);
+    __m128i tmp2 = __lsx_vpickod_h(b, a);
+    return __lsx_vadd_h(tmp1, tmp2);
+}
+
+static __m128i lsx_hadd_w(__m128i a, __m128i b) {
+    __m128i tmp1 = __lsx_vpickev_w(b, a);
+    __m128i tmp2 = __lsx_vpickod_w(b, a);
+    return __lsx_vadd_w(tmp1, tmp2);
+}
+
+static __m128 lsx_hadd_s(__m128 a, __m128 b) {
+    __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
+    __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
+
+    return __lsx_vfadd_s(tmp1, tmp2);
+}
+
+static __m256i lasx_maddubs_h(__m256i a, __m256i b) {
+    __m256i tmp1, tmp2;
+    tmp1 = __lasx_xvmulwev_h_b(a, b);
+    tmp2 = __lasx_xvmulwod_h_b(a, b);
+    return __lasx_xvsadd_h(tmp1, tmp2);
+}
+
+static __m256i lasx_madd_h(__m256i a, __m256i b) {
+    __m256i tmp1, tmp2;
+    tmp1 = __lasx_xvmulwev_w_h(a, b);
+    tmp2 = __lasx_xvmulwod_w_h(a, b);
+    return __lasx_xvadd_w(tmp1, tmp2);
+}
+
+static __m256i lasx_packs_w(__m256i a, __m256i b) {
+    __m256i tmp, tmp1;
+    tmp = __lasx_xvsat_w(a, 15);
+    tmp1 = __lasx_xvsat_w(b, 15);
+    return __lasx_xvpickev_h(tmp1, tmp);
+}
+
+static __m256i lasx_packs_h(__m256i a, __m256i b) {
+    __m256i tmp, tmp1;
+    tmp = __lasx_xvsat_h(a, 7);
+    tmp1 = __lasx_xvsat_h(b, 7);
+    return __lasx_xvpickev_b(tmp1, tmp);
+}
+
+static __m128i lsx_packs_w(__m128i a, __m128i b) {
+    __m128i tmp, tmp1;
+    tmp = __lsx_vsat_w(a, 15);
+    tmp1 = __lsx_vsat_w(b, 15);
+    return __lsx_vpickev_h(tmp1, tmp);
+}
+
+static __m128i lsx_packs_h(__m128i a, __m128i b) {
+    __m128i tmp, tmp1;
+    tmp = __lsx_vsat_h(a, 7);
+    tmp1 = __lsx_vsat_h(b, 7);
+    return __lsx_vpickev_b(tmp1, tmp);
+}
+
+static __m128i lsx_packus_h(__m128i a, __m128i b) {
+    __m128i tmp, tmp1;
+    tmp = __lsx_vsat_hu(a, 7);
+    tmp1 = __lsx_vsat_hu(b, 7);
+    return __lsx_vpickev_b(tmp1, tmp);
+}
+
+
+static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
+    __m128i tmp1, tmp2;
+    tmp1 = __lsx_vmulwev_h_b(a, b);
+    tmp2 = __lsx_vmulwod_h_b(a, b);
+    return __lsx_vsadd_h(tmp1, tmp2);
+}
+
+static __m128i lsx_madd_h(__m128i a, __m128i b) {
+    __m128i tmp1, tmp2;
+    tmp1 = __lsx_vmulwev_w_h(a, b);
+    tmp2 = __lsx_vmulwod_w_h(a, b);
+    return __lsx_vadd_w(tmp1, tmp2);
+}
+
+// multiply int8_t, add results pairwise twice
+static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
+    // Get absolute values of x vectors
+    const __m128i ax = __lsx_vsigncov_b(x, x);
+    // Sign the values of the y vectors
+    const __m128i sy = __lsx_vsigncov_b(x, y);
+    // Perform multiplication and create 16-bit values
+    const __m128i dot = lsx_maddubs_h(ax, sy);
+    const __m128i ones = __lsx_vreplgr2vr_h(1);
+    return lsx_madd_h(ones, dot);
+}
+
+// horizontally add 8 floats
+static inline float hsum_float_8(const __m256 x) {
+    __m128 res = lasx_extractf128(x, 1);
+    ft_union tmp;
+    res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
+    res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
+    res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
+    tmp.i = __lsx_vpickve2gr_w(res, 0);
+    return tmp.f;
+}
+
+// horizontally add 8 int32_t
+static inline int hsum_i32_8(const __m256i a) {
+
+    __m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11);
+    __m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00);
+
+    __m128i  tmp1_128 = lasx_extracti128_lo(tmp1);
+    __m128i  tmp2_128 = lasx_extracti128_lo(tmp2);
+
+    __m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128);
+
+    __m128i ev = __lsx_vpickev_w(sum128, sum128);
+    __m128i od = __lsx_vpickod_w(sum128, sum128);
+    __m128i sum64 = __lsx_vadd_w(ev, od);
+
+    int sum64_1, sum64_2;
+    sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
+    sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
+
+    return  sum64_1 + sum64_2;
+}
+
+// horizontally add 4 int32_t
+static inline int hsum_i32_4(const __m128i a) {
+    __m128i ev = __lsx_vpickev_w(a, a);
+    __m128i od = __lsx_vpickod_w(a, a);
+    __m128i sum64 = __lsx_vadd_w(ev, od);
+
+    int sum64_1, sum64_2;
+    sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
+    sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
+
+    return  sum64_1 + sum64_2;
+}
+
+// spread 32 bits to 32 bytes { 0x00, 0xFF }
+static inline __m256i bytes_from_bits_32(const uint8_t * x) {
+
+    uint32_t x32;
+    memcpy(&x32, x, sizeof(uint32_t));
+    const __m256i shuf_mask = lasx_set_d(
+            0x0303030303030303, 0x0202020202020202,
+            0x0101010101010101, 0x0000000000000000);
+
+    __m256i bytes = lasx_shuffle_b(__lasx_xvreplgr2vr_w(x32), shuf_mask);
+    const __m256i bit_mask = __lasx_xvreplgr2vr_d(0x7fbfdfeff7fbfdfe);
+    bytes = __lasx_xvor_v(bytes, bit_mask);
+    return __lasx_xvseq_b(bytes, __lasx_xvreplgr2vr_d(-1));
+}
+
+// Unpack 32 4-bit fields into 32 bytes
+// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
+static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) {
+    const __m128i lo = __lsx_vld((const __m128i *)rsi, 0);
+    __m128i hi = __lsx_vsrli_h(lo, 4);
+    return __lasx_xvandi_b(lasx_insertf128(hi, lo), 0xf);
+}
+
+// add int16_t pairwise and return as float vector
+static inline __m256 sum_i16_pairs_float(const __m256i x) {
+    __m256i v = __lasx_xvpackod_h(x, x);
+    __m256i summed_pairs = __lasx_xvaddwev_w_h(x, v);
+    return __lasx_xvffint_s_w(summed_pairs);
+}
+
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
+    // Perform multiplication and create 16-bit values
+    const __m256i dot = lasx_maddubs_h(ax, sy);
+    return sum_i16_pairs_float(dot);
+}
+
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+
+    // Get absolute values of x vectors
+    const __m256i ax = __lasx_xvsigncov_b(x, x);
+    // Sign the values of the y vectors
+    const __m256i sy = __lasx_xvsigncov_b(x, y);
+
+    return mul_sum_us8_pairs_float(ax, sy);
+}
+
+static inline __m128i packNibbles( __m256i bytes ) {
+    // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
+    const __m256i lowByte = __lasx_xvreplgr2vr_h(0xFF);
+     __m256i high = __lasx_xvandn_v(lowByte, bytes);
+    __m256i low = __lasx_xvand_v(lowByte, bytes);
+    high = __lasx_xvsrli_h(high, 4);
+    bytes = __lasx_xvor_v(low, high);
+    // Compress uint16_t lanes into bytes
+    __m128i *r0 = (__m128i *)&bytes;
+    __m256i tmp_h128 = __lasx_xvpermi_q(bytes, bytes, 0x11);
+    __m128i *r1 = (__m128i *)&tmp_h128;
+
+    __m128i zero = __lsx_vldi(0);
+    __m128i tmp, tmp2, tmp3;
+
+    tmp = __lsx_vmax_h(zero, *r0);
+    tmp2 = __lsx_vsat_hu(tmp, 7);
+
+    tmp = __lsx_vmax_h(zero, *r1);
+    tmp3 = __lsx_vsat_hu(tmp, 7);
+    return  __lsx_vpickev_b(tmp3, tmp2);
+}
+#endif  //__loongarch_asx
+
+void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) {
+    quantize_row_q4_0_ref(x, y, k);
+}
+
+void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) {
+    quantize_row_q4_1_ref(x, y, k);
+}
+
+void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) {
+    quantize_row_q5_0_ref(x, y, k);
+}
+
+void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) {
+    quantize_row_q5_1_ref(x, y, k);
+}
+
+void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) {
+    assert(QK8_0 == 32);
+    assert(k % QK8_0 == 0);
+    const int nb = k / QK8_0;
+
+    block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    for (int i = 0; i < nb; i++) {
+        float32x4_t srcv [8];
+        float32x4_t asrcv[8];
+        float32x4_t amaxv[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = vld1q_f32(x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = vmaxvq_f32(amaxv[0]);
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        for (int j = 0; j < 8; j++) {
+            const float32x4_t v  = vmulq_n_f32(srcv[j], id);
+            const int32x4_t   vi = vcvtnq_s32_f32(v);
+
+            y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
+            y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
+            y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
+            y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
+        }
+    }
+#elif defined(__wasm_simd128__)
+    for (int i = 0; i < nb; i++) {
+        v128_t srcv [8];
+        v128_t asrcv[8];
+        v128_t amaxv[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = wasm_v128_load(x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
+                                   wasm_f32x4_extract_lane(amaxv[0], 1)),
+                               MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
+                                   wasm_f32x4_extract_lane(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        for (int j = 0; j < 8; j++) {
+            const v128_t v  = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
+            const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
+
+            y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
+            y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
+            y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
+            y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
+        }
+    }
+#elif defined(__AVX2__) || defined(__AVX__)
+    for (int i = 0; i < nb; i++) {
+        // Load elements into 4 AVX vectors
+        __m256 v0 = _mm256_loadu_ps( x );
+        __m256 v1 = _mm256_loadu_ps( x + 8 );
+        __m256 v2 = _mm256_loadu_ps( x + 16 );
+        __m256 v3 = _mm256_loadu_ps( x + 24 );
+        x += 32;
+
+        // Compute max(abs(e)) for the block
+        const __m256 signBit = _mm256_set1_ps( -0.0f );
+        __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+        max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+        max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+        const float maxScalar = _mm_cvtss_f32( max4 );
+
+        // Quantize these floats
+        const float d = maxScalar / 127.f;
+        y[i].d = GGML_FP32_TO_FP16(d);
+        const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
+        const __m256 mul = _mm256_set1_ps( id );
+
+        // Apply the multiplier
+        v0 = _mm256_mul_ps( v0, mul );
+        v1 = _mm256_mul_ps( v1, mul );
+        v2 = _mm256_mul_ps( v2, mul );
+        v3 = _mm256_mul_ps( v3, mul );
+
+        // Round to nearest integer
+        v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+        v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+        v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+        v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+        // Convert floats to integers
+        __m256i i0 = _mm256_cvtps_epi32( v0 );
+        __m256i i1 = _mm256_cvtps_epi32( v1 );
+        __m256i i2 = _mm256_cvtps_epi32( v2 );
+        __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+#if defined(__AVX2__)
+        // Convert int32 to int16
+        i0 = _mm256_packs_epi32( i0, i1 );	// 0, 1, 2, 3,  8, 9, 10, 11,  4, 5, 6, 7, 12, 13, 14, 15
+        i2 = _mm256_packs_epi32( i2, i3 );	// 16, 17, 18, 19,  24, 25, 26, 27,  20, 21, 22, 23, 28, 29, 30, 31
+                                            // Convert int16 to int8
+        i0 = _mm256_packs_epi16( i0, i2 );	// 0, 1, 2, 3,  8, 9, 10, 11,  16, 17, 18, 19,  24, 25, 26, 27,  4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+
+        // We got our precious signed bytes, but the order is now wrong
+        // These AVX2 pack instructions process 16-byte pieces independently
+        // The following instruction is fixing the order
+        const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+        i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+        _mm256_storeu_si256((__m256i *)y[i].qs, i0);
+#else
+        // Since we don't have in AVX some necessary functions,
+        // we split the registers in half and call AVX2 analogs from SSE
+        __m128i ni0 = _mm256_castsi256_si128( i0 );
+        __m128i ni1 = _mm256_extractf128_si256( i0, 1);
+        __m128i ni2 = _mm256_castsi256_si128( i1 );
+        __m128i ni3 = _mm256_extractf128_si256( i1, 1);
+        __m128i ni4 = _mm256_castsi256_si128( i2 );
+        __m128i ni5 = _mm256_extractf128_si256( i2, 1);
+        __m128i ni6 = _mm256_castsi256_si128( i3 );
+        __m128i ni7 = _mm256_extractf128_si256( i3, 1);
+
+        // Convert int32 to int16
+        ni0 = _mm_packs_epi32( ni0, ni1 );
+        ni2 = _mm_packs_epi32( ni2, ni3 );
+        ni4 = _mm_packs_epi32( ni4, ni5 );
+        ni6 = _mm_packs_epi32( ni6, ni7 );
+        // Convert int16 to int8
+        ni0 = _mm_packs_epi16( ni0, ni2 );
+        ni4 = _mm_packs_epi16( ni4, ni6 );
+
+        _mm_storeu_si128((__m128i *)(y[i].qs +  0), ni0);
+        _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
+#endif
+    }
+#elif defined(__riscv_v_intrinsic)
+
+    size_t vl = __riscv_vsetvl_e32m4(QK8_0);
+
+    for (int i = 0; i < nb; i++) {
+        // load elements
+        vfloat32m4_t v_x   = __riscv_vle32_v_f32m4(x+i*QK8_0, vl);
+
+        vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
+        vfloat32m1_t tmp   = __riscv_vfmv_v_f_f32m1(0.0f, vl);
+        vfloat32m1_t vmax  = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
+        float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
+
+        // convert to integer
+        vint16m2_t   vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
+        vint8m1_t    vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
+
+        // store result
+        __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
+    }
+
+#elif defined(__POWER9_VECTOR__)
+    for (int i = 0; i < nb; i++) {
+        vector float srcv [8];
+        vector float asrcv[8];
+        vector float amaxv[8];
+        vector signed int vi[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = vec_xl(0, x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = MAX(MAX(vec_extract(amaxv[0], 0),
+                                   vec_extract(amaxv[0], 1)),
+                               MAX(vec_extract(amaxv[0], 2),
+                                   vec_extract(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+        const vector float vid = vec_splats(id);
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        for (int j = 0; j < 8; j++) {
+            const vector float v  = vec_round(vec_mul(srcv[j], vid));
+            vi[j] = vec_cts(v, 0);
+        }
+        vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])),  0, &y[i].qs[0]);
+        vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]);
+    }
+
+#elif defined(__loongarch_asx)
+    for (int i = 0; i < nb; i++) {
+        ft_union fi;
+        __m256 v0 = (__m256)__lasx_xvld( x , 0);
+        __m256 v1 = (__m256)__lasx_xvld( x , 32);
+        __m256 v2 = (__m256)__lasx_xvld( x , 64);
+        __m256 v3 = (__m256)__lasx_xvld( x , 96);
+        x += 32;
+
+        // Compute max(abs(e)) for the block
+        const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f );
+        __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 );
+        max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) );
+        max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) );
+        max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) );
+
+        __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs , 0) );
+        max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
+        __m128 tmp = max4;
+        max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 ));
+        fi.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
+        const float max_scalar = fi.f;
+
+        // Quantize these floats
+        const float d = max_scalar / 127.f;
+        y[i].d = GGML_FP32_TO_FP16(d);
+        const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
+        const __m256 mul = (__m256)__lasx_xvreplfr2vr_s( id );
+
+        // Apply the multiplier
+        v0 = __lasx_xvfmul_s( v0, mul );
+        v1 = __lasx_xvfmul_s( v1, mul );
+        v2 = __lasx_xvfmul_s( v2, mul );
+        v3 = __lasx_xvfmul_s( v3, mul );
+
+        // Round to nearest integer
+        __m256i i0 = __lasx_xvftintrne_w_s( v0 );
+        __m256i i1 = __lasx_xvftintrne_w_s( v1 );
+        __m256i i2 = __lasx_xvftintrne_w_s( v2 );
+        __m256i i3 = __lasx_xvftintrne_w_s( v3 );
+
+        __m128i ni0 = lasx_extracti128( i0, 0 );
+        __m128i ni1 = lasx_extracti128( i0, 1);
+        __m128i ni2 = lasx_extracti128( i1, 0);
+        __m128i ni3 = lasx_extracti128( i1, 1);
+        __m128i ni4 = lasx_extracti128( i2, 0);
+        __m128i ni5 = lasx_extracti128( i2, 1);
+        __m128i ni6 = lasx_extracti128( i3, 0);
+        __m128i ni7 = lasx_extracti128( i3, 1);
+
+        // Convert int32 to int16
+        ni0 = lsx_packs_w( ni0, ni1 );
+        ni2 = lsx_packs_w( ni2, ni3 );
+        ni4 = lsx_packs_w( ni4, ni5 );
+        ni6 = lsx_packs_w( ni6, ni7 );
+        // Convert int16 to int8
+        ni0 = lsx_packs_h( ni0, ni2 );
+        ni4 = lsx_packs_h( ni4, ni6 );
+
+        __lsx_vst(ni0, (__m128i *)(y[i].qs +  0), 0);
+        __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
+
+    }
+#else
+    GGML_UNUSED(nb);
+    // scalar
+    quantize_row_q8_0_ref(x, y, k);
+#endif
+}
+
+void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) {
+    assert(k % QK8_1 == 0);
+    const int nb = k / QK8_1;
+
+    block_q8_1 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    for (int i = 0; i < nb; i++) {
+        float32x4_t srcv [8];
+        float32x4_t asrcv[8];
+        float32x4_t amaxv[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = vld1q_f32(x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = vmaxvq_f32(amaxv[0]);
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        int32x4_t accv = vdupq_n_s32(0);
+
+        for (int j = 0; j < 8; j++) {
+            const float32x4_t v  = vmulq_n_f32(srcv[j], id);
+            const int32x4_t   vi = vcvtnq_s32_f32(v);
+
+            y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
+            y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
+            y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
+            y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
+
+            accv = vaddq_s32(accv, vi);
+        }
+
+        y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
+    }
+#elif defined(__wasm_simd128__)
+    for (int i = 0; i < nb; i++) {
+        v128_t srcv [8];
+        v128_t asrcv[8];
+        v128_t amaxv[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = wasm_v128_load(x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
+                                   wasm_f32x4_extract_lane(amaxv[0], 1)),
+                               MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
+                                   wasm_f32x4_extract_lane(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        v128_t accv = wasm_i32x4_splat(0);
+
+        for (int j = 0; j < 8; j++) {
+            const v128_t v  = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
+            const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
+
+            y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
+            y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
+            y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
+            y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
+
+            accv = wasm_i32x4_add(accv, vi);
+        }
+
+        y[i].s = GGML_FP32_TO_FP16(
+                d * (wasm_i32x4_extract_lane(accv, 0) +
+                     wasm_i32x4_extract_lane(accv, 1) +
+                     wasm_i32x4_extract_lane(accv, 2) +
+                     wasm_i32x4_extract_lane(accv, 3)));
+    }
+#elif defined(__AVX2__) || defined(__AVX__)
+    for (int i = 0; i < nb; i++) {
+        // Load elements into 4 AVX vectors
+        __m256 v0 = _mm256_loadu_ps( x );
+        __m256 v1 = _mm256_loadu_ps( x + 8 );
+        __m256 v2 = _mm256_loadu_ps( x + 16 );
+        __m256 v3 = _mm256_loadu_ps( x + 24 );
+        x += 32;
+
+        // Compute max(abs(e)) for the block
+        const __m256 signBit = _mm256_set1_ps( -0.0f );
+        __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+        max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+        max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+        const float max_scalar = _mm_cvtss_f32( max4 );
+
+        // Quantize these floats
+        const float d = max_scalar / 127.f;
+        y[i].d = GGML_FP32_TO_FP16(d);
+        const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
+        const __m256 mul = _mm256_set1_ps( id );
+
+        // Apply the multiplier
+        v0 = _mm256_mul_ps( v0, mul );
+        v1 = _mm256_mul_ps( v1, mul );
+        v2 = _mm256_mul_ps( v2, mul );
+        v3 = _mm256_mul_ps( v3, mul );
+
+        // Round to nearest integer
+        v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+        v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+        v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+        v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+        // Convert floats to integers
+        __m256i i0 = _mm256_cvtps_epi32( v0 );
+        __m256i i1 = _mm256_cvtps_epi32( v1 );
+        __m256i i2 = _mm256_cvtps_epi32( v2 );
+        __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+#if defined(__AVX2__)
+        // Compute the sum of the quants and set y[i].s
+        y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
+
+        // Convert int32 to int16
+        i0 = _mm256_packs_epi32( i0, i1 );	// 0, 1, 2, 3,  8, 9, 10, 11,  4, 5, 6, 7, 12, 13, 14, 15
+        i2 = _mm256_packs_epi32( i2, i3 );	// 16, 17, 18, 19,  24, 25, 26, 27,  20, 21, 22, 23, 28, 29, 30, 31
+                                            // Convert int16 to int8
+        i0 = _mm256_packs_epi16( i0, i2 );	// 0, 1, 2, 3,  8, 9, 10, 11,  16, 17, 18, 19,  24, 25, 26, 27,  4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+
+        // We got our precious signed bytes, but the order is now wrong
+        // These AVX2 pack instructions process 16-byte pieces independently
+        // The following instruction is fixing the order
+        const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+        i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+        _mm256_storeu_si256((__m256i *)y[i].qs, i0);
+#else
+        // Since we don't have in AVX some necessary functions,
+        // we split the registers in half and call AVX2 analogs from SSE
+        __m128i ni0 = _mm256_castsi256_si128( i0 );
+        __m128i ni1 = _mm256_extractf128_si256( i0, 1);
+        __m128i ni2 = _mm256_castsi256_si128( i1 );
+        __m128i ni3 = _mm256_extractf128_si256( i1, 1);
+        __m128i ni4 = _mm256_castsi256_si128( i2 );
+        __m128i ni5 = _mm256_extractf128_si256( i2, 1);
+        __m128i ni6 = _mm256_castsi256_si128( i3 );
+        __m128i ni7 = _mm256_extractf128_si256( i3, 1);
+
+        // Compute the sum of the quants and set y[i].s
+        const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
+        const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
+        y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_4(_mm_add_epi32(s0, s1)));
+
+        // Convert int32 to int16
+        ni0 = _mm_packs_epi32( ni0, ni1 );
+        ni2 = _mm_packs_epi32( ni2, ni3 );
+        ni4 = _mm_packs_epi32( ni4, ni5 );
+        ni6 = _mm_packs_epi32( ni6, ni7 );
+        // Convert int16 to int8
+        ni0 = _mm_packs_epi16( ni0, ni2 );
+        ni4 = _mm_packs_epi16( ni4, ni6 );
+
+        _mm_storeu_si128((__m128i *)(y[i].qs +  0), ni0);
+        _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
+#endif
+    }
+#elif defined(__riscv_v_intrinsic)
+
+    size_t vl = __riscv_vsetvl_e32m4(QK8_1);
+
+    for (int i = 0; i < nb; i++) {
+        // load elements
+        vfloat32m4_t v_x   = __riscv_vle32_v_f32m4(x+i*QK8_1, vl);
+
+        vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
+        vfloat32m1_t tmp   = __riscv_vfmv_v_f_f32m1(0.0, vl);
+        vfloat32m1_t vmax  = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
+        float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
+
+        const float d  = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
+
+        // convert to integer
+        vint16m2_t   vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
+        vint8m1_t    vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
+
+        // store result
+        __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
+
+        // compute sum for y[i].s
+        vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
+        vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl);
+
+        // set y[i].s
+        int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
+        y[i].s = GGML_FP32_TO_FP16(sum*d);
+    }
+
+#elif defined(__POWER9_VECTOR__)
+    for (int i = 0; i < nb; i++) {
+        vector float srcv [8];
+        vector float asrcv[8];
+        vector float amaxv[8];
+        vector signed int vi[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = vec_xl(0, x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = MAX(MAX(vec_extract(amaxv[0], 0),
+                                   vec_extract(amaxv[0], 1)),
+                               MAX(vec_extract(amaxv[0], 2),
+                                   vec_extract(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+        const vector float vid = vec_splats(id);
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        vector int accv = vec_splats(0);
+
+        for (int j = 0; j < 8; j++) {
+            const vector float v  = vec_round(vec_mul(srcv[j], vid));
+            vi[j] = vec_cts(v, 0);
+
+            accv = vec_add(accv, vi[j]);
+        }
+        vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])),  0, &y[i].qs[0]);
+        vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]);
+
+        accv = vec_add(accv, vec_sld(accv, accv, 4));
+        accv = vec_add(accv, vec_sld(accv, accv, 8));
+        y[i].s = GGML_FP32_TO_FP16(d * vec_extract(accv, 0));
+    }
+
+#elif defined(__loongarch_asx)
+    for (int i = 0; i < nb; i++) {
+        ft_union ft;
+        __m256 v0 = (__m256)__lasx_xvld( x , 0 );
+        __m256 v1 = (__m256)__lasx_xvld( x , 32 );
+        __m256 v2 = (__m256)__lasx_xvld( x , 64 );
+        __m256 v3 = (__m256)__lasx_xvld( x , 96 );
+        x += 32;
+
+        // Compute max(abs(e)) for the block
+        const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f );
+        __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 );
+        max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) );
+        max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) );
+        max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) );
+
+        __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs, 0) );
+        max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
+        __m128 tmp = max4;
+        max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 ));
+        ft.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
+        const float max_scalar = ft.f;
+
+        // Quantize these floats
+        const float d = max_scalar / 127.f;
+        y[i].d = GGML_FP32_TO_FP16(d);
+        const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
+        const __m256 mul = __lasx_xvreplfr2vr_s( id );
+
+        // Apply the multiplier
+        v0 = __lasx_xvfmul_s( v0, mul );
+        v1 = __lasx_xvfmul_s( v1, mul );
+        v2 = __lasx_xvfmul_s( v2, mul );
+        v3 = __lasx_xvfmul_s( v3, mul );
+
+        // Round to nearest integer
+        __m256i i0 = __lasx_xvftintrne_w_s( v0 );
+        __m256i i1 = __lasx_xvftintrne_w_s( v1 );
+        __m256i i2 = __lasx_xvftintrne_w_s( v2 );
+        __m256i i3 = __lasx_xvftintrne_w_s( v3 );
+
+        __m128i ni0 = lasx_extracti128(i0, 0);
+        __m128i ni1 = lasx_extracti128( i0, 1);
+        __m128i ni2 = lasx_extracti128( i1, 0);
+        __m128i ni3 = lasx_extracti128( i1, 1);
+        __m128i ni4 = lasx_extracti128( i2, 0 );
+        __m128i ni5 = lasx_extracti128( i2, 1);
+        __m128i ni6 = lasx_extracti128( i3, 0);
+        __m128i ni7 = lasx_extracti128( i3, 1);
+
+        // Compute the sum of the quants and set y[i].s
+        const __m128i s0 = __lsx_vadd_w(__lsx_vadd_w(ni0, ni1), __lsx_vadd_w(ni2, ni3));
+        const __m128i s1 = __lsx_vadd_w(__lsx_vadd_w(ni4, ni5), __lsx_vadd_w(ni6, ni7));
+        y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_4(__lsx_vadd_w(s0, s1)));
+
+        // Convert int32 to int16
+        ni0 = lsx_packs_w( ni0, ni1 );
+        ni2 = lsx_packs_w( ni2, ni3 );
+        ni4 = lsx_packs_w( ni4, ni5 );
+        ni6 = lsx_packs_w( ni6, ni7 );
+        // Convert int16 to int8
+        ni0 = lsx_packs_h( ni0, ni2 );
+        ni4 = lsx_packs_h( ni4, ni6 );
+
+        __lsx_vst(ni0, (__m128i *)(y[i].qs +  0), 0);
+        __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
+    }
+#else
+    GGML_UNUSED(nb);
+    // scalar
+    quantize_row_q8_1_ref(x, y, k);
+#endif
+}
+
+//
+// 2-6 bit quantization in super-blocks
+//
+
+//
+// ===================== Helper functions
+//
+static inline int nearest_int(float fval) {
+    assert(fabsf(fval) <= 4194303.f);
+    float val = fval + 12582912.f;
+    int i; memcpy(&i, &val, sizeof(int));
+    return (i & 0x007fffff) - 0x00400000;
+}
+
+static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type,
+        const float * restrict qw) {
+    float max = 0;
+    float amax = 0;
+    for (int i = 0; i < n; ++i) {
+        float ax = fabsf(x[i]);
+        if (ax > amax) { amax = ax; max = x[i]; }
+    }
+    if (amax < GROUP_MAX_EPS) { // all zero
+        for (int i = 0; i < n; ++i) {
+            L[i] = 0;
+        }
+        return 0.f;
+    }
+    float iscale = -nmax / max;
+    if (rmse_type == 0) {
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale * x[i]);
+            L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
+        }
+        return 1/iscale;
+    }
+    bool return_early = false;
+    if (rmse_type < 0) {
+        rmse_type = -rmse_type;
+        return_early = true;
+    }
+    float sumlx = 0;
+    float suml2 = 0;
+#ifdef HAVE_BUGGY_APPLE_LINKER
+    // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
+    for (volatile int i = 0; i < n; ++i) {
+#else
+    for (int i = 0; i < n; ++i) {
+#endif
+        int l = nearest_int(iscale * x[i]);
+        l = MAX(-nmax, MIN(nmax-1, l));
+        L[i] = l + nmax;
+        float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
+        sumlx += w*x[i]*l;
+        suml2 += w*l*l;
+    }
+    float scale = suml2 ? sumlx/suml2 : 0.0f;
+    if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
+    float best = scale * sumlx;
+    for (int is = -9; is <= 9; ++is) {
+        if (is == 0) {
+            continue;
+        }
+        iscale = -(nmax + 0.1f*is) / max;
+        sumlx = suml2 = 0;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale * x[i]);
+            l = MAX(-nmax, MIN(nmax-1, l));
+            float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
+            sumlx += w*x[i]*l;
+            suml2 += w*l*l;
+        }
+        if (suml2 > 0 && sumlx*sumlx > best*suml2) {
+            for (int i = 0; i < n; ++i) {
+                int l = nearest_int(iscale * x[i]);
+                L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
+            }
+            scale = sumlx/suml2; best = scale*sumlx;
+        }
+    }
+    return scale;
+}
+
+static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) {
+    float max = 0;
+    float amax = 0;
+    for (int i = 0; i < n; ++i) {
+        float ax = fabsf(x[i]);
+        if (ax > amax) { amax = ax; max = x[i]; }
+    }
+    if (amax < GROUP_MAX_EPS) { // all zero
+        for (int i = 0; i < n; ++i) { L[i] = 0; }
+        return 0.f;
+    }
+    float iscale = -nmax / max;
+    if (do_rmse) {
+        float sumlx = 0;
+        float suml2 = 0;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale * x[i]);
+            l = MAX(-nmax, MIN(nmax-1, l));
+            L[i] = l;
+            float w = x[i]*x[i];
+            sumlx += w*x[i]*l;
+            suml2 += w*l*l;
+        }
+        for (int itry = 0; itry < 5; ++itry) {
+            int n_changed = 0;
+            for (int i = 0; i < n; ++i) {
+                float w = x[i]*x[i];
+                float slx = sumlx - w*x[i]*L[i];
+                if (slx > 0) {
+                    float sl2 = suml2 - w*L[i]*L[i];
+                    int new_l = nearest_int(x[i] * sl2 / slx);
+                    new_l = MAX(-nmax, MIN(nmax-1, new_l));
+                    if (new_l != L[i]) {
+                        slx += w*x[i]*new_l;
+                        sl2 += w*new_l*new_l;
+                        if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
+                            L[i] = new_l; sumlx = slx; suml2 = sl2;
+                            ++n_changed;
+                        }
+                    }
+                }
+            }
+            if (!n_changed) {
+                break;
+            }
+        }
+        for (int i = 0; i < n; ++i) {
+            L[i] += nmax;
+        }
+        return sumlx / suml2;
+    }
+    for (int i = 0; i < n; ++i) {
+        int l = nearest_int(iscale * x[i]);
+        l = MAX(-nmax, MIN(nmax-1, l));
+        L[i] = l + nmax;
+    }
+    return 1/iscale;
+}
+
+static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
+        int ntry, float alpha) {
+    float min = x[0];
+    float max = x[0];
+    for (int i = 1; i < n; ++i) {
+        if (x[i] < min) min = x[i];
+        if (x[i] > max) max = x[i];
+    }
+    if (max == min) {
+        for (int i = 0; i < n; ++i) L[i] = 0;
+        *the_min = 0;
+        return 0.f;
+    }
+    if (min > 0) min = 0;
+    float iscale = nmax/(max - min);
+    float scale = 1/iscale;
+    for (int itry = 0; itry < ntry; ++itry) {
+        float sumlx = 0; int suml2 = 0;
+        bool did_change = false;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale*(x[i] - min));
+            l = MAX(0, MIN(nmax, l));
+            if (l != L[i]) {
+                L[i] = l;
+                did_change = true;
+            }
+            sumlx += (x[i] - min)*l;
+            suml2 += l*l;
+        }
+        scale = sumlx/suml2;
+        float sum = 0;
+        for (int i = 0; i < n; ++i) {
+            sum += x[i] - scale*L[i];
+        }
+        min = alpha*min + (1 - alpha)*sum/n;
+        if (min > 0) min = 0;
+        iscale = 1/scale;
+        if (!did_change) break;
+    }
+    *the_min = -min;
+    return scale;
+}
+
+static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
+        uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
+        float rmin, float rdelta, int nstep, bool use_mad) {
+    float min = x[0];
+    float max = x[0];
+    float sum_w = weights[0];
+    float sum_x = sum_w * x[0];
+#ifdef HAVE_BUGGY_APPLE_LINKER
+    // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
+    for (volatile int i = 1; i < n; ++i) {
+#else
+    for (int i = 1; i < n; ++i) {
+#endif
+        if (x[i] < min) min = x[i];
+        if (x[i] > max) max = x[i];
+        float w = weights[i];
+        sum_w += w;
+        sum_x += w * x[i];
+    }
+    if (min > 0) min = 0;
+    if (max == min) {
+        for (int i = 0; i < n; ++i) L[i] = 0;
+        *the_min = -min;
+        return 0.f;
+    }
+    float iscale = nmax/(max - min);
+    float scale = 1/iscale;
+    float best_mad = 0;
+    for (int i = 0; i < n; ++i) {
+        int l = nearest_int(iscale*(x[i] - min));
+        L[i] = MAX(0, MIN(nmax, l));
+        float diff = scale * L[i] + min - x[i];
+        diff = use_mad ? fabsf(diff) : diff * diff;
+        float w = weights[i];
+        best_mad += w * diff;
+    }
+    if (nstep < 1) {
+        *the_min = -min;
+        return scale;
+    }
+    for (int is = 0; is <= nstep; ++is) {
+        iscale = (rmin + rdelta*is + nmax)/(max - min);
+        float sum_l = 0, sum_l2 = 0, sum_xl = 0;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale*(x[i] - min));
+            l = MAX(0, MIN(nmax, l));
+            Laux[i] = l;
+            float w = weights[i];
+            sum_l += w*l;
+            sum_l2 += w*l*l;
+            sum_xl += w*l*x[i];
+        }
+        float D = sum_w * sum_l2 - sum_l * sum_l;
+        if (D > 0) {
+            float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
+            float this_min   = (sum_l2 * sum_x - sum_l * sum_xl)/D;
+            if (this_min > 0) {
+                this_min = 0;
+                this_scale = sum_xl / sum_l2;
+            }
+            float mad = 0;
+            for (int i = 0; i < n; ++i) {
+                float diff = this_scale * Laux[i] + this_min - x[i];
+                diff = use_mad ? fabsf(diff) : diff * diff;
+                float w = weights[i];
+                mad += w * diff;
+            }
+            if (mad < best_mad) {
+                for (int i = 0; i < n; ++i) {
+                    L[i] = Laux[i];
+                }
+                best_mad = mad;
+                scale = this_scale;
+                min = this_min;
+            }
+        }
+    }
+    *the_min = -min;
+    return scale;
+}
+
+static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
+    if (j < 4) {
+        *d = q[j] & 63; *m = q[j + 4] & 63;
+    } else {
+        *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
+        *m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);
+    }
+}
+
+//========================- 2-bit (de)-quantization
+
+void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) {
+    quantize_row_q2_K_ref(x, vy, k);
+}
+
+//========================= 3-bit (de)-quantization
+
+void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) {
+    quantize_row_q3_K_ref(x, vy, k);
+}
+
+// ====================== 4-bit (de)-quantization
+
+void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) {
+    assert(k % QK_K == 0);
+    block_q4_K * restrict y = vy;
+    quantize_row_q4_K_ref(x, y, k);
+}
+
+// ====================== 5-bit (de)-quantization
+
+void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) {
+    assert(k % QK_K == 0);
+    block_q5_K * restrict y = vy;
+    quantize_row_q5_K_ref(x, y, k);
+}
+
+// ====================== 6-bit (de)-quantization
+
+void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) {
+    assert(k % QK_K == 0);
+    block_q6_K * restrict y = vy;
+    quantize_row_q6_K_ref(x, y, k);
+}
+
+// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
+
+void quantize_row_tq1_0(const float * restrict x, void * restrict vy, int64_t k) {
+    assert(k % QK_K == 0);
+    block_tq1_0 * restrict y = vy;
+    quantize_row_tq1_0_ref(x, y, k);
+}
+
+void quantize_row_tq2_0(const float * restrict x, void * restrict vy, int64_t k) {
+    assert(k % QK_K == 0);
+    block_tq2_0 * restrict y = vy;
+    quantize_row_tq2_0_ref(x, y, k);
+}
+
+static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
+
+//===================================== Q8_K ==============================================
+
+void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
+    quantize_row_q8_K_ref(x, y, k);
+}
+
+//===================================== Dot products =================================
+
+//
+// Helper functions
+//
+#if __AVX__ || __AVX2__ || __AVX512F__
+
+// shuffles to pick the required scales in dot products
+static inline __m256i get_scale_shuffle_q3k(int i) {
+    static const uint8_t k_shuffle[128] = {
+         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,     2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
+         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,     6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
+         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,    10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
+        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,    14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
+    };
+    return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
+}
+static inline __m256i get_scale_shuffle_k4(int i) {
+    static const uint8_t k_shuffle[256] = {
+         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
+         2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
+         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
+         6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
+         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
+        10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
+        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
+        14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
+    };
+    return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
+}
+static inline __m128i get_scale_shuffle(int i) {
+    static const uint8_t k_shuffle[128] = {
+         0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
+         2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
+         4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
+         6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
+         8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
+        10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,
+        12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,
+        14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15
+    };
+    return _mm_loadu_si128((const __m128i*)k_shuffle + i);
+}
+#elif defined(__loongarch_asx)
+// shuffles to pick the required scales in dot products
+static inline __m256i get_scale_shuffle_q3k(int i) {
+    static const uint8_t k_shuffle[128] = {
+         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,     2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
+         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,     6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
+         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,    10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
+        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,    14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
+    };
+    return __lasx_xvld((const __m256i*)k_shuffle + i, 0);
+}
+static inline __m256i get_scale_shuffle_k4(int i) {
+    static const uint8_t k_shuffle[256] = {
+         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
+         2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
+         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
+         6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
+         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
+        10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
+        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
+        14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
+    };
+    return __lasx_xvld((const __m256i*)k_shuffle + i, 0);
+}
+static inline __m128i get_scale_shuffle(int i) {
+    static const uint8_t k_shuffle[128] = {
+         0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
+         2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
+         4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
+         6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
+         8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
+        10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,
+        12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,
+        14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15
+    };
+    return __lsx_vld((const __m128i*)k_shuffle + i, 0);
+}
+#endif
+
+void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+
+    assert(n % qk == 0);
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    assert((nrc == 2) || (nrc == 1));
+#else
+    assert(nrc == 1);
+#endif
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q4_0 * restrict x = vx;
+    const block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    if (nrc == 2) {
+        const block_q4_0 * restrict vx0 = vx;
+        const block_q4_0 * restrict vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx);
+        const block_q8_0 * restrict vy0 = vy;
+        const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
+
+        float32x4_t sumv0 = vdupq_n_f32(0.0f);
+
+        for (int i = 0; i < nb; i++) {
+            const block_q4_0 * restrict b_x0 = &vx0[i];
+            const block_q4_0 * restrict b_x1 = &vx1[i];
+            const block_q8_0 * restrict b_y0 = &vy0[i];
+            const block_q8_0 * restrict b_y1 = &vy1[i];
+
+            const uint8x16_t m4b = vdupq_n_u8(0x0F);
+            const int8x16_t  s8b = vdupq_n_s8(0x8);
+
+            const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
+            const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
+
+            // 4-bit -> 8-bit
+            const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+            const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+            const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+            const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+            // sub 8
+            const int8x16_t x0_l = vsubq_s8(v0_0l, s8b);
+            const int8x16_t x0_h = vsubq_s8(v0_0h, s8b);
+            const int8x16_t x1_l = vsubq_s8(v0_1l, s8b);
+            const int8x16_t x1_h = vsubq_s8(v0_1h, s8b);
+
+            // load y
+            const int8x16_t y0_l = vld1q_s8(b_y0->qs);
+            const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
+            const int8x16_t y1_l = vld1q_s8(b_y1->qs);
+            const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
+
+            float32_t _scale[4] = {
+                GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
+                GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
+                GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
+                GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)
+            };
+            float32x4_t scale = vld1q_f32(_scale);
+
+            int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+            int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+
+            int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+            int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+
+            int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+            int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+
+            int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+            int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+
+            sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
+                                                l1, r1)), l2, r2)), l3, r3))), scale);
+        }
+
+        float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
+        float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
+
+        vst1_f32(s,      vget_low_f32 (sumv2));
+        vst1_f32(s + bs, vget_high_f32(sumv2));
+
+        return;
+    }
+#endif
+
+    int ib = 0;
+    float sumf = 0;
+
+#if defined(__ARM_FEATURE_SVE)
+    svfloat32_t sumv0 = svdup_n_f32(0.0f);
+    svfloat32_t sumv1 = svdup_n_f32(0.0f);
+
+    const int vector_length = ggml_cpu_get_sve_cnt()*8;
+
+    // VLA Implementation using switch case
+    switch (vector_length) {
+        case 128:
+            {
+                // predicate for activating higher lanes for 4 float32 elements
+                const svbool_t ph4 = svptrue_pat_b32(SV_VL4);
+
+                for (; ib + 1 < nb; ib += 2) {
+                    const block_q4_0 * restrict x0 = &x[ib + 0];
+                    const block_q4_0 * restrict x1 = &x[ib + 1];
+                    const block_q8_0 * restrict y0 = &y[ib + 0];
+                    const block_q8_0 * restrict y1 = &y[ib + 1];
+
+                    // load x
+                    const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
+                    const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
+
+                    // 4-bit -> 8-bit
+                    const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx0r, 0x0F));
+                    const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx0r, 0x04));
+                    const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx1r, 0x0F));
+                    const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04));
+
+                    // sub 8
+                    const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8);
+                    const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8);
+                    const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8);
+                    const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8);
+
+                    // load y
+                    const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs);
+                    const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs + 16);
+                    const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs);
+                    const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs + 16);
+
+                    // dot product
+                    sumv0 = svmla_n_f32_x(ph4, sumv0, svcvt_f32_s32_x(ph4, svadd_x(ph4,
+                                    svdot_s32(svdup_n_s32(0), qx0ls, qy0l),
+                                    svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+                    sumv1 = svmla_n_f32_x(ph4, sumv1, svcvt_f32_s32_x(ph4, svadd_x(ph4,
+                                    svdot_s32(svdup_n_s32(0), qx1ls, qy1l),
+                                    svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+                }
+
+                sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
+            } break;
+        case 256:
+            {
+                // predicate for activating higher lanes for 16 int8 elements
+                const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
+                // predicate for activating lower lanes for  16 int8 elements
+                const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
+
+                for (; ib + 1 < nb; ib += 2) {
+                    const block_q4_0 * restrict x0 = &x[ib + 0];
+                    const block_q4_0 * restrict x1 = &x[ib + 1];
+                    const block_q8_0 * restrict y0 = &y[ib + 0];
+                    const block_q8_0 * restrict y1 = &y[ib + 1];
+
+                    // load x
+                    const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
+                    const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
+
+                    // 4-bit -> 8-bit
+                    const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
+                    const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
+
+                    // sub 8
+                    const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
+                    const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
+
+                    // load y
+                    const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
+                    const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
+
+                    // dot product
+                    sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
+                                svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+                    sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
+                                svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+                }
+
+                sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
+            } break;
+        case 512:
+            {
+                // predicate for activating higher lanes for 32 int8 elements
+                const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
+
+                // predicate for activating higher lanes for 16 int8 elements
+                const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
+                // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes
+                const svbool_t pl16 = svnot_b_z(ph32, ph16);
+
+                for (; ib + 1 < nb; ib += 2) {
+                    const block_q4_0 * restrict x0 = &x[ib + 0];
+                    const block_q4_0 * restrict x1 = &x[ib + 1];
+                    const block_q8_0 * restrict y0 = &y[ib + 0];
+                    const block_q8_0 * restrict y1 = &y[ib + 1];
+
+                    // load x
+                    const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs);
+                    const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs);
+
+                    // 4-bit -> 8-bit
+                    const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
+                    const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
+
+                    // sub 8
+                    const svint8_t qx0s = svsub_n_s8_x(ph32, qx0, 8);
+                    const svint8_t qx1s = svsub_n_s8_x(ph32, qx1, 8);
+
+                    // load y
+                    const svint8_t qy0 = svld1_s8(ph32, y0->qs);
+                    const svint8_t qy1 = svld1_s8(ph32, y1->qs);
+
+                    // dot product
+                    sumv0 = svmla_n_f32_x(ph32, sumv0, svcvt_f32_s32_x(ph32,
+                                svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+                    sumv1 = svmla_n_f32_x(ph32, sumv1, svcvt_f32_s32_x(ph32,
+                                svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+                }
+
+                sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1));
+            } break;
+        default:
+            assert(false && "Unsupported vector length");
+            break;
+    }
+
+#elif defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    for (; ib + 1 < nb; ib += 2) {
+        const block_q4_0 * restrict x0 = &x[ib + 0];
+        const block_q4_0 * restrict x1 = &x[ib + 1];
+        const block_q8_0 * restrict y0 = &y[ib + 0];
+        const block_q8_0 * restrict y1 = &y[ib + 1];
+
+        const uint8x16_t m4b = vdupq_n_u8(0x0F);
+        const int8x16_t  s8b = vdupq_n_s8(0x8);
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+        // 4-bit -> 8-bit
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // sub 8
+        const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
+        const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
+        const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
+        const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+        // dot product into int32x4_t
+        const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
+        const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+    }
+
+    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (; ib < nb; ++ib) {
+        /* Compute combined scale for the block */
+        const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
+
+        __m256i qx = bytes_from_nibbles_32(x[ib].qs);
+
+        // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+        const __m256i off = _mm256_set1_epi8( 8 );
+        qx = _mm256_sub_epi8( qx, off );
+
+        __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
+
+        const __m256 q = mul_sum_i8_pairs_float(qx, qy);
+
+        /* Multiply q with scale and accumulate */
+        acc = _mm256_fmadd_ps( d, q, acc );
+    }
+
+    sumf = hsum_float_8(acc);
+#elif defined(__AVX__)
+    __m256 accum = _mm256_setzero_ps();
+    for (; ib + 1 < nb; ib += 2) {
+        const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
+        const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
+        const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
+        const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
+        const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
+        const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
+
+        const __m128i q4b_1_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_1), _mm_set1_epi8(8));
+        const __m128i q4b_1_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_1, 4)), _mm_set1_epi8(8));
+        const __m128i q4b_2_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_2), _mm_set1_epi8(8));
+        const __m128i q4b_2_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_2, 4)), _mm_set1_epi8(8));
+
+        const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
+        const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
+        const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
+        const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
+        const __m128i p_1 = _mm_add_epi16(p16_1_0, p16_1_1);
+        const __m128i p_2 = _mm_add_epi16(p16_2_0, p16_2_1);
+        const __m256 p =  sum_i16_pairs_float(p_2, p_1);
+
+        const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d);
+        accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
+    }
+
+    sumf = hsum_float_8(accum);
+#elif defined(__SSSE3__)
+    // set constants
+    const __m128i lowMask = _mm_set1_epi8(0xF);
+    const __m128i off = _mm_set1_epi8(8);
+
+    // Initialize accumulator with zeros
+    __m128 acc_0 = _mm_setzero_ps();
+    __m128 acc_1 = _mm_setzero_ps();
+    __m128 acc_2 = _mm_setzero_ps();
+    __m128 acc_3 = _mm_setzero_ps();
+
+    for (; ib + 1 < nb; ib += 2) {
+        _mm_prefetch(&x[ib] + sizeof(block_q4_0), _MM_HINT_T0);
+        _mm_prefetch(&y[ib] + sizeof(block_q8_0), _MM_HINT_T0);
+
+        // Compute combined scale for the block 0 and 1
+        const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
+
+        const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[ib].qs);
+
+        __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
+        __m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);
+        bx_0 = _mm_sub_epi8(bx_0, off);
+        const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
+
+        __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
+        __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16));
+        bx_1 = _mm_sub_epi8(bx_1, off);
+        const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
+
+        _mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
+        _mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
+
+        // Compute combined scale for the block 2 and 3
+        const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) );
+
+        const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
+
+        __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
+        __m128i by_2 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
+        bx_2 = _mm_sub_epi8(bx_2, off);
+        const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
+
+        __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
+        __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[ib + 1].qs + 16));
+        bx_3 = _mm_sub_epi8(bx_3, off);
+        const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
+
+        // Convert int32_t to float
+        __m128 p0 = _mm_cvtepi32_ps(i32_0);
+        __m128 p1 = _mm_cvtepi32_ps(i32_1);
+        __m128 p2 = _mm_cvtepi32_ps(i32_2);
+        __m128 p3 = _mm_cvtepi32_ps(i32_3);
+
+        // Apply the scale
+        __m128 p0_d = _mm_mul_ps( d_0_1, p0 );
+        __m128 p1_d = _mm_mul_ps( d_0_1, p1 );
+        __m128 p2_d = _mm_mul_ps( d_2_3, p2 );
+        __m128 p3_d = _mm_mul_ps( d_2_3, p3 );
+
+        // Acummulate
+        acc_0 = _mm_add_ps(p0_d, acc_0);
+        acc_1 = _mm_add_ps(p1_d, acc_1);
+        acc_2 = _mm_add_ps(p2_d, acc_2);
+        acc_3 = _mm_add_ps(p3_d, acc_3);
+    }
+
+    sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
+#elif defined(__riscv_v_intrinsic)
+    size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+    for (; ib < nb; ++ib) {
+        // load elements
+        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
+
+        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
+        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
+
+        // mask and store lower part of x, and then upper part
+        vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+        vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+
+        vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+        vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+
+        // subtract offset
+        vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl);
+        vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl);
+
+        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+        sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d);
+    }
+
+#elif defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector signed int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+    const vector signed char v8 = vec_splats((signed char)0x8);
+
+    vector float vsumf0 = vec_splats(0.0f);
+
+#pragma GCC unroll 8
+    for (; ib < nb; ++ib) {
+        __builtin_prefetch(x[ib].qs, 0, 1);
+        __builtin_prefetch(y[ib].qs, 0, 1);
+
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
+        vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
+        vector signed char q8y0 = vec_xl( 0, y[ib].qs);
+        vector signed char q8y1 = vec_xl(16, y[ib].qs);
+
+        vector signed char q4x0 = vec_and(qxs, lowMask);
+        vector signed char q4x1 = vec_sr(qxs, v4);
+
+        q4x0 = vec_sub(q4x0, v8);
+        q4x1 = vec_sub(q4x1, v8);
+
+        vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0));
+        vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1));
+
+        vector signed int vsumi0 = v0;
+
+        vsumi0 = vec_sum4s(qv0, vsumi0);
+        vsumi0 = vec_sum4s(qv1, vsumi0);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+    }
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    sumf = vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+    // Initialize accumulator with zeros
+    __m256 acc = (__m256)__lasx_xvldi(0);
+
+    // Main loop
+    for (; ib < nb; ++ib) {
+        /* Compute combined scale for the block */
+        const __m256 d = __lasx_xvreplfr2vr_s( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
+
+        __m256i qx = bytes_from_nibbles_32(x[ib].qs);
+
+        // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+        const __m256i off = __lasx_xvreplgr2vr_b( 8 );
+        qx = __lasx_xvsub_b( qx, off );
+
+        __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
+
+        const __m256 q = mul_sum_i8_pairs_float(qx, qy);
+
+        /* Multiply q with scale and accumulate */
+        acc = __lasx_xvfmadd_s( d, q, acc );
+    }
+
+    sumf = hsum_float_8(acc);
+#elif defined(__loongarch_sx)
+    // set constants
+    const __m128i low_mask = __lsx_vreplgr2vr_b(0xF);
+    const __m128i off = __lsx_vreplgr2vr_b(8);
+
+    // Initialize accumulator with zeros
+    __m128 acc_0 = __lsx_vldi(0);
+    __m128 acc_1 = __lsx_vldi(0);
+    __m128 acc_2 = __lsx_vldi(0);
+    __m128 acc_3 = __lsx_vldi(0);
+
+    for (; ib + 1 < nb; ib += 2) {
+
+        // Compute combined scale for the block 0 and 1
+        const __m128 d_0_1 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
+
+        const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0);
+
+        __m128i bx_0 = __lsx_vand_v(low_mask, tmp_0_1);
+        __m128i by_0 = __lsx_vld((const __m128i *)y[ib].qs, 0);
+        bx_0 = __lsx_vsub_b(bx_0, off);
+        const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
+
+        __m128i bx_1 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_0_1, 4));
+        __m128i by_1 = __lsx_vld((const __m128i *)(y[ib].qs + 16), 0);
+        bx_1 = __lsx_vsub_b(bx_1, off);
+        const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
+
+        //_mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
+        //_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
+
+        // Compute combined scale for the block 2 and 3
+        const __m128 d_2_3 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) );
+
+        const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0);
+
+        __m128i bx_2 = __lsx_vand_v(low_mask, tmp_2_3);
+        __m128i by_2 = __lsx_vld((const __m128i *)y[ib + 1].qs, 0);
+        bx_2 = __lsx_vsub_b(bx_2, off);
+        const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
+
+        __m128i bx_3 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_2_3, 4));
+        __m128i by_3 = __lsx_vld((const __m128i *)(y[ib + 1].qs + 16), 0);
+        bx_3 = __lsx_vsub_b(bx_3, off);
+        const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
+
+        // Convert int32_t to float
+        __m128 p0 = __lsx_vffint_s_w(i32_0);
+        __m128 p1 = __lsx_vffint_s_w(i32_1);
+        __m128 p2 = __lsx_vffint_s_w(i32_2);
+        __m128 p3 = __lsx_vffint_s_w(i32_3);
+
+        // Apply the scale
+        __m128 p0_d = __lsx_vfmul_s( d_0_1, p0 );
+        __m128 p1_d = __lsx_vfmul_s( d_0_1, p1 );
+        __m128 p2_d = __lsx_vfmul_s( d_2_3, p2 );
+        __m128 p3_d = __lsx_vfmul_s( d_2_3, p3 );
+
+        // Acummulate
+        acc_0 = __lsx_vfadd_s(p0_d, acc_0);
+        acc_1 = __lsx_vfadd_s(p1_d, acc_1);
+        acc_2 = __lsx_vfadd_s(p2_d, acc_2);
+        acc_3 = __lsx_vfadd_s(p3_d, acc_3);
+    }
+
+    sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
+#endif
+    for (; ib < nb; ++ib) {
+        int sumi0 = 0;
+        int sumi1 = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int v0 = (x[ib].qs[j] & 0x0F) - 8;
+            const int v1 = (x[ib].qs[j] >>   4) - 8;
+
+            sumi0 += (v0 * y[ib].qs[j]);
+            sumi1 += (v1 * y[ib].qs[j + qk/2]);
+        }
+
+        int sumi = sumi0 + sumi1;
+        sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d);
+    }
+
+    *s = sumf;
+}
+
+void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    const int qk = QK8_1;
+    const int nb = n / qk;
+
+    assert(n % qk == 0);
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    assert((nrc == 2) || (nrc == 1));
+#else
+    assert(nrc == 1);
+#endif
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q4_1 * restrict x = vx;
+    const block_q8_1 * restrict y = vy;
+
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    if (nrc == 2) {
+        const block_q4_1 * restrict vx0 = vx;
+        const block_q4_1 * restrict vx1 = (const block_q4_1 *) ((const uint8_t*)vx + bx);
+        const block_q8_1 * restrict vy0 = vy;
+        const block_q8_1 * restrict vy1 = (const block_q8_1 *) ((const uint8_t*)vy + by);
+
+        float32x4_t sumv0 = vdupq_n_f32(0.0f);
+        float32x4_t summs0 = vdupq_n_f32(0.0f);
+
+        for (int i = 0; i < nb; i++) {
+            const block_q4_1 * restrict b_x0 = &vx0[i];
+            const block_q4_1 * restrict b_x1 = &vx1[i];
+            const block_q8_1 * restrict b_y0 = &vy0[i];
+            const block_q8_1 * restrict b_y1 = &vy1[i];
+
+            float32_t summs_t[4] = {
+                GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
+                GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
+                GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
+                GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)
+            };
+            summs0 = vaddq_f32(summs0, vld1q_f32(summs_t));
+
+            const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+            const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
+            const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
+
+            // 4-bit -> 8-bit
+            const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+            const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+            const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+            const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+            // load y
+            const int8x16_t y0_l = vld1q_s8(b_y0->qs);
+            const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
+            const int8x16_t y1_l = vld1q_s8(b_y1->qs);
+            const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
+
+            // mmla into int32x4_t
+            float32_t _scale[4] = {
+                GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
+                GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
+                GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
+                GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)
+            };
+            float32x4_t scale = vld1q_f32(_scale);
+
+            int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+            int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+
+            int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+            int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+
+            int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+            int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+
+            int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+            int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+            sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
+                                                l1, r1)), l2, r2)), l3, r3))), scale);
+        }
+
+        float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
+        float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
+
+        sumv2 = vaddq_f32(sumv2, summs0);
+
+        vst1_f32(s,      vget_low_f32 (sumv2));
+        vst1_f32(s + bs, vget_high_f32(sumv2));
+
+        return;
+    }
+#endif
+
+    int ib = 0;
+    float sumf = 0;
+
+    // TODO: add WASM SIMD
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    float summs = 0;
+
+    for (; ib + 1 < nb; ib += 2) {
+        const block_q4_1 * restrict x0 = &x[ib + 0];
+        const block_q4_1 * restrict x1 = &x[ib + 1];
+        const block_q8_1 * restrict y0 = &y[ib + 0];
+        const block_q8_1 * restrict y1 = &y[ib + 1];
+
+        summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s) + GGML_FP16_TO_FP32(x1->m) * GGML_FP16_TO_FP32(y1->s);
+
+        const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+        // 4-bit -> 8-bit
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+        // dot product into int32x4_t
+        const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
+        const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+    }
+
+    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
+#elif defined(__AVX2__) || defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    float summs = 0;
+
+    // Main loop
+    for (; ib < nb; ++ib) {
+        const float d0 = GGML_FP16_TO_FP32(x[ib].d);
+        const float d1 = GGML_FP16_TO_FP32(y[ib].d);
+
+        summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
+
+        const __m256 d0v = _mm256_set1_ps( d0 );
+        const __m256 d1v = _mm256_set1_ps( d1 );
+
+        // Compute combined scales
+        const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
+
+        // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
+        const __m256i qx = bytes_from_nibbles_32(x[ib].qs);
+        const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[ib].qs );
+
+        const __m256 xy = mul_sum_us8_pairs_float(qx, qy);
+
+        // Accumulate d0*d1*x*y
+#if defined(__AVX2__)
+        acc = _mm256_fmadd_ps( d0d1, xy, acc );
+#else
+        acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
+#endif
+    }
+
+    sumf = hsum_float_8(acc) + summs;
+#elif defined(__riscv_v_intrinsic)
+    size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+    for (; ib < nb; ++ib) {
+        // load elements
+        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
+
+        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
+        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
+
+        // mask and store lower part of x, and then upper part
+        vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+        vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+
+        vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+        vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+
+        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+        sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
+    }
+
+#elif defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector signed int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+    vector float vsumf0 = vec_splats(0.0f);
+
+#pragma GCC unroll 4
+    for (; ib < nb; ++ib) {
+        __builtin_prefetch(x[ib].qs, 0, 1);
+        __builtin_prefetch(y[ib].qs, 0, 1);
+
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
+        vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[ib].m));
+        vector float vys = {GGML_FP16_TO_FP32(y[ib].s), 0.0f, 0.0f, 0.0f};
+        vsumf0 = vec_madd(vxmin, vys, vsumf0);
+
+        vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
+        vector signed char q8y0 = vec_xl( 0, y[ib].qs);
+        vector signed char q8y1 = vec_xl(16, y[ib].qs);
+
+        vector unsigned char q4x0 = (vector unsigned char)vec_and(qxs, lowMask);
+        vector unsigned char q4x1 = (vector unsigned char)vec_sr(qxs, v4);
+
+        vector signed int vsumi0 = v0;
+
+        vsumi0 = vec_msum(q8y0, q4x0, vsumi0);
+        vsumi0 = vec_msum(q8y1, q4x1, vsumi0);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+    }
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    sumf = vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+    // Initialize accumulator with zeros
+    __m256 acc = (__m256)__lasx_xvldi(0);
+
+    float summs = 0;
+
+    // Main loop
+    for (; ib < nb; ++ib) {
+        const float d0 = GGML_FP16_TO_FP32(x[ib].d);
+        const float d1 = GGML_FP16_TO_FP32(y[ib].d);
+
+        summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
+
+        const __m256 d0v = __lasx_xvreplfr2vr_s( d0 );
+        const __m256 d1v = __lasx_xvreplfr2vr_s( d1 );
+
+        // Compute combined scales
+        const __m256 d0d1 = __lasx_xvfmul_s( d0v, d1v );
+
+        // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
+        const __m256i qx = bytes_from_nibbles_32(x[ib].qs);
+        const __m256i qy = __lasx_xvld( (const __m256i *)y[ib].qs, 0);
+
+        const __m256 xy = mul_sum_us8_pairs_float(qx, qy);
+
+        // Accumulate d0*d1*x*y
+        acc = __lasx_xvfmadd_s( d0d1, xy, acc );
+    }
+
+    sumf = hsum_float_8(acc) + summs;
+#endif
+    for (; ib < nb; ++ib) {
+        int sumi0 = 0;
+        int sumi1 = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int v0 = (x[ib].qs[j] & 0x0F);
+            const int v1 = (x[ib].qs[j] >>   4);
+
+            sumi0 += (v0 * y[ib].qs[j]);
+            sumi1 += (v1 * y[ib].qs[j + qk/2]);
+        }
+
+        int sumi = sumi0 + sumi1;
+        sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
+    }
+
+    *s = sumf;
+}
+
+void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+
+    int ib = 0;
+    float sumf = 0;
+
+    assert(n % qk == 0);
+    assert(qk == QK5_0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q5_0 * restrict x = vx;
+    const block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    uint32_t qh0;
+    uint32_t qh1;
+
+    uint64_t tmp0[4];
+    uint64_t tmp1[4];
+
+    for (; ib + 1 < nb; ib += 2) {
+        const block_q5_0 * restrict x0 = &x[ib];
+        const block_q5_0 * restrict x1 = &x[ib + 1];
+        const block_q8_0 * restrict y0 = &y[ib];
+        const block_q8_0 * restrict y1 = &y[ib + 1];
+
+        const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+        // extract the 5th bit via lookup table ((!b) << 4)
+        memcpy(&qh0, x0->qh, sizeof(qh0));
+        memcpy(&qh1, x1->qh, sizeof(qh1));
+
+        tmp0[0] = table_b2b_1[(qh0 >>  0) & 0xFF];
+        tmp0[1] = table_b2b_1[(qh0 >>  8) & 0xFF];
+        tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];
+        tmp0[3] = table_b2b_1[(qh0 >> 24)       ];
+
+        tmp1[0] = table_b2b_1[(qh1 >>  0) & 0xFF];
+        tmp1[1] = table_b2b_1[(qh1 >>  8) & 0xFF];
+        tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];
+        tmp1[3] = table_b2b_1[(qh1 >> 24)       ];
+
+        const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
+        const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
+        const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
+        const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+        // 4-bit -> 8-bit
+        int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
+        const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0);
+        const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0);
+        const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1);
+        const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1);
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
+                        ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
+                        ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
+                        ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
+                        ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+    }
+
+    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__wasm_simd128__)
+    v128_t sumv = wasm_f32x4_splat(0.0f);
+
+    uint32_t qh;
+    uint64_t tmp[4];
+
+    // TODO: check if unrolling this is better
+    for (; ib < nb; ++ib) {
+        const block_q5_0 * restrict x0 = &x[ib];
+        const block_q8_0 * restrict y0 = &y[ib];
+
+        const v128_t m4b  = wasm_i8x16_splat(0x0F);
+
+        // extract the 5th bit
+        memcpy(&qh, x0->qh, sizeof(qh));
+
+        tmp[0] = table_b2b_1[(qh >>  0) & 0xFF];
+        tmp[1] = table_b2b_1[(qh >>  8) & 0xFF];
+        tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
+        tmp[3] = table_b2b_1[(qh >> 24)       ];
+
+        const v128_t qhl = wasm_v128_load(tmp + 0);
+        const v128_t qhh = wasm_v128_load(tmp + 2);
+
+        const v128_t v0 = wasm_v128_load(x0->qs);
+
+        // 4-bit -> 8-bit
+        const v128_t v0l = wasm_v128_and (v0, m4b);
+        const v128_t v0h = wasm_u8x16_shr(v0, 4);
+
+        // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
+        const v128_t v0lf = wasm_i8x16_sub(v0l, qhl);
+        const v128_t v0hf = wasm_i8x16_sub(v0h, qhh);
+
+        // load y
+        const v128_t v1l = wasm_v128_load(y0->qs);
+        const v128_t v1h = wasm_v128_load(y0->qs + 16);
+
+        // int8x16 -> int16x8
+        const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
+        const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
+        const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
+        const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
+
+        const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
+        const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
+        const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
+        const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
+
+        // dot product
+        sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
+                        wasm_i32x4_add(
+                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
+                                           wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
+                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
+                                           wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
+                    wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d))));
+    }
+
+    sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
+           wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (; ib < nb; ++ib) {
+        /* Compute combined scale for the block */
+        const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
+
+        __m256i qx = bytes_from_nibbles_32(x[ib].qs);
+        __m256i bxhi = bytes_from_bits_32(x[ib].qh);
+        bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
+        qx = _mm256_or_si256(qx, bxhi);
+
+        __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
+
+        const __m256 q = mul_sum_i8_pairs_float(qx, qy);
+
+        /* Multiply q with scale and accumulate */
+        acc = _mm256_fmadd_ps(d, q, acc);
+    }
+
+    sumf = hsum_float_8(acc);
+#elif defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+    __m128i mask = _mm_set1_epi8((char)0xF0);
+
+    // Main loop
+    for (; ib < nb; ++ib) {
+        /* Compute combined scale for the block */
+        const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
+
+        __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs);
+        const __m256i bxhi = bytes_from_bits_32(x[ib].qh);
+        __m128i bxhil = _mm256_castsi256_si128(bxhi);
+        __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
+        bxhil = _mm_andnot_si128(bxhil, mask);
+        bxhih = _mm_andnot_si128(bxhih, mask);
+        __m128i bxl = _mm256_castsi256_si128(bx_0);
+        __m128i bxh = _mm256_extractf128_si256(bx_0, 1);
+        bxl = _mm_or_si128(bxl, bxhil);
+        bxh = _mm_or_si128(bxh, bxhih);
+        bx_0 = MM256_SET_M128I(bxh, bxl);
+
+        const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs);
+
+        const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0);
+
+        /* Multiply q with scale and accumulate */
+        acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
+    }
+
+    sumf = hsum_float_8(acc);
+#elif defined(__riscv_v_intrinsic)
+    uint32_t qh;
+
+    size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+    // These temporary registers are for masking and shift operations
+    vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
+    vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
+
+    vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
+    vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
+
+    for (; ib < nb; ++ib) {
+        memcpy(&qh, x[ib].qh, sizeof(uint32_t));
+
+        // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
+        vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
+        vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl);
+        vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
+
+        // ((qh & (1u << (j + 16))) >> (j + 12));
+        vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl);
+        vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl);
+
+        // narrowing
+        vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl);
+        vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
+
+        vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl);
+        vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
+
+        // load
+        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
+
+        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
+        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
+
+        vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+        vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+
+        vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
+        vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
+
+        vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+        vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+
+        vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
+        vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
+
+        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+        sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi;
+    }
+
+#elif defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector unsigned char v4 = vec_splats((unsigned char)4);
+
+    vector float vsumf0 = vec_splats(0.0f);
+
+#pragma GCC unroll 4
+    for (; ib < nb; ++ib) {
+        __builtin_prefetch(x[ib].qs, 0, 1);
+        __builtin_prefetch(y[ib].qs, 0, 1);
+
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
+        vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed long long aux64x2_0 = {(uint64_t)(table_b2b_1[x[ib].qh[0]]), (uint64_t)(table_b2b_1[x[ib].qh[1]])};
+        vector signed long long aux64x2_1 = {(uint64_t)(table_b2b_1[x[ib].qh[2]]), (uint64_t)(table_b2b_1[x[ib].qh[3]])};
+
+        vector signed char qh0 = (vector signed char)aux64x2_0;
+        vector signed char qh1 = (vector signed char)aux64x2_1;
+
+        vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
+
+        vector signed char q5x0 = vec_sub(vec_and (qxs, lowMask), qh0);
+        vector signed char q5x1 = vec_sub(vec_sr(qxs, v4), qh1);
+
+        vector signed char q8y0 = vec_xl(  0, y[ib].qs);
+        vector signed char q8y1 = vec_xl( 16, y[ib].qs);
+
+        vector signed short qv0 = vec_add(vec_mule(q5x0, q8y0), vec_mulo(q5x0, q8y0));
+        vector signed short qv1 = vec_add(vec_mule(q5x1, q8y1), vec_mulo(q5x1, q8y1));
+
+        qv0 = vec_add(qv0, qv1);
+
+        vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0));
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+    }
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    sumf = vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+    // Initialize accumulator with zeros
+    __m256 acc = (__m256)__lasx_xvldi(0);
+
+    // Main loop
+    for (; ib < nb; ++ib) {
+        /* Compute combined scale for the block */
+        const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); //FIXME
+
+        __m256i qx = bytes_from_nibbles_32(x[ib].qs);
+        __m256i bxhi = bytes_from_bits_32(x[ib].qh);
+        bxhi = __lasx_xvandn_v(bxhi, __lasx_xvreplgr2vr_b((char)0xF0));
+        qx = __lasx_xvor_v(qx, bxhi);
+
+        __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
+
+        const __m256 q = mul_sum_i8_pairs_float(qx, qy);
+
+        /* Multiply q with scale and accumulate */
+        acc = __lasx_xvfmadd_s(d, q, acc);
+    }
+
+    sumf = hsum_float_8(acc);
+#endif
+    for (; ib < nb; ++ib) {
+        uint32_t qh;
+        memcpy(&qh, x[ib].qh, sizeof(qh));
+
+        int sumi0 = 0;
+        int sumi1 = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
+            const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
+
+            const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16);
+            const int32_t x1 = (int8_t)(((x[ib].qs[j] >>   4) | xh_1) - 16);
+
+            sumi0 += (x0 * y[ib].qs[j]);
+            sumi1 += (x1 * y[ib].qs[j + qk/2]);
+        }
+
+        int sumi = sumi0 + sumi1;
+        sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi;
+    }
+
+    *s = sumf;
+}
+
+void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    const int qk = QK8_1;
+    const int nb = n / qk;
+
+    int ib = 0;
+    float sumf = 0;
+
+    assert(n % qk == 0);
+    assert(qk == QK5_1);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q5_1 * restrict x = vx;
+    const block_q8_1 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    float summs0 = 0.0f;
+    float summs1 = 0.0f;
+
+    uint32_t qh0;
+    uint32_t qh1;
+
+    uint64_t tmp0[4];
+    uint64_t tmp1[4];
+
+    for (; ib + 1 < nb; ib += 2) {
+        const block_q5_1 * restrict x0 = &x[ib];
+        const block_q5_1 * restrict x1 = &x[ib + 1];
+        const block_q8_1 * restrict y0 = &y[ib];
+        const block_q8_1 * restrict y1 = &y[ib + 1];
+
+        const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+        summs0 += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s);
+        summs1 += GGML_FP16_TO_FP32(x1->m) * GGML_FP16_TO_FP32(y1->s);
+
+        // extract the 5th bit via lookup table ((b) << 4)
+        memcpy(&qh0, x0->qh, sizeof(qh0));
+        memcpy(&qh1, x1->qh, sizeof(qh1));
+
+        tmp0[0] = table_b2b_0[(qh0 >>  0) & 0xFF];
+        tmp0[1] = table_b2b_0[(qh0 >>  8) & 0xFF];
+        tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];
+        tmp0[3] = table_b2b_0[(qh0 >> 24)       ];
+
+        tmp1[0] = table_b2b_0[(qh1 >>  0) & 0xFF];
+        tmp1[1] = table_b2b_0[(qh1 >>  8) & 0xFF];
+        tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];
+        tmp1[3] = table_b2b_0[(qh1 >> 24)       ];
+
+        const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
+        const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
+        const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
+        const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+        // 4-bit -> 8-bit
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // add high bit
+        const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0);
+        const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0);
+        const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1);
+        const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1);
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
+                        ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
+                        ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
+                        ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
+                        ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+    }
+
+    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
+#elif defined(__wasm_simd128__)
+    v128_t sumv = wasm_f32x4_splat(0.0f);
+
+    float summs = 0.0f;
+
+    uint32_t qh;
+    uint64_t tmp[4];
+
+    // TODO: check if unrolling this is better
+    for (; ib < nb; ++ib) {
+        const block_q5_1 * restrict x0 = &x[ib];
+        const block_q8_1 * restrict y0 = &y[ib];
+
+        summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s);
+
+        const v128_t m4b = wasm_i8x16_splat(0x0F);
+
+        // extract the 5th bit
+        memcpy(&qh, x0->qh, sizeof(qh));
+
+        tmp[0] = table_b2b_0[(qh >>  0) & 0xFF];
+        tmp[1] = table_b2b_0[(qh >>  8) & 0xFF];
+        tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
+        tmp[3] = table_b2b_0[(qh >> 24)       ];
+
+        const v128_t qhl = wasm_v128_load(tmp + 0);
+        const v128_t qhh = wasm_v128_load(tmp + 2);
+
+        const v128_t v0 = wasm_v128_load(x0->qs);
+
+        // 4-bit -> 8-bit
+        const v128_t v0l = wasm_v128_and (v0, m4b);
+        const v128_t v0h = wasm_u8x16_shr(v0, 4);
+
+        // add high bit
+        const v128_t v0lf = wasm_v128_or(v0l, qhl);
+        const v128_t v0hf = wasm_v128_or(v0h, qhh);
+
+        // load y
+        const v128_t v1l = wasm_v128_load(y0->qs);
+        const v128_t v1h = wasm_v128_load(y0->qs + 16);
+
+        // int8x16 -> int16x8
+        const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
+        const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
+        const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
+        const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
+
+        const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
+        const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
+        const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
+        const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
+
+        // dot product
+        sumv = wasm_f32x4_add(sumv,
+                wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add(
+                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
+                                           wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
+                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
+                                           wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
+                    wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d))));
+    }
+
+    sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
+           wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    float summs = 0.0f;
+
+    // Main loop
+    for (; ib < nb; ++ib) {
+        const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d));
+
+        summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
+
+        __m256i qx = bytes_from_nibbles_32(x[ib].qs);
+        __m256i bxhi = bytes_from_bits_32(x[ib].qh);
+        bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
+        qx = _mm256_or_si256(qx, bxhi);
+
+        const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[ib].d));
+        const __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
+
+        const __m256 q = mul_sum_us8_pairs_float(qx, qy);
+
+        acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
+    }
+
+    sumf = hsum_float_8(acc) + summs;
+#elif defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+    __m128i mask = _mm_set1_epi8(0x10);
+
+    float summs = 0.0f;
+
+    // Main loop
+    for (; ib < nb; ++ib) {
+        const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d));
+
+        summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
+
+        __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs);
+        const __m256i bxhi = bytes_from_bits_32(x[ib].qh);
+        __m128i bxhil = _mm256_castsi256_si128(bxhi);
+        __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
+        bxhil = _mm_and_si128(bxhil, mask);
+        bxhih = _mm_and_si128(bxhih, mask);
+        __m128i bxl = _mm256_castsi256_si128(bx_0);
+        __m128i bxh = _mm256_extractf128_si256(bx_0, 1);
+        bxl = _mm_or_si128(bxl, bxhil);
+        bxh = _mm_or_si128(bxh, bxhih);
+        bx_0 = MM256_SET_M128I(bxh, bxl);
+
+        const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[ib].d));
+        const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs);
+
+        const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0);
+
+        acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
+    }
+
+    sumf = hsum_float_8(acc) + summs;
+#elif defined(__riscv_v_intrinsic)
+    uint32_t qh;
+
+    size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+    // temporary registers for shift operations
+    vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
+    vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
+
+    for (; ib < nb; ++ib) {
+        memcpy(&qh, x[ib].qh, sizeof(uint32_t));
+
+        // load qh
+        vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
+
+        // ((qh >> (j +  0)) << 4) & 0x10;
+        vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl);
+        vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
+        vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl);
+
+        // ((qh >> (j + 12))     ) & 0x10;
+        vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl);
+        vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl);
+
+        // narrowing
+        vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl);
+        vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
+
+        vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl);
+        vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
+
+        // load
+        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
+
+        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
+        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
+
+        vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+        vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+
+        vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
+        vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
+
+        vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+        vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+
+        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+        sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
+    }
+
+#elif defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector signed int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+    vector float vsumf0 = vec_splats(0.0f);
+
+#pragma GCC unroll 4
+    for (; ib < nb; ++ib) {
+        __builtin_prefetch(x[ib].qs, 0, 1);
+        __builtin_prefetch(y[ib].qs, 0, 1);
+
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
+        vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[ib].m));
+        vector float vys = {GGML_FP16_TO_FP32(y[ib].s), 0.f, 0.f, 0.f};
+        vsumf0 = vec_madd(vxmin, vys, vsumf0);
+
+        vector unsigned long long aux64x2_0 = {(uint64_t)(table_b2b_0[x[ib].qh[0]]), (uint64_t)(table_b2b_0[x[ib].qh[1]])};
+        vector unsigned long long aux64x2_1 = {(uint64_t)(table_b2b_0[x[ib].qh[2]]), (uint64_t)(table_b2b_0[x[ib].qh[3]])};
+
+        vector signed char qh0 = (vector signed char)aux64x2_0;
+        vector signed char qh1 = (vector signed char)aux64x2_1;
+
+        vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
+
+        vector unsigned char q5x0 = (vector unsigned char)vec_or(vec_and(qxs, lowMask), qh0);
+        vector unsigned char q5x1 = (vector unsigned char)vec_or(vec_sr(qxs, v4), qh1);
+
+        vector signed char q8y0 = vec_xl(  0, y[ib].qs);
+        vector signed char q8y1 = vec_xl( 16, y[ib].qs);
+
+        vector signed int vsumi0 = v0;
+
+        vsumi0 = vec_msum(q8y0, q5x0, vsumi0);
+        vsumi0 = vec_msum(q8y1, q5x1, vsumi0);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+    }
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    sumf = vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+    // Initialize accumulator with zeros
+    __m256 acc = (__m256)__lasx_xvldi(0);
+
+    float summs = 0.0f;
+
+    // Main loop
+    for (; ib < nb; ++ib) {
+        const __m256 dx = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d));
+
+        summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
+
+        __m256i qx = bytes_from_nibbles_32(x[ib].qs);
+        __m256i bxhi = bytes_from_bits_32(x[ib].qh);
+        bxhi = __lasx_xvand_v(bxhi, __lasx_xvreplgr2vr_b(0x10));
+        qx = __lasx_xvor_v(qx, bxhi);
+
+        const __m256 dy = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib].d));
+        const __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
+
+        const __m256 q = mul_sum_us8_pairs_float(qx, qy);
+
+        acc = __lasx_xvfmadd_s(q, __lasx_xvfmul_s(dx, dy), acc);
+    }
+
+    sumf = hsum_float_8(acc) + summs;
+#endif
+    for (; ib < nb; ++ib) {
+        uint32_t qh;
+        memcpy(&qh, x[ib].qh, sizeof(qh));
+
+        int sumi0 = 0;
+        int sumi1 = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
+            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
+
+            const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0;
+            const int32_t x1 = (x[ib].qs[j] >>  4) | xh_1;
+
+            sumi0 += (x0 * y[ib].qs[j]);
+            sumi1 += (x1 * y[ib].qs[j + qk/2]);
+        }
+
+        int sumi = sumi0 + sumi1;
+        sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
+    }
+
+    *s = sumf;
+}
+
+void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+
+    assert(n % qk == 0);
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    assert((nrc == 2) || (nrc == 1));
+#else
+    assert(nrc == 1);
+#endif
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q8_0 * restrict x = vx;
+    const block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    if (nrc == 2) {
+        const block_q8_0 * restrict vx0 = vx;
+        const block_q8_0 * restrict vx1 = (const block_q8_0 *) ((const uint8_t*)vx + bx);
+        const block_q8_0 * restrict vy0 = vy;
+        const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
+
+        float32x4_t sumv0 = vdupq_n_f32(0.0f);
+
+        for (int i = 0; i < nb; i++) {
+            const block_q8_0 * restrict b_x0 = &vx0[i];
+            const block_q8_0 * restrict b_y0 = &vy0[i];
+
+            const block_q8_0 * restrict b_x1 = &vx1[i];
+            const block_q8_0 * restrict b_y1 = &vy1[i];
+
+            const int8x16_t x0_l = vld1q_s8(b_x0->qs);
+            const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16);
+            const int8x16_t x1_l = vld1q_s8(b_x1->qs);
+            const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16);
+
+            // load y
+            const int8x16_t y0_l = vld1q_s8(b_y0->qs);
+            const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
+            const int8x16_t y1_l = vld1q_s8(b_y1->qs);
+            const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
+
+            float32_t _scale[4] = {
+                GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
+                GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
+                GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
+                GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)
+            };
+            float32x4_t scale = vld1q_f32(_scale);
+
+            int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+            int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+
+            int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+            int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+
+            int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+            int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+
+            int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+            int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+
+            sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
+                                                l1, r1)), l2, r2)), l3, r3))), scale);
+        }
+
+        float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
+        float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
+
+        vst1_f32(s,      vget_low_f32 (sumv2));
+        vst1_f32(s + bs, vget_high_f32(sumv2));
+
+        return;
+    }
+#endif
+
+    int ib = 0;
+    float sumf = 0;
+
+#if defined(__ARM_FEATURE_SVE)
+    svfloat32_t sumv0 = svdup_n_f32(0.0f);
+    svfloat32_t sumv1 = svdup_n_f32(0.0f);
+
+    const int vector_length = ggml_cpu_get_sve_cnt()*8;
+
+    //VLA Implemenation for SVE
+    switch (vector_length) {
+        case 128:
+            {
+                // predicate for activating lanes for 16 Int8 elements
+                const svbool_t ph16 = svptrue_pat_b8 (SV_VL16);
+                const svbool_t pl16 = svptrue_pat_b32(SV_VL4);
+
+                for (; ib + 1 < nb; ib += 2) {
+                    const block_q8_0 * restrict x0 = &x[ib + 0];
+                    const block_q8_0 * restrict x1 = &x[ib + 1];
+                    const block_q8_0 * restrict y0 = &y[ib + 0];
+                    const block_q8_0 * restrict y1 = &y[ib + 1];
+
+                    // load x
+                    const svint8_t qx0_0 = svld1_s8(ph16, x0->qs);
+                    const svint8_t qx0_1 = svld1_s8(ph16, x0->qs+16);
+                    const svint8_t qx1_0 = svld1_s8(ph16, x1->qs);
+                    const svint8_t qx1_1 = svld1_s8(ph16, x1->qs+16);
+
+                    // load y
+                    const svint8_t qy0_0 = svld1_s8(ph16, y0->qs);
+                    const svint8_t qy0_1 = svld1_s8(ph16, y0->qs+16);
+                    const svint8_t qy1_0 = svld1_s8(ph16, y1->qs);
+                    const svint8_t qy1_1 = svld1_s8(ph16, y1->qs+16);
+
+                    sumv0 = svmla_n_f32_x(pl16, sumv0, svcvt_f32_s32_x(pl16, svadd_x(pl16,
+                                    svdot_s32(svdup_n_s32(0), qx0_0, qy0_0),
+                                    svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+                    sumv1 = svmla_n_f32_x(pl16, sumv1, svcvt_f32_s32_x(pl16, svadd_x(pl16,
+                                    svdot_s32(svdup_n_s32(0), qx1_0, qy1_0),
+                                    svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+                }
+
+                sumf = svaddv_f32(pl16, svadd_f32_x(pl16, sumv0, sumv1));
+            } break;
+        case 256:
+            {
+                //printf("sve256");
+                for (; ib + 1 < nb; ib += 2) {
+                    const block_q8_0 * restrict x0 = &x[ib + 0];
+                    const block_q8_0 * restrict x1 = &x[ib + 1];
+                    const block_q8_0 * restrict y0 = &y[ib + 0];
+                    const block_q8_0 * restrict y1 = &y[ib + 1];
+
+                    // load x
+                    const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
+                    const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
+
+                    // load y
+                    const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
+                    const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
+
+                    sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
+                                svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+                    sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
+                                svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+                }
+
+                sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
+            } break;
+        case 512:
+            {
+                // predicate for activating high 256 bit
+                const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
+                // predicate for activating low 256 bit
+                const svbool_t pl32 = svnot_b_z(svptrue_b8(), ph32);
+
+                // predicate for activating high lanes for 8 float32 elements
+                const svbool_t ph8 = svptrue_pat_b32(SV_VL8);
+                // predicate for activating low lanes for 8 float32 elements
+                const svbool_t pl8 = svnot_b_z(svptrue_b32(), ph8);
+
+                svfloat32_t sumv00 = svdup_n_f32(0.0f);
+
+                for (; ib + 1 < nb; ib += 2) {
+                    const block_q8_0 * restrict x0 = &x[ib + 0];
+                    const block_q8_0 * restrict x1 = &x[ib + 1];
+                    const block_q8_0 * restrict y0 = &y[ib + 0];
+                    const block_q8_0 * restrict y1 = &y[ib + 1];
+
+                    //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits
+                    // and add them to make one 64 element vector
+                    // load x
+                    const svint8_t qx_32 = svld1_s8(ph32, x0->qs);
+                          svint8_t qx_64 = svld1_s8(pl32, x0->qs + 2);
+
+                    qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64);
+
+                    // load y
+                    const svint8_t qy_32 = svld1_s8(ph32, y0->qs);
+                          svint8_t qy_64 = svld1_s8(pl32, y0->qs + 2);
+
+                    qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64);
+
+                    // scale creation
+                    const float32_t deq1 = GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d);
+                    const float32_t deq2 = GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d);
+
+                    // duplicate deq1 in first half of vector and deq2 in second half of vector
+                    const svfloat32_t temp = svdup_f32_m(svdup_f32_z(ph8, deq1), pl8, deq2);
+
+                    const svfloat32_t sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64));
+
+                    sumv00 = svmla_f32_m(svptrue_b32(), sumv00, sumvt, temp);
+                }
+
+                sumf = svaddv_f32(svptrue_b32(), sumv00);
+                break;
+            }
+        default:
+            assert(false && "Unsupported vector length");
+            break;
+    }
+#elif defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    for (; ib + 1 < nb; ib += 2) {
+        const block_q8_0 * restrict x0 = &x[ib + 0];
+        const block_q8_0 * restrict x1 = &x[ib + 1];
+        const block_q8_0 * restrict y0 = &y[ib + 0];
+        const block_q8_0 * restrict y1 = &y[ib + 1];
+
+        const int8x16_t x0_0 = vld1q_s8(x0->qs);
+        const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
+        const int8x16_t x1_0 = vld1q_s8(x1->qs);
+        const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
+
+        // load y
+        const int8x16_t y0_0 = vld1q_s8(y0->qs);
+        const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
+        const int8x16_t y1_0 = vld1q_s8(y1->qs);
+        const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
+                        ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
+                        ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
+                        ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
+                        ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+    }
+
+    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (; ib < nb; ++ib) {
+        // Compute combined scale for the block
+        const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
+        __m256i qx = _mm256_loadu_si256((const __m256i *)x[ib].qs);
+        __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
+
+        const __m256 q = mul_sum_i8_pairs_float(qx, qy);
+
+        // Multiply q with scale and accumulate
+        acc = _mm256_fmadd_ps( d, q, acc );
+    }
+
+    sumf = hsum_float_8(acc);
+#elif defined(__AVX__)
+    __m256 accum = _mm256_setzero_ps();
+
+    for (; ib + 1 < nb; ib += 2) {
+        const __m128i qx_1_0 = _mm_loadu_si128((const __m128i *)x[ib].qs);
+        const __m128i qx_1_1 = _mm_loadu_si128((const __m128i *)x[ib].qs + 1);
+        const __m128i qx_2_0 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
+        const __m128i qx_2_1 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs + 1);
+        const __m128i qy_1_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);
+        const __m128i qy_1_1 = _mm_loadu_si128((const __m128i *)y[ib].qs + 1);
+        const __m128i qy_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
+        const __m128i qy_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
+
+        const __m256 p = mul_sum_i8_quad_float(qx_1_0, qx_1_1, qx_2_0, qx_2_1, qy_1_0, qy_1_1, qy_2_0, qy_2_1);
+        const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d);
+        accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
+    }
+
+    sumf = hsum_float_8(accum);
+#elif defined(__riscv_v_intrinsic)
+    size_t vl = __riscv_vsetvl_e8m1(qk);
+
+    for (; ib < nb; ++ib) {
+        // load elements
+        vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[ib].qs, vl);
+        vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
+
+        vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl);
+
+        vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
+        vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
+
+        sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d));
+    }
+#elif defined(__POWER9_VECTOR__)
+    const vector signed int v0 = vec_splats((int32_t)0);
+    vector float vsumf0 = vec_splats(0.0f);
+
+#pragma GCC unroll 8
+    for (; ib < nb; ++ib) {
+        __builtin_prefetch(x[ib].qs, 0, 1);
+        __builtin_prefetch(y[ib].qs, 0, 1);
+
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
+        vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed char q8x0 = vec_xl( 0, x[ib].qs);
+        vector signed char q8x1 = vec_xl(16, x[ib].qs);
+        vector signed char q8y0 = vec_xl( 0, y[ib].qs);
+        vector signed char q8y1 = vec_xl(16, y[ib].qs);
+
+        vector signed short qv0 = vec_mule(q8x0, q8y0);
+        vector signed short qv1 = vec_mulo(q8x0, q8y0);
+        vector signed short qv2 = vec_mule(q8x1, q8y1);
+        vector signed short qv3 = vec_mulo(q8x1, q8y1);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+
+        vsumi0 = vec_sum4s(qv0, vsumi0);
+        vsumi1 = vec_sum4s(qv1, vsumi1);
+        vsumi0 = vec_sum4s(qv2, vsumi0);
+        vsumi1 = vec_sum4s(qv3, vsumi1);
+
+        vsumi0 = vec_add(vsumi0, vsumi1);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+    }
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    sumf = vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+    // Initialize accumulator with zeros
+    __m256 acc = (__m256)__lasx_xvldi(0);
+
+    // Main loop
+    for (; ib < nb; ++ib) {
+        // Compute combined scale for the block
+        const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
+        __m256i qx = __lasx_xvld((const __m256i *)x[ib].qs, 0);
+        __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
+
+        const __m256 q = mul_sum_i8_pairs_float(qx, qy);
+
+        // Multiply q with scale and accumulate
+        acc = __lasx_xvfmadd_s( d, q, acc );
+    }
+
+    sumf = hsum_float_8(acc);
+#endif
+    for (; ib < nb; ++ib) {
+        int sumi = 0;
+
+        for (int j = 0; j < qk; j++) {
+            sumi += x[ib].qs[j]*y[ib].qs[j];
+        }
+
+        sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d));
+    }
+
+    *s = sumf;
+}
+
+void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_tq1_0 * restrict x = vx;
+    const block_q8_K  * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__ARM_NEON)
+    float sumf = 0.0f;
+
+    uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
+
+    const uint8x16_t shift = vld1q_u8(k_shift);
+
+    for (int i = 0; i < nb; ++i) {
+#if defined(__ARM_FEATURE_DOTPROD)
+        int32x4_t sumi0 = vdupq_n_s32(0);
+        int32x4_t sumi1 = vdupq_n_s32(0);
+#else
+        int16x8_t sumi0 = vdupq_n_s16(0);
+        int16x8_t sumi1 = vdupq_n_s16(0);
+#endif
+
+        // first 32 bytes of 5 elements
+        {
+            uint8x16_t qx0 = vld1q_u8(x[i].qs + 0);
+            uint8x16_t qx1 = vld1q_u8(x[i].qs + 16);
+            uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3));
+            uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3));
+            uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9));
+            uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9));
+            uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27));
+            uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27));
+            uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81));
+            uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81));
+
+            // multiply by 3 and keep the 2 bits above 8 bits
+            int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
+            int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
+            int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
+            int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
+            int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
+            int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
+            int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6));
+            int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6));
+            int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6));
+            int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6));
+
+            const int8x16_t qy0 = vld1q_s8(y[i].qs +   0);
+            const int8x16_t qy1 = vld1q_s8(y[i].qs +  16);
+            const int8x16_t qy2 = vld1q_s8(y[i].qs +  32);
+            const int8x16_t qy3 = vld1q_s8(y[i].qs +  48);
+            const int8x16_t qy4 = vld1q_s8(y[i].qs +  64);
+            const int8x16_t qy5 = vld1q_s8(y[i].qs +  80);
+            const int8x16_t qy6 = vld1q_s8(y[i].qs +  96);
+            const int8x16_t qy7 = vld1q_s8(y[i].qs + 112);
+            const int8x16_t qy8 = vld1q_s8(y[i].qs + 128);
+            const int8x16_t qy9 = vld1q_s8(y[i].qs + 144);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+            sumi0 = vdotq_s32(sumi0, sqx0, qy0);
+            sumi1 = vdotq_s32(sumi1, sqx1, qy1);
+            sumi0 = vdotq_s32(sumi0, sqx2, qy2);
+            sumi1 = vdotq_s32(sumi1, sqx3, qy3);
+            sumi0 = vdotq_s32(sumi0, sqx4, qy4);
+            sumi1 = vdotq_s32(sumi1, sqx5, qy5);
+            sumi0 = vdotq_s32(sumi0, sqx6, qy6);
+            sumi1 = vdotq_s32(sumi1, sqx7, qy7);
+            sumi0 = vdotq_s32(sumi0, sqx8, qy8);
+            sumi1 = vdotq_s32(sumi1, sqx9, qy9);
+#else
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx8), vget_low_s8(qy8));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9));
+#endif
+        }
+
+        // last 16 bytes of 5-element, along with the 4 bytes of 4 elements
+        {
+            uint8x16_t qx0 = vld1q_u8(x[i].qs + 32);
+            uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3));
+            uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9));
+            uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27));
+            uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81));
+            uint32_t qh;
+            memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
+            uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh));
+            qx5 = vmulq_u8(qx5, shift);
+
+            // multiply by 3 and keep the 2 bits above 8 bits
+            int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
+            int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
+            int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
+            int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
+            int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
+            int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
+
+            const int8x16_t qy0 = vld1q_s8(y[i].qs + 160);
+            const int8x16_t qy1 = vld1q_s8(y[i].qs + 176);
+            const int8x16_t qy2 = vld1q_s8(y[i].qs + 192);
+            const int8x16_t qy3 = vld1q_s8(y[i].qs + 208);
+            const int8x16_t qy4 = vld1q_s8(y[i].qs + 224);
+            const int8x16_t qy5 = vld1q_s8(y[i].qs + 240);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+            sumi0 = vdotq_s32(sumi0, sqx0, qy0);
+            sumi1 = vdotq_s32(sumi1, sqx1, qy1);
+            sumi0 = vdotq_s32(sumi0, sqx2, qy2);
+            sumi1 = vdotq_s32(sumi1, sqx3, qy3);
+            sumi0 = vdotq_s32(sumi0, sqx4, qy4);
+            sumi1 = vdotq_s32(sumi1, sqx5, qy5);
+#else
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
+#endif
+        }
+
+        const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
+        const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
+
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        sumi0 = vaddq_s32(sumi0, sumi1);
+        sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
+
+        sumf += d * (float) vaddvq_s32(sumi0);
+#else
+        sumi0 = vaddq_s16(sumi0, sumi1);
+        sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
+
+        sumf += d * (float) vaddlvq_s16(sumi0);
+#endif
+    }
+
+    *s = sumf;
+
+#elif defined(__AVX2__)
+    __m256 sumf = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+        // 16-bit sums
+        __m256i sumi0 = _mm256_setzero_si256();
+        __m256i sumi1 = _mm256_setzero_si256();
+        __m256i sumi2 = _mm256_setzero_si256();
+
+        // first 32 bytes of 5 elements
+        {
+            __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs));
+            // 8-bit multiplies with shifts, masks and adds
+            __m256i qx1 = _mm256_add_epi8(qx0, _mm256_add_epi8(qx0, qx0)); // 1 * 3
+            __m256i qx2 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx0, 3), _mm256_set1_epi8(-8)), qx0); // 1 * 9
+            __m256i qx3 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx1, 3), _mm256_set1_epi8(-8)), qx1); // 3 * 9
+            __m256i qx4 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx2, 3), _mm256_set1_epi8(-8)), qx2); // 9 * 9
+
+            // TODO: can _mm256_mulhi_epu16 be faster even if 16-bits?
+
+            // Cancel the +1 from avg so that it behaves like a halving add
+            qx0 = _mm256_subs_epu8(qx0, _mm256_set1_epi8(1));
+            qx1 = _mm256_subs_epu8(qx1, _mm256_set1_epi8(1));
+            qx2 = _mm256_subs_epu8(qx2, _mm256_set1_epi8(1));
+            qx3 = _mm256_subs_epu8(qx3, _mm256_set1_epi8(1));
+            qx4 = _mm256_subs_epu8(qx4, _mm256_set1_epi8(1));
+            // Multiply by 3 and get the top 2 bits
+            qx0 = _mm256_avg_epu8(qx0, _mm256_avg_epu8(qx0, _mm256_setzero_si256()));
+            qx1 = _mm256_avg_epu8(qx1, _mm256_avg_epu8(qx1, _mm256_setzero_si256()));
+            qx2 = _mm256_avg_epu8(qx2, _mm256_avg_epu8(qx2, _mm256_setzero_si256()));
+            qx3 = _mm256_avg_epu8(qx3, _mm256_avg_epu8(qx3, _mm256_setzero_si256()));
+            qx4 = _mm256_avg_epu8(qx4, _mm256_avg_epu8(qx4, _mm256_setzero_si256()));
+            qx0 = _mm256_and_si256(_mm256_srli_epi16(qx0, 6), _mm256_set1_epi8(3));
+            qx1 = _mm256_and_si256(_mm256_srli_epi16(qx1, 6), _mm256_set1_epi8(3));
+            qx2 = _mm256_and_si256(_mm256_srli_epi16(qx2, 6), _mm256_set1_epi8(3));
+            qx3 = _mm256_and_si256(_mm256_srli_epi16(qx3, 6), _mm256_set1_epi8(3));
+            qx4 = _mm256_and_si256(_mm256_srli_epi16(qx4, 6), _mm256_set1_epi8(3));
+
+            const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs +   0));
+            const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs +  32));
+            const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs +  64));
+            const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs +  96));
+            const __m256i qy4 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 128));
+
+            qx0 = _mm256_maddubs_epi16(qx0, qy0);
+            qx1 = _mm256_maddubs_epi16(qx1, qy1);
+            qx2 = _mm256_maddubs_epi16(qx2, qy2);
+            qx3 = _mm256_maddubs_epi16(qx3, qy3);
+            qx4 = _mm256_maddubs_epi16(qx4, qy4);
+
+            sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));
+            sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));
+            sumi2 = _mm256_add_epi16(sumi2, qx4);
+        }
+
+        // last 16 bytes of 5-element, along with the 4 bytes of 4 elements
+        {
+            __m128i qx0 = _mm_loadu_si128((const __m128i *) (x[i].qs + 32));
+            uint32_t qh;
+            memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
+            __m256i qx5_l = _mm256_cvtepu8_epi16(_mm_set1_epi32(qh));
+            __m128i qx1 = _mm_add_epi8(qx0, _mm_add_epi8(qx0, qx0)); // 1 * 3
+            __m128i qx2 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx0, 3), _mm_set1_epi8(-8)), qx0); // 1 * 9
+            __m128i qx3 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx1, 3), _mm_set1_epi8(-8)), qx1); // 3 * 9
+            __m128i qx4 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx2, 3), _mm_set1_epi8(-8)), qx2); // 9 * 9
+            __m256i qx01 = MM256_SET_M128I(qx1, qx0);
+            __m256i qx23 = MM256_SET_M128I(qx3, qx2);
+
+            // avx2 does not have 8-bit multiplies, so 16-bit it is.
+            qx5_l = _mm256_mullo_epi16(qx5_l, _mm256_set_epi16(27, 27, 27, 27, 9, 9, 9, 9, 3, 3, 3, 3, 1, 1, 1, 1));
+            qx5_l = _mm256_and_si256(qx5_l, _mm256_set1_epi16(0xFF));
+            __m128i qx5 = _mm_packus_epi16(_mm256_castsi256_si128(qx5_l), _mm256_extracti128_si256(qx5_l, 1));
+
+            __m256i qx45 = MM256_SET_M128I(qx5, qx4);
+
+            // Cancel the +1 from avg so that it behaves like a halving add
+            qx01 = _mm256_subs_epu8(qx01, _mm256_set1_epi8(1));
+            qx23 = _mm256_subs_epu8(qx23, _mm256_set1_epi8(1));
+            qx45 = _mm256_subs_epu8(qx45, _mm256_set1_epi8(1));
+            // Multiply by 3 and get the top 2 bits
+            qx01 = _mm256_avg_epu8(qx01, _mm256_avg_epu8(qx01, _mm256_setzero_si256()));
+            qx23 = _mm256_avg_epu8(qx23, _mm256_avg_epu8(qx23, _mm256_setzero_si256()));
+            qx45 = _mm256_avg_epu8(qx45, _mm256_avg_epu8(qx45, _mm256_setzero_si256()));
+            qx01 = _mm256_and_si256(_mm256_srli_epi16(qx01, 6), _mm256_set1_epi8(3));
+            qx23 = _mm256_and_si256(_mm256_srli_epi16(qx23, 6), _mm256_set1_epi8(3));
+            qx45 = _mm256_and_si256(_mm256_srli_epi16(qx45, 6), _mm256_set1_epi8(3));
+
+            const __m256i qy01 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 160));
+            const __m256i qy23 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 192));
+            const __m256i qy45 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 224));
+
+            qx01 = _mm256_maddubs_epi16(qx01, qy01);
+            qx23 = _mm256_maddubs_epi16(qx23, qy23);
+            qx45 = _mm256_maddubs_epi16(qx45, qy45);
+
+            sumi0 = _mm256_add_epi16(sumi0, qx01);
+            sumi1 = _mm256_add_epi16(sumi1, qx23);
+            sumi2 = _mm256_add_epi16(sumi2, qx45);
+        }
+
+        const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);
+        const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d));
+
+        sumi0 = _mm256_sub_epi16(sumi0, ysum);
+        sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(sumi1, sumi2));
+        sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));
+
+        sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf);
+    }
+
+    *s = hsum_float_8(sumf);
+
+#else
+    const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
+
+    float sumf = 0.0f;
+
+    for (int i = 0; i < nb; ++i) {
+        int sum = 0;
+
+        for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) {
+            for (size_t l = 0; l < 5; ++l) {
+                for (size_t m = 0; m < 32; ++m) {
+                    uint8_t q = x[i].qs[j + m] * pow3[l];
+                    uint16_t xi = ((uint16_t) q * 3) >> 8;
+                    sum += (xi - 1) * y[i].qs[j*5 + l*32 + m];
+                }
+            }
+        }
+        for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) {
+            for (size_t l = 0; l < 5; ++l) {
+                for (size_t m = 0; m < 16; ++m) {
+                    uint8_t q = x[i].qs[j + m] * pow3[l];
+                    uint16_t xi = ((uint16_t) q * 3) >> 8;
+                    sum += (xi - 1) * y[i].qs[j*5 + l*16 + m];
+                }
+            }
+        }
+
+        for (size_t l = 0; l < 4; ++l) {
+            for (size_t j = 0; j < sizeof(x->qh); ++j) {
+                uint8_t q = x[i].qh[j] * pow3[l];
+                uint16_t xi = ((uint16_t) q * 3) >> 8;
+                sum += (xi - 1) * y[i].qs[sizeof(x->qs)*5 + l*sizeof(x->qh) + j];
+            }
+        }
+
+        sumf += (float) sum * (GGML_FP16_TO_FP32(x[i].d) * y[i].d);
+    }
+
+    *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_tq2_0 * restrict x = vx;
+    const block_q8_K  * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__ARM_NEON)
+    float sumf = 0.0f;
+
+    const uint8x16_t m3 = vdupq_n_u8(3);
+
+    for (int i = 0; i < nb; ++i) {
+#if defined(__ARM_FEATURE_DOTPROD)
+        int32x4_t sumi0 = vdupq_n_s32(0);
+        int32x4_t sumi1 = vdupq_n_s32(0);
+#else
+        int16x8_t sumi0 = vdupq_n_s16(0);
+        int16x8_t sumi1 = vdupq_n_s16(0);
+#endif
+
+        for (size_t j = 0; j < sizeof(x->qs); j += 32) {
+            uint8x16_t qx0 = vld1q_u8(x[i].qs + j);
+            uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16);
+            uint8x16_t qx2 = vshrq_n_u8(qx0, 2);
+            uint8x16_t qx3 = vshrq_n_u8(qx1, 2);
+            uint8x16_t qx4 = vshrq_n_u8(qx0, 4);
+            uint8x16_t qx5 = vshrq_n_u8(qx1, 4);
+            uint8x16_t qx6 = vshrq_n_u8(qx0, 6);
+            uint8x16_t qx7 = vshrq_n_u8(qx1, 6);
+
+            int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3));
+            int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3));
+            int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3));
+            int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3));
+            int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3));
+            int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3));
+            int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3));
+            int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3));
+
+            const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 +   0);
+            const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 +  16);
+            const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 +  32);
+            const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 +  48);
+            const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 +  64);
+            const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 +  80);
+            const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 +  96);
+            const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+            sumi0 = vdotq_s32(sumi0, sqx0, qy0);
+            sumi1 = vdotq_s32(sumi1, sqx1, qy1);
+            sumi0 = vdotq_s32(sumi0, sqx2, qy2);
+            sumi1 = vdotq_s32(sumi1, sqx3, qy3);
+            sumi0 = vdotq_s32(sumi0, sqx4, qy4);
+            sumi1 = vdotq_s32(sumi1, sqx5, qy5);
+            sumi0 = vdotq_s32(sumi0, sqx6, qy6);
+            sumi1 = vdotq_s32(sumi1, sqx7, qy7);
+#else
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
+#endif
+        }
+
+        const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
+        const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
+
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        sumi0 = vaddq_s32(sumi0, sumi1);
+        sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
+
+        sumf += d * (float) vaddvq_s32(sumi0);
+#else
+        sumi0 = vaddq_s16(sumi0, sumi1);
+        sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
+
+        sumf += d * (float) vaddlvq_s16(sumi0);
+#endif
+    }
+
+    *s = sumf;
+
+#elif defined(__AVX2__)
+    __m256 sumf = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+        // 16-bit sums, because 256*127 still fits
+        __m256i sumi0 = _mm256_setzero_si256();
+        __m256i sumi1 = _mm256_setzero_si256();
+
+        for (size_t j = 0; j < sizeof(x->qs); j += 32) {
+            __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs + j));
+            __m256i qx1 = _mm256_srli_epi16(qx0, 2);
+            __m256i qx2 = _mm256_srli_epi16(qx0, 4);
+            __m256i qx3 = _mm256_srli_epi16(qx0, 6);
+
+            // 0, 1, 2 (should not be 3)
+            qx0 = _mm256_and_si256(qx0, _mm256_set1_epi8(3));
+            qx1 = _mm256_and_si256(qx1, _mm256_set1_epi8(3));
+            qx2 = _mm256_and_si256(qx2, _mm256_set1_epi8(3));
+            qx3 = _mm256_and_si256(qx3, _mm256_set1_epi8(3));
+
+            const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 +  0));
+            const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 32));
+            const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 64));
+            const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 96));
+
+            qx0 = _mm256_maddubs_epi16(qx0, qy0);
+            qx1 = _mm256_maddubs_epi16(qx1, qy1);
+            qx2 = _mm256_maddubs_epi16(qx2, qy2);
+            qx3 = _mm256_maddubs_epi16(qx3, qy3);
+
+            sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));
+            sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));
+        }
+
+        const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);
+        const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d));
+
+        sumi0 = _mm256_add_epi16(sumi0, sumi1);
+        sumi0 = _mm256_sub_epi16(sumi0, ysum);
+        sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));
+
+        sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf);
+    }
+
+    *s = hsum_float_8(sumf);
+
+#else
+    float sumf = 0.0f;
+
+    for (int i = 0; i < nb; ++i) {
+        int32_t sumi = 0;
+
+        for (size_t j = 0; j < sizeof(x->qs); j += 32) {
+            for (size_t l = 0; l < 4; ++l) {
+                for (size_t k = 0; k < 32; ++k) {
+                    sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].qs[j + k] >> (l*2)) & 3) - 1);
+                }
+            }
+        }
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        sumf += (float) sumi * d;
+    }
+
+    *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q2_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+    const uint8x16_t m3 = vdupq_n_u8(0x3);
+    const uint8x16_t m4 = vdupq_n_u8(0xF);
+
+    const int32x4_t vzero = vdupq_n_s32(0);
+
+    ggml_int8x16x2_t q2bytes;
+    uint8_t aux[16];
+
+    float sum = 0;
+
+    for (int i = 0; i < nb; ++i) {
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * restrict sc = x[i].scales;
+
+        const uint8x16_t mins_and_scales = vld1q_u8(sc);
+        const uint8x16_t scales = vandq_u8(mins_and_scales, m4);
+        vst1q_u8(aux, scales);
+
+        const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
+        const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
+        const ggml_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}};
+        const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
+                                       vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
+        const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
+                                       vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1])));
+        sum += dmin * vaddvq_s32(vaddq_s32(s0, s1));
+
+        int isum = 0;
+        int is = 0;
+
+// We use this macro instead of a function call because for some reason
+// the code runs 2-3% slower, even if the function is declared inline
+#define MULTIPLY_ACCUM_WITH_SCALE(index)\
+        isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
+        isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
+
+#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
+        q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
+        q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
+        q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
+        MULTIPLY_ACCUM_WITH_SCALE((index));
+
+        for (int j = 0; j < QK_K/128; ++j) {
+            const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32;
+
+            ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
+            q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
+            q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
+
+            MULTIPLY_ACCUM_WITH_SCALE(0);
+
+            SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);
+            SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);
+            SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);
+
+            is += 8;
+        }
+
+        sum += d * isum;
+    }
+
+    *s = sum;
+
+#elif defined __AVX2__
+
+    const __m256i m3 = _mm256_set1_epi8(3);
+    const __m128i m4 = _mm_set1_epi8(0xF);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+        const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
+        const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
+        const __m256i mins = _mm256_cvtepi8_epi16(mins8);
+        const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums));
+
+        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc);
+
+        const __m256i all_scales = _mm256_cvtepi8_epi16(scales8);
+        const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
+        const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
+        const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
+
+        __m256i sumi = _mm256_setzero_si256();
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32;
+
+            const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+            const __m256i q2_0 = _mm256_and_si256(q2bits, m3);
+            const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3);
+            const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3);
+            const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3);
+
+            __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0);
+            __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1);
+            __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2);
+            __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3);
+
+            p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0);
+            p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1);
+            p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2);
+            p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3);
+
+            p0 = _mm256_add_epi32(p0, p1);
+            p2 = _mm256_add_epi32(p2, p3);
+
+            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2));
+        }
+
+        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __AVX__
+
+    const __m128i m3 = _mm_set1_epi8(0x3);
+    const __m128i m4 = _mm_set1_epi8(0xF);
+    const __m128i m2 = _mm_set1_epi8(0x2);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        // load mins and scales from block_q2_K.scales[QK_K/16]
+        const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+        const __m128i scales16 = _mm_and_si128(mins_and_scales, m4);
+        const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
+        const __m128i mins_0 = _mm_cvtepi8_epi16(mins16);
+        const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16));
+
+        // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2
+        const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0]));
+        const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8]));
+
+        // sumf += -dmin * summs in 32bits*8
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc);
+
+        const __m128i scales_0 = _mm_cvtepi8_epi16(scales16);
+        const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16));
+        const __m128i scales[2] = { scales_0, scales_1 };
+
+        __m128i sumi_0 = _mm_setzero_si128();
+        __m128i sumi_1 = _mm_setzero_si128();
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K]
+            const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+
+            // load 2bits*16*8 from block_q2_K.qs[QK_K/4]
+            __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
+            const __m128i q2_0 = _mm_and_si128(q2bits, m3);
+            const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
+            const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
+            const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
+            q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
+            const __m128i q2_1 = _mm_and_si128(q2bits, m3);
+            const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
+            const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
+            const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
+
+            // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8
+            __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0);
+            __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1);
+            __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2);
+            __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3);
+            __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4);
+            __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5);
+            __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6);
+            __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7);
+
+            // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8
+            __m128i shuffle = _mm_set1_epi16(0x0100);
+            p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7);
+
+            p0 = _mm_add_epi32(p0, p1);
+            p2 = _mm_add_epi32(p2, p3);
+            p4 = _mm_add_epi32(p4, p5);
+            p6 = _mm_add_epi32(p6, p7);
+
+            // isum in 32bits*4*2
+            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2));
+            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6));
+        }
+
+        // sumf += dall * isum - dmin * summs in 32bits
+        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc);
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __riscv_v_intrinsic
+
+    float sumf = 0;
+    uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
+
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * q2 = x[i].qs;
+        const  int8_t * q8 = y[i].qs;
+        const uint8_t * sc = x[i].scales;
+
+        const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        size_t vl = 16;
+
+        vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl);
+        vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl);
+
+        vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl);
+
+        vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl);
+        vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl);
+        vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
+        vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl);
+        vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
+
+        sumf  += dmin * __riscv_vmv_x_s_i32m1_i32(vsums);
+
+        vl = 32;
+
+        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+        vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl);
+
+        uint8_t is=0;
+        int isum=0;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+            // load Q2
+            vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl);
+
+            vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl);
+            vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03 , vl);
+            vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03 , vl);
+            vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl);
+
+            // duplicate scale elements for product
+            vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0+is, vl), vl);
+            vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2+is, vl), vl);
+            vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4+is, vl), vl);
+            vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl);
+
+            vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl));
+            vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl));
+            vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl));
+            vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl));
+
+            // load Q8
+            vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
+            vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
+            vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8+64, vl);
+            vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl);
+
+            vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl);
+            vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl);
+            vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl);
+            vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl);
+
+            vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl);
+            vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl);
+
+            isum += __riscv_vmv_x_s_i32m1_i32(isum1);
+
+            q2+=32;  q8+=128;  is=8;
+
+        }
+
+        sumf += dall * isum;
+
+    }
+
+    *s = sumf;
+
+#elif defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0x3);
+    const vector signed char lowScaleMask = vec_splats((signed char)0xF);
+    const vector int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v2 = vec_splats((unsigned char)0x2);
+    const vector unsigned char v6 = vec_splats((unsigned char)0x6);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin));
+        vector float vdmin = vec_mul(vxmin, vyd);
+
+        vector signed short q8ysums0 = vec_xl( 0, y[i].bsums);
+        vector signed short q8ysums1 = vec_xl(16, y[i].bsums);
+
+        vector signed char q2xmins = (vector signed char)vec_xl( 0, x[i].scales);
+        vector signed char vscales = vec_and(q2xmins, lowScaleMask);
+
+        q2xmins = vec_sr(q2xmins, v4);
+        vector signed short q2xmins0 = vec_unpackh(q2xmins);
+        vector signed short q2xmins1 = vec_unpackl(q2xmins);
+
+        vector signed int prod0 = vec_mule(q2xmins0, q8ysums0);
+        vector signed int prod1 = vec_mulo(q2xmins0, q8ysums0);
+        vector signed int prod2 = vec_mule(q2xmins1, q8ysums1);
+        vector signed int prod3 = vec_mulo(q2xmins1, q8ysums1);
+
+        vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0);
+        vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1);
+        vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2);
+        vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+        vector signed int vsumi4 = v0;
+        vector signed int vsumi5 = v0;
+        vector signed int vsumi6 = v0;
+        vector signed int vsumi7 = v0;
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+            __builtin_prefetch(q2, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector signed char qxs0 = (vector signed char)vec_xl( 0, q2);
+            vector signed char qxs1 = (vector signed char)vec_xl(16, q2);
+            q2 += 32;
+
+            vector unsigned char q2x00 = (vector unsigned char)vec_and(qxs0, lowMask);
+            vector unsigned char q2x01 = (vector unsigned char)vec_and(vec_sr(qxs0, v2), lowMask);
+            vector unsigned char q2x02 = (vector unsigned char)vec_and(vec_sr(qxs0, v4), lowMask);
+            vector unsigned char q2x03 = (vector unsigned char)vec_and(vec_sr(qxs0, v6), lowMask);
+            vector unsigned char q2x10 = (vector unsigned char)vec_and(qxs1, lowMask);
+            vector unsigned char q2x11 = (vector unsigned char)vec_and(vec_sr(qxs1, v2), lowMask);
+            vector unsigned char q2x12 = (vector unsigned char)vec_and(vec_sr(qxs1, v4), lowMask);
+            vector unsigned char q2x13 = (vector unsigned char)vec_and(vec_sr(qxs1, v6), lowMask);
+
+            vector signed char q8y00 = vec_xl(  0, q8);
+            vector signed char q8y10 = vec_xl( 16, q8);
+            vector signed char q8y01 = vec_xl( 32, q8);
+            vector signed char q8y11 = vec_xl( 48, q8);
+            vector signed char q8y02 = vec_xl( 64, q8);
+            vector signed char q8y12 = vec_xl( 80, q8);
+            vector signed char q8y03 = vec_xl( 96, q8);
+            vector signed char q8y13 = vec_xl(112, q8);
+            q8 += 128;
+
+            vector signed int qv0 = vec_msum(q8y00, q2x00, v0);
+            vector signed int qv1 = vec_msum(q8y01, q2x01, v0);
+            vector signed int qv2 = vec_msum(q8y02, q2x02, v0);
+            vector signed int qv3 = vec_msum(q8y03, q2x03, v0);
+            vector signed int qv4 = vec_msum(q8y10, q2x10, v0);
+            vector signed int qv5 = vec_msum(q8y11, q2x11, v0);
+            vector signed int qv6 = vec_msum(q8y12, q2x12, v0);
+            vector signed int qv7 = vec_msum(q8y13, q2x13, v0);
+
+            vector signed short vscales_07 = vec_unpackh(vscales);
+            vector signed int vscales_03 = vec_unpackh(vscales_07);
+            vector signed int vscales_47 = vec_unpackl(vscales_07);
+            vector signed int vs0 = vec_splat(vscales_03, 0);
+            vector signed int vs1 = vec_splat(vscales_03, 1);
+            vector signed int vs2 = vec_splat(vscales_03, 2);
+            vector signed int vs3 = vec_splat(vscales_03, 3);
+            vector signed int vs4 = vec_splat(vscales_47, 0);
+            vector signed int vs5 = vec_splat(vscales_47, 1);
+            vector signed int vs6 = vec_splat(vscales_47, 2);
+            vector signed int vs7 = vec_splat(vscales_47, 3);
+            vscales = vec_sld(vscales, vscales, 8);
+
+            vsumi0 = vec_add(vec_mul(qv0, vs0), vsumi0);
+            vsumi1 = vec_add(vec_mul(qv1, vs2), vsumi1);
+            vsumi2 = vec_add(vec_mul(qv2, vs4), vsumi2);
+            vsumi3 = vec_add(vec_mul(qv3, vs6), vsumi3);
+            vsumi4 = vec_add(vec_mul(qv4, vs1), vsumi4);
+            vsumi5 = vec_add(vec_mul(qv5, vs3), vsumi5);
+            vsumi6 = vec_add(vec_mul(qv6, vs5), vsumi6);
+            vsumi7 = vec_add(vec_mul(qv7, vs7), vsumi7);
+        }
+
+        vsumi0 = vec_add(vsumi0, vsumi4);
+        vsumi1 = vec_add(vsumi1, vsumi5);
+        vsumi2 = vec_add(vsumi2, vsumi6);
+        vsumi3 = vec_add(vsumi3, vsumi7);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = vec_extract(vsumf0, 0);
+
+#elif defined __loongarch_asx
+
+    const __m256i m3 = __lasx_xvreplgr2vr_b(3);
+    const __m128i m4 = __lsx_vreplgr2vr_b(0xF);
+
+    __m256 acc = (__m256)__lasx_xvldi(0);
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m128i mins_and_scales = __lsx_vld((const __m128i*)x[i].scales, 0);
+        const __m128i scales8 = __lsx_vand_v(mins_and_scales, m4);
+        const __m128i mins8 = __lsx_vand_v(__lsx_vsrli_h(mins_and_scales, 4), m4);
+        const __m256i mins = lasx_ext8_16(mins8);
+        const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0));
+
+        acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc);
+
+        const __m256i all_scales = lasx_ext8_16(scales8);
+        const __m128i l_scales = lasx_extracti128(all_scales, 0);
+        const __m128i h_scales = lasx_extracti128(all_scales, 1);
+        const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
+
+        __m256i sumi = __lasx_xvldi(0);
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            const __m256i q2bits = __lasx_xvld((const __m256i*)q2, 0); q2 += 32;
+
+            const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+
+            const __m256i q2_0 = __lasx_xvand_v(q2bits, m3);
+            const __m256i q2_1 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 2), m3);
+            const __m256i q2_2 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 4), m3);
+            const __m256i q2_3 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 6), m3);
+
+            __m256i p0 = lasx_maddubs_h(q2_0, q8_0);
+            __m256i p1 = lasx_maddubs_h(q2_1, q8_1);
+            __m256i p2 = lasx_maddubs_h(q2_2, q8_2);
+            __m256i p3 = lasx_maddubs_h(q2_3, q8_3);
+
+            p0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(0)), p0);
+            p1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(1)), p1);
+            p2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(2)), p2);
+            p3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(3)), p3);
+
+            p0 = __lasx_xvadd_w(p0, p1);
+            p2 = __lasx_xvadd_w(p2, p3);
+
+            sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p0, p2));
+        }
+
+        acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
+#else
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * q2 = x[i].qs;
+        const  int8_t * q8 = y[i].qs;
+        const uint8_t * sc = x[i].scales;
+
+        int summs = 0;
+        for (int j = 0; j < 16; ++j) {
+            summs += y[i].bsums[j] * (sc[j] >> 4);
+        }
+
+        const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        int isum = 0;
+        int is = 0;
+        int d;
+        for (int k = 0; k < QK_K/128; ++k) {
+            int shift = 0;
+            for (int j = 0; j < 4; ++j) {
+                d = sc[is++] & 0xF;
+                int isuml = 0;
+                for (int l =  0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
+                isum += d * isuml;
+                d = sc[is++] & 0xF;
+                isuml = 0;
+                for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
+                isum += d * isuml;
+                shift += 2;
+                q8 += 32;
+            }
+            q2 += 32;
+        }
+        sumf += dall * isum - dmin * summs;
+    }
+    *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const uint32_t kmask1 = 0x03030303;
+    const uint32_t kmask2 = 0x0f0f0f0f;
+
+    const block_q3_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+    uint32_t aux[3];
+    uint32_t utmp[4];
+
+    const uint8x16_t m3b = vdupq_n_u8(0x3);
+    const int32x4_t  vzero = vdupq_n_s32(0);
+
+    const uint8x16_t m0 = vdupq_n_u8(1);
+    const uint8x16_t m1 = vshlq_n_u8(m0, 1);
+    const uint8x16_t m2 = vshlq_n_u8(m0, 2);
+    const uint8x16_t m3 = vshlq_n_u8(m0, 3);
+    const int8_t m32 = 32;
+
+    ggml_int8x16x4_t q3bytes;
+
+    float sum = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict qh = x[i].hmask;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
+
+        ggml_uint8x16x4_t q3h;
+
+        int32_t isum = 0;
+
+        // Set up scales
+        memcpy(aux, x[i].scales, 12);
+        utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
+        utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
+        utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
+        utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
+
+        int8_t * scale = (int8_t *)utmp;
+        for (int j = 0; j < 16; ++j) scale[j] -= m32;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32;
+            const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64;
+            const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+            q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
+            q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
+            q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1);
+            q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1);
+
+            q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0]));
+            q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1]));
+            q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
+            q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
+
+            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
+            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
+            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
+            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
+
+            scale += 4;
+
+            q3h.val[0] = vbicq_u8(m2, qhbits.val[0]);
+            q3h.val[1] = vbicq_u8(m2, qhbits.val[1]);
+            q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1);
+            q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1);
+
+            q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0]));
+            q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1]));
+            q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
+            q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
+
+            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
+            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
+            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
+            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
+
+            scale += 4;
+
+            if (j == 0) {
+                qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4);
+                qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4);
+            }
+
+        }
+        sum += d * isum;
+
+    }
+
+    *s = sum;
+
+#elif defined __AVX2__
+
+    const __m256i m3 = _mm256_set1_epi8(3);
+    const __m256i mone = _mm256_set1_epi8(1);
+    const __m128i m32 = _mm_set1_epi8(32);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    uint32_t aux[3];
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        // Set up scales
+        memcpy(aux, x[i].scales, 12);
+        __m128i scales128 = _mm_set_epi32(
+                ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
+                ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
+                (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
+                (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
+        scales128 = _mm_sub_epi8(scales128, m32);
+        const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
+        const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
+        const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
+        const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
+
+        // high bit
+        const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask);
+
+        // integer accumulator
+        __m256i sumi = _mm256_setzero_si256();
+
+        int bit = 0;
+        int is  = 0;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+            // load low 2 bits
+            const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32;
+
+            // prepare low and high bits
+            const __m256i q3l_0 = _mm256_and_si256(q3bits, m3);
+            const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
+            ++bit;
+
+            const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3);
+            const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
+            ++bit;
+
+            const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3);
+            const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
+            ++bit;
+
+            const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3);
+            const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
+            ++bit;
+
+            // load Q8 quants
+            const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+            // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
+            // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
+            // and 2 if the high bit was set)
+            __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
+            __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
+            __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2);
+            __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3);
+
+            __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
+            __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
+            __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2);
+            __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3);
+
+            p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
+            p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
+            p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
+            p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
+
+            // multiply with scales
+            p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
+            p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
+            p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
+            p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
+
+            // accumulate
+            p16_0 = _mm256_add_epi32(p16_0, p16_1);
+            p16_2 = _mm256_add_epi32(p16_2, p16_3);
+            sumi  = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2));
+
+        }
+
+        // multiply with block scale and accumulate
+        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __AVX__
+
+    const __m128i m3 = _mm_set1_epi8(3);
+    const __m128i mone = _mm_set1_epi8(1);
+    const __m128i m32 = _mm_set1_epi8(32);
+    const __m128i m2 = _mm_set1_epi8(2);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    const uint32_t *aux;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        // Set up scales
+        aux = (const uint32_t *)x[i].scales;
+        __m128i scales128 = _mm_set_epi32(
+                ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
+                ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
+                (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
+                (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
+        scales128 = _mm_sub_epi8(scales128, m32);
+        const __m128i scales_0 = _mm_cvtepi8_epi16(scales128);
+        const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128));
+        const __m128i scales[2] = { scales_0, scales_1 };
+
+        // high bit *128*2 from block_q3_K.hmask[QK_K/8]
+        const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]);
+        const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]);
+
+        // integer accumulator
+        __m128i sumi_0 = _mm_setzero_si128();
+        __m128i sumi_1 = _mm_setzero_si128();
+
+        for (int j = 0; j < QK_K/128; ++j) {
+            // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4]
+            const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
+            const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
+
+            // prepare low and high bits
+            const int bit = j << 2;
+
+            const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3);
+            const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3);
+            const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2);
+            const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2);
+
+            const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3);
+            const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3);
+            const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
+            const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
+
+            const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3);
+            const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3);
+            const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
+            const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
+
+            const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3);
+            const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3);
+            const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
+            const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
+
+            // load Q8 quants from block_q8_K.qs[QK_K]
+            const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+
+            // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
+            // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
+            // and 2 if the high bit was set)
+            __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0);
+            __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1);
+            __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2);
+            __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3);
+            __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4);
+            __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5);
+            __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6);
+            __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7);
+
+            __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0);
+            __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1);
+            __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2);
+            __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3);
+            __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4);
+            __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5);
+            __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6);
+            __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7);
+
+            p16_0 = _mm_sub_epi16(p16_0, q8s_0);
+            p16_1 = _mm_sub_epi16(p16_1, q8s_1);
+            p16_2 = _mm_sub_epi16(p16_2, q8s_2);
+            p16_3 = _mm_sub_epi16(p16_3, q8s_3);
+            p16_4 = _mm_sub_epi16(p16_4, q8s_4);
+            p16_5 = _mm_sub_epi16(p16_5, q8s_5);
+            p16_6 = _mm_sub_epi16(p16_6, q8s_6);
+            p16_7 = _mm_sub_epi16(p16_7, q8s_7);
+
+            // multiply with scales
+            __m128i shuffle = _mm_set1_epi16(0x0100);
+            p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7);
+
+            // accumulate
+            p16_0 = _mm_add_epi32(p16_0, p16_1);
+            p16_2 = _mm_add_epi32(p16_2, p16_3);
+            p16_4 = _mm_add_epi32(p16_4, p16_5);
+            p16_6 = _mm_add_epi32(p16_6, p16_7);
+            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
+            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6));
+
+        }
+
+        // multiply with block scale and accumulate
+        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __riscv_v_intrinsic
+
+    uint32_t aux[3];
+    uint32_t utmp[4];
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict qh = x[i].hmask;
+        const  int8_t * restrict q8 = y[i].qs;
+
+        memcpy(aux, x[i].scales, 12);
+        utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
+        utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
+        utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
+        utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
+
+        int8_t * scale = (int8_t *)utmp;
+        for (int j = 0; j < 16; ++j) scale[j] -= 32;
+
+
+        size_t vl = 32;
+        uint8_t m =  1;
+
+        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+        vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl);
+
+        int sum_t = 0;
+
+        for (int j = 0; j < QK_K; j += 128) {
+
+            vl = 32;
+
+            // load Q3
+            vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl);
+
+            vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl));
+            vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl));
+            vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl));
+            vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl));
+
+            // compute mask for subtraction
+            vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
+            vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
+            vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl);
+            m <<= 1;
+
+            vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
+            vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
+            vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl);
+            m <<= 1;
+
+            vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
+            vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
+            vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl);
+            m <<= 1;
+
+            vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
+            vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
+            vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl);
+            m <<= 1;
+
+            // load Q8 and take product with Q3
+            vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl);
+            vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
+            vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
+            vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
+
+            vl = 16;
+
+            // retrieve lane to multiply with scale
+            vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
+            vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
+            vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
+            vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl);
+            vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl);
+            vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl);
+            vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl);
+            vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl);
+
+            vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl);
+            vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl);
+            vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl);
+            vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl);
+
+            sum_t +=  __riscv_vmv_x_s_i32m1_i32(isum3);
+
+            q3 += 32;    q8 += 128;   scale += 8;
+
+        }
+
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+
+        sumf += d*sum_t;
+
+    }
+
+    *s = sumf;
+
+#elif defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0x3);
+    const vector signed char lowMask1 = vec_splats((int8_t)0xf);
+    const vector signed char lowMask2 = vec_splats((int8_t)0x30);
+    const vector int v0 = vec_splats((int32_t)0);
+    const vector signed char v1 = vec_splats((signed char)0x1);
+    const vector unsigned char v2 = vec_splats((unsigned char)0x2);
+    const vector unsigned char v3 = vec_splats((unsigned char)0x3);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+    const vector unsigned char v6 = vec_splats((unsigned char)0x6);
+    const vector signed char off = vec_splats((signed char)0x20);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        UNUSED(kmask1);
+        UNUSED(kmask2);
+
+        vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8);
+        vector signed char u1 = vec_and(u0, lowMask1);
+        vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4);
+        vector signed char u3 = (vector signed char)vec_mergeh((vector signed int)u2, (vector signed int)vec_sr(u2, v2));
+        vector signed char u30 = vec_sl(vec_and(u3, lowMask), v4);
+        vector signed char u31 = vec_and(u3, lowMask2);
+
+        u1 = vec_or(u1, u30);
+        u2 = vec_or(vec_sr(u0, v4), u31);
+
+        vector signed char vscales = (vector signed char)vec_mergeh((vector signed long long)u1, (vector signed long long)u2);
+        vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].hmask);
+        vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].hmask);
+
+        vscales = vec_sub(vscales, off);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+        vector signed int vsumi4 = v0;
+        vector signed int vsumi5 = v0;
+        vector signed int vsumi6 = v0;
+        vector signed int vsumi7 = v0;
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+            __builtin_prefetch(q3, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector signed char qxs0 = (vector signed char)vec_xl( 0, q3);
+            vector signed char qxs1 = (vector signed char)vec_xl(16, q3);
+            q3 += 32;
+
+            //the low 2 bits
+            vector signed char qxs00 = vec_and(qxs0, lowMask);
+            vector signed char qxs01 = vec_and(vec_sr(qxs0, v2), lowMask);
+            vector signed char qxs02 = vec_and(vec_sr(qxs0, v4), lowMask);
+            vector signed char qxs03 = vec_and(vec_sr(qxs0, v6), lowMask);
+            vector signed char qxs10 = vec_and(qxs1, lowMask);
+            vector signed char qxs11 = vec_and(vec_sr(qxs1, v2), lowMask);
+            vector signed char qxs12 = vec_and(vec_sr(qxs1, v4), lowMask);
+            vector signed char qxs13 = vec_and(vec_sr(qxs1, v6), lowMask);
+
+            //the 3rd bit
+            vector signed char qxh00 = vec_sl(vec_andc(v1, qxhs0), v2);
+            vector signed char qxh01 = vec_sl(vec_andc(v1, vec_sr(qxhs0, (vector unsigned char)v1)), v2);
+            vector signed char qxh02 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v2)), v2);
+            vector signed char qxh03 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v3)), v2);
+            vector signed char qxh10 = vec_sl(vec_andc(v1, qxhs1), v2);
+            vector signed char qxh11 = vec_sl(vec_andc(v1, vec_sr(qxhs1, (vector unsigned char)v1)), v2);
+            vector signed char qxh12 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v2)), v2);
+            vector signed char qxh13 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v3)), v2);
+            qxhs0 = vec_sr(qxhs0, v4);
+            qxhs1 = vec_sr(qxhs1, v4);
+
+            vector signed char q3x00 = vec_sub(qxs00, qxh00);
+            vector signed char q3x01 = vec_sub(qxs01, qxh01);
+            vector signed char q3x02 = vec_sub(qxs02, qxh02);
+            vector signed char q3x03 = vec_sub(qxs03, qxh03);
+            vector signed char q3x10 = vec_sub(qxs10, qxh10);
+            vector signed char q3x11 = vec_sub(qxs11, qxh11);
+            vector signed char q3x12 = vec_sub(qxs12, qxh12);
+            vector signed char q3x13 = vec_sub(qxs13, qxh13);
+
+            vector signed char q8y00 = vec_xl(  0, q8);
+            vector signed char q8y10 = vec_xl( 16, q8);
+            vector signed char q8y01 = vec_xl( 32, q8);
+            vector signed char q8y11 = vec_xl( 48, q8);
+            vector signed char q8y02 = vec_xl( 64, q8);
+            vector signed char q8y12 = vec_xl( 80, q8);
+            vector signed char q8y03 = vec_xl( 96, q8);
+            vector signed char q8y13 = vec_xl(112, q8);
+            q8 += 128;
+
+            vector signed short vscales_h = vec_unpackh(vscales);
+            vector signed short vs0 = vec_splat(vscales_h, 0);
+            vector signed short vs1 = vec_splat(vscales_h, 1);
+            vector signed short vs2 = vec_splat(vscales_h, 2);
+            vector signed short vs3 = vec_splat(vscales_h, 3);
+            vector signed short vs4 = vec_splat(vscales_h, 4);
+            vector signed short vs5 = vec_splat(vscales_h, 5);
+            vector signed short vs6 = vec_splat(vscales_h, 6);
+            vector signed short vs7 = vec_splat(vscales_h, 7);
+            vscales = vec_sld(vscales, vscales, 8);
+
+            vector signed short qv00 = vec_add(vec_mule(q3x00, q8y00), vec_mulo(q3x00, q8y00));
+            vector signed short qv01 = vec_add(vec_mule(q3x01, q8y01), vec_mulo(q3x01, q8y01));
+            vector signed short qv02 = vec_add(vec_mule(q3x02, q8y02), vec_mulo(q3x02, q8y02));
+            vector signed short qv03 = vec_add(vec_mule(q3x03, q8y03), vec_mulo(q3x03, q8y03));
+            vector signed short qv10 = vec_add(vec_mule(q3x10, q8y10), vec_mulo(q3x10, q8y10));
+            vector signed short qv11 = vec_add(vec_mule(q3x11, q8y11), vec_mulo(q3x11, q8y11));
+            vector signed short qv12 = vec_add(vec_mule(q3x12, q8y12), vec_mulo(q3x12, q8y12));
+            vector signed short qv13 = vec_add(vec_mule(q3x13, q8y13), vec_mulo(q3x13, q8y13));
+
+            vsumi0 = vec_msum(qv00, vs0, vsumi0);
+            vsumi1 = vec_msum(qv01, vs2, vsumi1);
+            vsumi2 = vec_msum(qv02, vs4, vsumi2);
+            vsumi3 = vec_msum(qv03, vs6, vsumi3);
+            vsumi4 = vec_msum(qv10, vs1, vsumi4);
+            vsumi5 = vec_msum(qv11, vs3, vsumi5);
+            vsumi6 = vec_msum(qv12, vs5, vsumi6);
+            vsumi7 = vec_msum(qv13, vs7, vsumi7);
+        }
+
+        vsumi0 = vec_add(vsumi0, vsumi4);
+        vsumi1 = vec_add(vsumi1, vsumi5);
+        vsumi2 = vec_add(vsumi2, vsumi6);
+        vsumi3 = vec_add(vsumi3, vsumi7);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = vec_extract(vsumf0, 0);
+
+#elif defined __loongarch_asx
+
+    const __m256i m3 = __lasx_xvreplgr2vr_b(3);
+    const __m256i mone = __lasx_xvreplgr2vr_b(1);
+    const __m128i m32 = __lsx_vreplgr2vr_b(32);
+
+    __m256 acc = (__m256)__lasx_xvldi(0);
+
+    uint32_t aux[3];
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const uint8_t * restrict q3 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+        // Set up scales
+        memcpy(aux, x[i].scales, 12);
+        __m128i scales128 = lsx_set_w(
+                ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
+                ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
+                (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
+                (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
+        scales128 = __lsx_vsub_b(scales128, m32);
+        const __m256i all_scales = lasx_ext8_16(scales128);
+        const __m128i l_scales = lasx_extracti128(all_scales, 0);
+        const __m128i h_scales = lasx_extracti128(all_scales, 1);
+        const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
+
+        // high bit
+        const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0);
+
+        // integer accumulator
+        __m256i sumi = __lasx_xvldi(0);
+
+        int bit = 0;
+        int is  = 0;
+        __m256i xvbit;
+
+
+        for (int j = 0; j < QK_K/128; ++j) {
+            // load low 2 bits
+            const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32;
+
+            xvbit = __lasx_xvreplgr2vr_h(bit);
+            // prepare low and high bits
+            const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3);
+            const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
+            ++bit;
+
+            xvbit = __lasx_xvreplgr2vr_h(bit);
+            const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3);
+            const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
+            ++bit;
+
+            xvbit = __lasx_xvreplgr2vr_h(bit);
+            const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3);
+            const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
+            ++bit;
+
+            xvbit = __lasx_xvreplgr2vr_h(bit);
+            const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3);
+            const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
+            ++bit;
+
+            // load Q8 quants
+            const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+
+            // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use lasx_maddubs_h,
+            // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
+            // and 2 if the high bit was set)
+            __m256i q8s_0 = lasx_maddubs_h(q3h_0, q8_0);
+            __m256i q8s_1 = lasx_maddubs_h(q3h_1, q8_1);
+            __m256i q8s_2 = lasx_maddubs_h(q3h_2, q8_2);
+            __m256i q8s_3 = lasx_maddubs_h(q3h_3, q8_3);
+
+            __m256i p16_0 = lasx_maddubs_h(q3l_0, q8_0);
+            __m256i p16_1 = lasx_maddubs_h(q3l_1, q8_1);
+            __m256i p16_2 = lasx_maddubs_h(q3l_2, q8_2);
+            __m256i p16_3 = lasx_maddubs_h(q3l_3, q8_3);
+
+            p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
+            p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
+            p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
+            p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
+
+            // multiply with scales
+            p16_0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
+            p16_1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
+            p16_2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
+            p16_3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
+
+            // accumulate
+            p16_0 = __lasx_xvadd_w(p16_0, p16_1);
+            p16_2 = __lasx_xvadd_w(p16_2, p16_3);
+            sumi  = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2));
+        }
+        // multiply with block scale and accumulate
+        acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);//FIXME
+    }
+
+    *s = hsum_float_8(acc);
+
+#else
+    // scalar version
+    // This function is written like this so the compiler can manage to vectorize most of it
+    // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
+    // manually vectorized version above. Every other version I tried would run at least 4 times slower.
+    // The ideal situation would be if we could just write the code once, and the compiler would
+    // automatically produce the best possible set of machine instructions, instead of us having to manually
+    // write vectorized versions for AVX, ARM_NEON, etc.
+
+    int8_t  aux8[QK_K];
+    int16_t aux16[8];
+    float   sums [8];
+    int32_t aux32[8];
+    memset(sums, 0, 8*sizeof(float));
+
+    uint32_t auxs[4];
+    const int8_t * scales = (const int8_t*)auxs;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict hm = x[i].hmask;
+        const  int8_t * restrict q8 = y[i].qs;
+        memset(aux32, 0, 8*sizeof(int32_t));
+        int8_t * restrict a = aux8;
+        uint8_t m = 1;
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
+            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
+            a += 32; m <<= 1;
+            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
+            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
+            a += 32; m <<= 1;
+            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
+            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
+            a += 32; m <<= 1;
+            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
+            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
+            a += 32; m <<= 1;
+            q3 += 32;
+        }
+        a = aux8;
+
+        memcpy(auxs, x[i].scales, 12);
+        uint32_t tmp = auxs[2];
+        auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
+        auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
+        auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
+        auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
+        for (int j = 0; j < QK_K/16; ++j) {
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
+            q8 += 8; a += 8;
+        }
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+
+#endif
+
+}
+
+void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q4_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+    static const uint32_t kmask1 = 0x3f3f3f3f;
+    static const uint32_t kmask2 = 0x0f0f0f0f;
+    static const uint32_t kmask3 = 0x03030303;
+
+    uint32_t utmp[4];
+
+#ifdef __ARM_NEON
+    const uint8x16_t m4b = vdupq_n_u8(0xf);
+    const int32x4_t mzero = vdupq_n_s32(0);
+
+    ggml_int8x16x2_t q4bytes;
+    ggml_int8x16x2_t q8bytes;
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
+
+        memcpy(utmp, x[i].scales, 12);
+
+        uint32x2_t mins8 = { 0 };
+        mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
+        mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
+
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[0] &= kmask1;
+
+        const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
+        const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
+                                         vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
+        sumf -= dmin * vaddvq_s32(prod);
+
+        const uint8_t * scales = (const uint8_t *)utmp;
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        int32_t sumi1 = 0;
+        int32_t sumi2 = 0;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+            const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
+
+            q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
+            q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
+            q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
+
+            const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
+            sumi1 += vaddvq_s32(p1) * scales[2*j+0];
+
+            q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
+            q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
+            q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
+
+            const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
+
+            sumi2 += vaddvq_s32(p2) * scales[2*j+1];
+        }
+
+        sumf += d * (sumi1 + sumi2);
+
+    }
+
+    *s = sumf;
+
+#elif defined __AVX2__
+
+    const __m256i m4 = _mm256_set1_epi8(0xF);
+
+    __m256 acc = _mm256_setzero_ps();
+    __m128 acc_m = _mm_setzero_ps();
+
+   for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
+
+        const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
+        const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
+        const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
+        acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
+
+        const __m128i sc128  = _mm256_extracti128_si256(mins_and_scales, 0);
+        const __m256i scales = MM256_SET_M128I(sc128, sc128);
+
+        __m256i sumi = _mm256_setzero_si256();
+
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
+            const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
+
+            const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
+            const __m256i q4l = _mm256_and_si256(q4bits, m4);
+            const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
+
+            const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
+            p16l = _mm256_madd_epi16(scale_l, p16l);
+
+            const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
+            p16h = _mm256_madd_epi16(scale_h, p16h);
+            const __m256i sumj = _mm256_add_epi32(p16l, p16h);
+
+            sumi = _mm256_add_epi32(sumi, sumj);
+        }
+
+        __m256 vd = _mm256_set1_ps(d);
+        acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
+
+    }
+
+    acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
+    acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
+
+    *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
+
+#elif defined __AVX__
+
+    const __m128i m4 = _mm_set1_epi8(0xF);
+    const __m128i m2 = _mm_set1_epi8(0x2);
+
+    __m256 acc = _mm256_setzero_ps();
+    __m128 acc_m = _mm_setzero_ps();
+
+   for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
+        const __m128i scales = _mm_cvtepu8_epi16(utmps);
+        const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
+
+        const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
+        const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
+        const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
+        const __m128i prod = _mm_madd_epi16(mins, q8s);
+        acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m);
+
+        __m128i sumi_0 = _mm_setzero_si128();
+        __m128i sumi_1 = _mm_setzero_si128();
+
+        __m128i shuffle = _mm_set1_epi16(0x0100);
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle);
+            shuffle = _mm_add_epi16(shuffle, m2);
+
+            __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+            const __m128i q4l_0 = _mm_and_si128(q4bits, m4);
+            const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
+            q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+            const __m128i q4l_1 = _mm_and_si128(q4bits, m4);
+            const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
+
+            const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0);
+            p16l = _mm_madd_epi16(scale_l, p16l);
+            sumi_0 = _mm_add_epi32(sumi_0, p16l);
+            const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            p16l = _mm_maddubs_epi16(q4l_1, q8l_1);
+            p16l = _mm_madd_epi16(scale_l, p16l);
+            sumi_1 = _mm_add_epi32(sumi_1, p16l);
+
+            const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0);
+            p16h = _mm_madd_epi16(scale_h, p16h);
+            sumi_0 = _mm_add_epi32(sumi_0, p16h);
+            const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            p16h = _mm_maddubs_epi16(q4h_1, q8h_1);
+            p16h = _mm_madd_epi16(scale_h, p16h);
+            sumi_1 = _mm_add_epi32(sumi_1, p16h);
+
+        }
+
+        __m256 vd = _mm256_set1_ps(d);
+        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+        acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
+
+    }
+
+    acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
+    acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
+
+    *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
+
+#elif defined __riscv_v_intrinsic
+
+    const uint8_t * scales = (const uint8_t*)&utmp[0];
+    const uint8_t * mins   = (const uint8_t*)&utmp[2];
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        size_t vl = 8;
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
+        vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
+        vint16mf2_t q8sums   = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        vuint8mf4_t mins8  = __riscv_vle8_v_u8mf4(mins, vl);
+        vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
+        vint32m1_t  prod   = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
+
+        vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
+        sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        vl = 32;
+
+        int32_t sum_1 = 0;
+        int32_t sum_2 = 0;
+
+        vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
+
+        for (int j = 0; j < QK_K/64; ++j) {
+            // load Q4
+            vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
+
+            // load Q8 and multiply it with lower Q4 nibble
+            vint8m1_t  q8_0 = __riscv_vle8_v_i8m1(q8, vl);
+            vint8m1_t  q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
+            vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl);
+            vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl);
+
+            sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0];
+
+            // load Q8 and multiply it with upper Q4 nibble
+            vint8m1_t  q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
+            vint8m1_t  q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
+            vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl);
+            vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl);
+
+            sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1];
+
+            q4 += 32;    q8 += 64;
+
+        }
+
+        sumf += d*(sum_1 + sum_2);
+
+    }
+
+    *s = sumf;
+
+#elif defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector signed char lowMask1 = vec_splats((int8_t)0x3f);
+    const vector signed char lowMask2 = vec_splats((int8_t)0x30);
+    const vector int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v2 = vec_splats((uint8_t)2);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin));
+        vector float vdmin = vec_mul(vxmin, vyd);
+
+        vector signed short q8ysums0 = vec_xl( 0, y[i].bsums);
+        vector signed short q8ysums1 = vec_xl(16, y[i].bsums);
+
+        UNUSED(kmask1);
+        UNUSED(kmask2);
+        UNUSED(kmask3);
+        UNUSED(utmp);
+
+        vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8);
+        vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2);
+        vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4);
+        vector signed char u3 = vec_sr(u2, v4);
+
+        vector signed char u30 = u1;
+        vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3);
+
+        u1 = vec_and(u0, lowMask1);
+        u2 = vec_or(u30, u31);
+
+        vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2);
+
+        vector signed short vscales = vec_unpackh(utmps);
+        vector signed short q4xmins = vec_unpackl(utmps);
+        vector signed short q4xmins0 = vec_mergeh(q4xmins, q4xmins);
+        vector signed short q4xmins1 = vec_mergel(q4xmins, q4xmins);
+
+        vector signed int prod0 = vec_mule(q4xmins0, q8ysums0);
+        vector signed int prod1 = vec_mule(q4xmins1, q8ysums1);
+        vector signed int prod2 = vec_mulo(q4xmins0, q8ysums0);
+        vector signed int prod3 = vec_mulo(q4xmins1, q8ysums1);
+
+        vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0);
+        vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1);
+        vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2);
+        vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        for (int j = 0; j < QK_K/64; j+=2) {
+            __builtin_prefetch(q4, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector signed char qxs0 = (vector signed char)vec_xl( 0, q4);
+            vector signed char qxs1 = (vector signed char)vec_xl(16, q4);
+            vector signed char qxs2 = (vector signed char)vec_xl(32, q4);
+            vector signed char qxs3 = (vector signed char)vec_xl(48, q4);
+            q4 += 64;
+
+            vector unsigned char q4x00 = (vector unsigned char)vec_and(qxs0, lowMask);
+            vector unsigned char q4x01 = (vector unsigned char)vec_sr(qxs0, v4);
+            vector unsigned char q4x10 = (vector unsigned char)vec_and(qxs1, lowMask);
+            vector unsigned char q4x11 = (vector unsigned char)vec_sr(qxs1, v4);
+            vector unsigned char q4x20 = (vector unsigned char)vec_and(qxs2, lowMask);
+            vector unsigned char q4x21 = (vector unsigned char)vec_sr(qxs2, v4);
+            vector unsigned char q4x30 = (vector unsigned char)vec_and(qxs3, lowMask);
+            vector unsigned char q4x31 = (vector unsigned char)vec_sr(qxs3, v4);
+
+            vector signed char q8y00 = vec_xl(  0, q8);
+            vector signed char q8y10 = vec_xl( 16, q8);
+            vector signed char q8y01 = vec_xl( 32, q8);
+            vector signed char q8y11 = vec_xl( 48, q8);
+            vector signed char q8y20 = vec_xl( 64, q8);
+            vector signed char q8y30 = vec_xl( 80, q8);
+            vector signed char q8y21 = vec_xl( 96, q8);
+            vector signed char q8y31 = vec_xl(112, q8);
+            q8 += 128;
+
+            vector signed int qv00 = vec_msum(q8y00, q4x00, v0);
+            vector signed int qv01 = vec_msum(q8y01, q4x01, v0);
+            vector signed int qv10 = vec_msum(q8y10, q4x10, v0);
+            vector signed int qv11 = vec_msum(q8y11, q4x11, v0);
+            vector signed int qv20 = vec_msum(q8y20, q4x20, v0);
+            vector signed int qv21 = vec_msum(q8y21, q4x21, v0);
+            vector signed int qv30 = vec_msum(q8y30, q4x30, v0);
+            vector signed int qv31 = vec_msum(q8y31, q4x31, v0);
+
+            vector signed int vscales_h = vec_unpackh(vscales);
+            vector signed int vs0 = vec_splat(vscales_h, 0);
+            vector signed int vs1 = vec_splat(vscales_h, 1);
+            vector signed int vs2 = vec_splat(vscales_h, 2);
+            vector signed int vs3 = vec_splat(vscales_h, 3);
+            vscales = vec_sld(vscales, vscales, 8);
+
+            vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0);
+            vsumi1 = vec_add(vec_mul(qv01, vs1), vsumi1);
+            vsumi2 = vec_add(vec_mul(qv20, vs2), vsumi2);
+            vsumi3 = vec_add(vec_mul(qv21, vs3), vsumi3);
+
+            vsumi0 = vec_add(vec_mul(qv10, vs0), vsumi0);
+            vsumi1 = vec_add(vec_mul(qv11, vs1), vsumi1);
+            vsumi2 = vec_add(vec_mul(qv30, vs2), vsumi2);
+            vsumi3 = vec_add(vec_mul(qv31, vs3), vsumi3);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = vec_extract(vsumf0, 0);
+
+#elif defined __loongarch_asx
+    GGML_UNUSED(kmask1);
+    GGML_UNUSED(kmask2);
+    GGML_UNUSED(kmask3);
+
+    const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
+
+    __m256 acc = (__m256)__lasx_xvldi(0);
+    __m128 acc_m = (__m128)__lsx_vldi(0);
+
+   for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
+
+        const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
+        const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
+        const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s);
+        acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
+
+        const __m128i sc128  = lasx_extracti128(mins_and_scales, 0);
+        const __m256i scales = lasx_insertf128(sc128, sc128);
+
+        __m256i sumi = __lasx_xvldi(0);
+
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const __m256i scale_l = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0));
+            const __m256i scale_h = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1));
+
+            const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
+            const __m256i q4l = __lasx_xvand_v(q4bits, m4);
+            const __m256i q4h = __lasx_xvand_v(__lasx_xvsrli_h(q4bits, 4), m4);
+
+            const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            __m256i p16l = lasx_maddubs_h(q4l, q8l);
+            p16l = lasx_madd_h(scale_l, p16l);
+
+            const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            __m256i p16h = lasx_maddubs_h(q4h, q8h);
+            p16h = lasx_madd_h(scale_h, p16h);
+            const __m256i sumj = __lasx_xvadd_w(p16l, p16h);
+
+            sumi = __lasx_xvadd_w(sumi, sumj);
+        }
+
+        __m256 vd = __lasx_xvreplfr2vr_s(d);
+        acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
+
+    }
+
+    acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vpermi_w((__m128i)acc_m, (__m128i)acc_m, 0xee));
+    __m128i tmp1 = __lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w((__m128i)acc_m, 1), 0);
+    acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
+
+
+    ft_union fi;
+    fi.i = __lsx_vpickve2gr_w(acc_m, 0);
+    *s = hsum_float_8(acc) + fi.f ;
+#else
+
+    const uint8_t * scales = (const uint8_t*)&utmp[0];
+    const uint8_t * mins   = (const uint8_t*)&utmp[2];
+
+    int8_t  aux8[QK_K];
+    int16_t aux16[8];
+    float   sums [8];
+    int32_t aux32[8];
+    memset(sums, 0, 8*sizeof(float));
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q4 = x[i].qs;
+        const  int8_t * restrict q8 = y[i].qs;
+        memset(aux32, 0, 8*sizeof(int32_t));
+        int8_t * restrict a = aux8;
+        for (int j = 0; j < QK_K/64; ++j) {
+            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
+            a += 32;
+            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l]  >> 4);
+            a += 32; q4 += 32;
+        }
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        int sumi = 0;
+        for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
+        a = aux8;
+        int is = 0;
+        for (int j = 0; j < QK_K/32; ++j) {
+            int32_t scale = scales[is++];
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+        }
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+        const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
+        sumf -= dmin * sumi;
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy,  size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q5_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+    static const uint32_t kmask1 = 0x3f3f3f3f;
+    static const uint32_t kmask2 = 0x0f0f0f0f;
+    static const uint32_t kmask3 = 0x03030303;
+
+    uint32_t utmp[4];
+
+#ifdef __ARM_NEON
+    const uint8x16_t m4b = vdupq_n_u8(0xf);
+    const uint8x16_t mone = vdupq_n_u8(1);
+    const uint8x16_t mtwo = vdupq_n_u8(2);
+    const int32x4_t mzero = vdupq_n_s32(0);
+
+    ggml_int8x16x4_t q5bytes;
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8);
+        const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
+        const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
+                                         vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
+        int32_t sumi_mins = vaddvq_s32(prod);
+
+        const uint8_t * scales = (const uint8_t *)utmp;
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
+
+        ggml_uint8x16x4_t q5h;
+
+        int32_t sumi = 0;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32;
+            const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+            q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
+            q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
+            q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3);
+            q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3);
+            qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2);
+            qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2);
+
+            q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0]));
+            q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1]));
+            q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
+            q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
+
+            sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
+            sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
+        }
+
+        sumf += d * sumi - dmin * sumi_mins;
+    }
+
+    *s = sumf;
+
+#elif defined __AVX2__
+
+    const __m256i m4 = _mm256_set1_epi8(0xF);
+    const __m128i mzero = _mm_setzero_si128();
+    const __m256i mone  = _mm256_set1_epi8(1);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    float summs = 0.f;
+
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q5 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
+
+        const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
+        const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
+        const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
+        const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
+        summs += dmin * _mm_extract_epi32(hsum, 0);
+
+        const __m128i sc128  = _mm256_extracti128_si256(mins_and_scales, 0);
+        const __m256i scales = MM256_SET_M128I(sc128, sc128);
+
+        const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh);
+        __m256i hmask = mone;
+
+        __m256i sumi = _mm256_setzero_si256();
+
+        int bit = 0;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
+            const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
+
+            const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32;
+
+            const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
+            const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
+            const __m256i q5_0  = _mm256_add_epi8(q5l_0, q5h_0);
+            hmask = _mm256_slli_epi16(hmask, 1);
+
+            const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
+            const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
+            const __m256i q5_1  = _mm256_add_epi8(q5l_1, q5h_1);
+            hmask = _mm256_slli_epi16(hmask, 1);
+
+            const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+            __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
+            __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
+
+            p16_0 = _mm256_madd_epi16(scale_0, p16_0);
+            p16_1 = _mm256_madd_epi16(scale_1, p16_1);
+
+            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
+
+        }
+
+        __m256 vd = _mm256_set1_ps(d);
+        acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
+
+    }
+
+    *s = hsum_float_8(acc) + summs;
+
+#elif defined __AVX__
+
+    const __m128i m4 = _mm_set1_epi8(0xF);
+    const __m128i mzero = _mm_setzero_si128();
+    const __m128i mone  = _mm_set1_epi8(1);
+    const __m128i m2 = _mm_set1_epi8(2);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    float summs = 0.f;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
+        const __m128i scales = _mm_cvtepu8_epi16(utmps);
+        const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
+
+        const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
+        const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
+        const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
+        const __m128i prod = _mm_madd_epi16(mins, q8s);
+        const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
+        summs += dmin * _mm_extract_epi32(hsum, 0);
+
+        const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]);
+        const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]);
+        __m128i hmask = mone;
+
+        __m128i sumi_0 = _mm_setzero_si128();
+        __m128i sumi_1 = _mm_setzero_si128();
+
+        int bit = 0;
+
+        __m128i shuffle = _mm_set1_epi16(0x0100);
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
+            shuffle = _mm_add_epi16(shuffle, m2);
+
+            const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
+            const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
+
+            __m128i q5l_0 = _mm_and_si128(q5bits_0, m4);
+            __m128i q5l_1 = _mm_and_si128(q5bits_1, m4);
+            __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
+            __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
+            __m128i q5_0  = _mm_add_epi8(q5l_0, q5h_0);
+            __m128i q5_1  = _mm_add_epi8(q5l_1, q5h_1);
+            hmask = _mm_slli_epi16(hmask, 1);
+
+            __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0);
+            __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1);
+            p16_0 = _mm_madd_epi16(scale_0, p16_0);
+            p16_1 = _mm_madd_epi16(scale_0, p16_1);
+
+            q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4);
+            q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4);
+            q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
+            q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
+            q5_0  = _mm_add_epi8(q5l_0, q5h_0);
+            q5_1  = _mm_add_epi8(q5l_1, q5h_1);
+            hmask = _mm_slli_epi16(hmask, 1);
+
+            q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0);
+            __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1);
+            p16_2 = _mm_madd_epi16(scale_1, p16_2);
+            p16_3 = _mm_madd_epi16(scale_1, p16_3);
+
+            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
+            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
+
+        }
+
+        __m256 vd = _mm256_set1_ps(d);
+        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+        acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
+
+    }
+
+    *s = hsum_float_8(acc) + summs;
+
+#elif defined __riscv_v_intrinsic
+
+    const uint8_t * scales = (const uint8_t*)&utmp[0];
+    const uint8_t * mins   = (const uint8_t*)&utmp[2];
+
+    float sumf = 0;
+    float sums = 0.0;
+
+    size_t vl;
+
+    for (int i = 0; i < nb; ++i) {
+
+        vl = 8;
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const uint8_t * restrict hm = x[i].qh;
+        const  int8_t * restrict q8 = y[i].qs;
+
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
+
+        vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
+        vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
+        vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
+        vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
+        vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
+
+        vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
+        sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
+
+        vl = 32;
+        int32_t aux32 = 0;
+        int is = 0;
+
+        uint8_t m = 1;
+        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+        vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl);
+
+        for (int j = 0; j < QK_K/64; ++j) {
+            // load Q5 and Q8
+            vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl);
+            vint8m1_t  q8_y1 = __riscv_vle8_v_i8m1(q8, vl);
+            vint8m1_t  q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl);
+
+            // compute mask for addition
+            vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl));
+            vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
+            vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl);
+            vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_mu(vmask_1, q5_a, q5_a, 16, vl);
+            m <<= 1;
+
+            vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl));
+            vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
+            vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl);
+            vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_mu(vmask_2, q5_l, q5_l, 16, vl);
+            m <<= 1;
+
+            vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl);
+            vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl);
+
+            vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl);
+            vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl);
+
+            vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl);
+            vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl);
+
+            aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2);
+            q5 += 32;    q8 += 64;
+
+        }
+
+        vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1);
+        sums += __riscv_vfmv_f_s_f32m1_f32(vaux);
+
+    }
+
+    *s = sumf+sums;
+
+#elif defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector signed char lowMask1 = vec_splats((int8_t)0x3f);
+    const vector signed char lowMask2 = vec_splats((int8_t)0x30);
+    const vector int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v1 = vec_splats((unsigned char)0x1);
+    const vector unsigned char v2 = vec_splats((unsigned char)0x2);
+    const vector unsigned char v3 = vec_splats((unsigned char)0x3);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin));
+        vector float vdmin = vec_mul(vxmin, vyd);
+
+        UNUSED(kmask1);
+        UNUSED(kmask2);
+        UNUSED(kmask3);
+        UNUSED(utmp);
+
+        vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8);
+        vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2);
+        vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4);
+        vector signed char u3 = vec_sr(u2, v4);
+
+        vector signed char u30 = u1;
+        vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3);
+
+        u1 = vec_and(u0, lowMask1);
+        u2 = vec_or(u30, u31);
+
+        vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2);
+
+        vector signed short q8ysums0 = vec_xl( 0, y[i].bsums);
+        vector signed short q8ysums1 = vec_xl(16, y[i].bsums);
+
+        vector signed short vscales = vec_unpackh(utmps);
+
+        vector signed short q5xmins = vec_unpackl(utmps);
+        vector signed short q5xmins0 = vec_mergeh(q5xmins, q5xmins);
+        vector signed short q5xmins1 = vec_mergel(q5xmins, q5xmins);
+
+        vector signed int prod0 = vec_mule(q5xmins0, q8ysums0);
+        vector signed int prod1 = vec_mule(q5xmins1, q8ysums1);
+        vector signed int prod2 = vec_mulo(q5xmins0, q8ysums0);
+        vector signed int prod3 = vec_mulo(q5xmins1, q8ysums1);
+
+        vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0);
+        vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1);
+        vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2);
+        vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3);
+
+        vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].qh);
+        vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].qh);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+            __builtin_prefetch(q5, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector signed char qxs0 = (vector signed char)vec_xl( 0, q5);
+            vector signed char qxs1 = (vector signed char)vec_xl(16, q5);
+            q5 += 32;
+
+            vector signed char qxs00 = vec_and(qxs0, lowMask);
+            vector signed char qxs01 = vec_sr(qxs0, v4);
+            vector signed char qxs10 = vec_and(qxs1, lowMask);
+            vector signed char qxs11 = vec_sr(qxs1, v4);
+
+            vector signed char q5h00 = vec_sl(vec_and((vector signed char)v1, qxhs0), v4);
+            vector signed char q5h01 = vec_sl(vec_and((vector signed char)v2, qxhs0), v3);
+            vector signed char q5h10 = vec_sl(vec_and((vector signed char)v1, qxhs1), v4);
+            vector signed char q5h11 = vec_sl(vec_and((vector signed char)v2, qxhs1), v3);
+            qxhs0 = vec_sr(qxhs0, v2);
+            qxhs1 = vec_sr(qxhs1, v2);
+
+            vector unsigned char q5x00 = (vector unsigned char)vec_or(q5h00, qxs00);
+            vector unsigned char q5x01 = (vector unsigned char)vec_or(q5h01, qxs01);
+            vector unsigned char q5x10 = (vector unsigned char)vec_or(q5h10, qxs10);
+            vector unsigned char q5x11 = (vector unsigned char)vec_or(q5h11, qxs11);
+
+            vector signed char q8y00 = vec_xl( 0, q8);
+            vector signed char q8y10 = vec_xl(16, q8);
+            vector signed char q8y01 = vec_xl(32, q8);
+            vector signed char q8y11 = vec_xl(48, q8);
+            q8 += 64;
+
+            vector signed int qv00 = vec_msum(q8y00, q5x00, v0);
+            vector signed int qv01 = vec_msum(q8y01, q5x01, v0);
+            vector signed int qv10 = vec_msum(q8y10, q5x10, v0);
+            vector signed int qv11 = vec_msum(q8y11, q5x11, v0);
+
+            vector signed int vscales_h = vec_unpackh(vscales);
+            vector signed int vs0 = vec_splat(vscales_h, 0);
+            vector signed int vs1 = vec_splat(vscales_h, 1);
+            vscales = vec_sld(vscales, vscales, 12);
+
+            vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0);
+            vsumi1 = vec_add(vec_mul(qv10, vs0), vsumi1);
+            vsumi2 = vec_add(vec_mul(qv01, vs1), vsumi2);
+            vsumi3 = vec_add(vec_mul(qv11, vs1), vsumi3);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = vec_extract(vsumf0, 0);
+
+#elif defined __loongarch_asx
+    GGML_UNUSED(kmask1);
+    GGML_UNUSED(kmask2);
+    GGML_UNUSED(kmask3);
+
+    const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
+    const __m128i mzero = __lsx_vldi(0);
+    const __m256i mone  = __lasx_xvreplgr2vr_b(1);
+
+    __m256 acc = (__m256)__lasx_xvldi(0);
+
+    float summs = 0.f;
+
+   for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
+
+        const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
+        const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
+        const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s);
+        const __m128i hsum = lsx_hadd_w(lsx_hadd_w(prod, mzero), mzero);
+        summs += dmin * __lsx_vpickve2gr_w(hsum, 0);    //TODO check
+
+        const __m128i sc128  = lasx_extracti128(mins_and_scales, 0);
+        const __m256i scales = lasx_insertf128(sc128, sc128);
+
+        const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0);
+        __m256i hmask = mone;
+
+        __m256i sumi = __lasx_xvldi(0);
+
+        int bit = 0;
+        __m256i xvbit;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const __m256i scale_0 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0));
+            const __m256i scale_1 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1));
+
+            const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;
+
+            xvbit = __lasx_xvreplgr2vr_h(bit++);
+            const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4);
+            const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
+            const __m256i q5_0  = __lasx_xvadd_b(q5l_0, q5h_0);
+            hmask = __lasx_xvslli_h(hmask, 1);
+
+            xvbit = __lasx_xvreplgr2vr_h(bit++);
+            const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4);
+            const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
+            const __m256i q5_1  = __lasx_xvadd_b(q5l_1, q5h_1);
+            hmask = __lasx_xvslli_h(hmask, 1);
+
+            const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+
+            __m256i p16_0 = lasx_maddubs_h(q5_0, q8_0);
+            __m256i p16_1 = lasx_maddubs_h(q5_1, q8_1);
+
+            p16_0 = lasx_madd_h(scale_0, p16_0);
+            p16_1 = lasx_madd_h(scale_1, p16_1);
+
+            sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
+
+        }
+
+        __m256 vd = __lasx_xvreplfr2vr_s(d);
+        acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
+
+    }
+
+    *s = hsum_float_8(acc) + summs;
+
+#else
+
+    const uint8_t * scales = (const uint8_t*)&utmp[0];
+    const uint8_t * mins   = (const uint8_t*)&utmp[2];
+
+    int8_t  aux8[QK_K];
+    int16_t aux16[8];
+    float   sums [8];
+    int32_t aux32[8];
+    memset(sums, 0, 8*sizeof(float));
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q4 = x[i].qs;
+        const uint8_t * restrict hm = x[i].qh;
+        const  int8_t * restrict q8 = y[i].qs;
+        memset(aux32, 0, 8*sizeof(int32_t));
+        int8_t * restrict a = aux8;
+        uint8_t m = 1;
+        for (int j = 0; j < QK_K/64; ++j) {
+            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
+            for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
+            a += 32; m <<= 1;
+            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l]  >> 4);
+            for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
+            a += 32; m <<= 1;
+            q4 += 32;
+        }
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        int sumi = 0;
+        for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
+        a = aux8;
+        int is = 0;
+        for (int j = 0; j < QK_K/32; ++j) {
+            int32_t scale = scales[is++];
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+        }
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+        const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
+        sumf -= dmin * sumi;
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q6_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+    float sum = 0;
+
+    const uint8x16_t m4b = vdupq_n_u8(0xF);
+    const int32x4_t  vzero = vdupq_n_s32(0);
+    //const int8x16_t  m32s = vdupq_n_s8(32);
+
+    const uint8x16_t mone = vdupq_n_u8(3);
+
+    ggml_int8x16x4_t q6bytes;
+    ggml_uint8x16x4_t q6h;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d_all = GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q6 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const int8_t * restrict scale = x[i].scales;
+
+        const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
+        const int8x16_t scales = vld1q_s8(scale);
+        const ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}};
+
+        const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
+                                                   vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
+                                         vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])),
+                                                   vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1]))));
+        int32_t isum_mins = vaddvq_s32(prod);
+
+        int32_t isum = 0;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32;
+            ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64;
+            ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+            q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
+            q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
+            uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2);
+            q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+            shifted = vshrq_n_u8(qhbits.val[1], 2);
+            q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+
+            //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
+            //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
+            //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s);
+            //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s);
+            q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0]));
+            q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1]));
+            q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
+            q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
+
+            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
+                    vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
+                    vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
+                    vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
+
+            scale += 4;
+
+            q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+            shifted = vshrq_n_u8(qhbits.val[0], 4);
+            q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+            shifted = vshrq_n_u8(qhbits.val[1], 4);
+            q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+            shifted = vshrq_n_u8(qhbits.val[0], 6);
+            q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+            shifted = vshrq_n_u8(qhbits.val[1], 6);
+            q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+
+            //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s);
+            //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s);
+            //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s);
+            //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s);
+            q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0]));
+            q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1]));
+            q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
+            q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
+
+            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
+                    vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
+                    vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
+                    vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
+            scale += 4;
+        }
+        //sum += isum * d_all * y[i].d;
+        sum += d_all * y[i].d * (isum - 32 * isum_mins);
+
+    }
+    *s = sum;
+
+#elif defined __AVX2__
+
+    const __m256i m4 = _mm256_set1_epi8(0xF);
+    const __m256i m2 = _mm256_set1_epi8(3);
+    const __m256i m32s = _mm256_set1_epi8(32);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q4 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+
+        __m256i sumi = _mm256_setzero_si256();
+
+        int is = 0;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
+            const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
+            const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
+            const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
+            is += 4;
+
+            const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
+            const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
+            const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32;
+
+            const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4);
+            const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4);
+            const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4);
+            const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4);
+
+            const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
+            const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
+            const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
+            const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
+
+            const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+            __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
+            __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
+            __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
+            __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
+
+            __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
+            __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
+            __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
+            __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
+
+            p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
+            p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
+            p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
+            p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
+
+            p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
+            p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
+            p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);
+            p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);
+
+            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
+            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));
+
+        }
+
+        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __AVX__
+
+    const __m128i m3 = _mm_set1_epi8(3);
+    const __m128i m15 = _mm_set1_epi8(15);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q4 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        // handle the q6_k -32 offset separately using bsums
+        const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums);
+        const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)y[i].bsums + 1);
+        const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+        const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales);
+        const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8));
+        const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5);
+        const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5);
+
+        __m128i sumi_0 = _mm_setzero_si128();
+        __m128i sumi_1 = _mm_setzero_si128();
+
+        int is = 0;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
+            const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
+
+            const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
+            const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
+            const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12)), 2);
+            const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12)), 2);
+            const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48));
+            const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48));
+            const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64)), 2);
+            const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64)), 2);
+
+            const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+            const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+            const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+            const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+
+            const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0);
+            const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1);
+            const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2);
+            const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3);
+            const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4);
+            const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5);
+            const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6);
+            const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7);
+
+            const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+
+            __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
+            __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
+            __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
+            __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);
+            __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4);
+            __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5);
+            __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
+            __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
+
+            const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
+            const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
+            const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
+            const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
+            is += 4;
+
+            p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
+            p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_0, 8)), p16_1);
+            p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
+            p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_1, 8)), p16_3);
+            p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
+            p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_2, 8)), p16_5);
+            p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
+            p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_3, 8)), p16_7);
+
+            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
+            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
+            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
+            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
+
+        }
+
+        sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0);
+        sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1);
+        const __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi)), acc);
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __riscv_v_intrinsic
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+
+        const uint8_t * restrict q6 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const  int8_t * restrict q8 = y[i].qs;
+
+        const int8_t * restrict scale = x[i].scales;
+
+        size_t vl;
+
+        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+
+        int sum_t = 0;
+        int is = 0;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            vl = 32;
+
+            // load qh
+            vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
+
+            // load Q6
+            vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
+            vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
+
+            vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
+            vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
+            vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
+            vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
+
+            vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
+            vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
+            vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
+            vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
+
+            vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
+            vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
+            vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
+            vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
+
+            vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
+            vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
+            vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
+            vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
+
+            // load Q8 and take product
+            vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
+            vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
+            vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
+            vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
+
+            vl = 16;
+
+            vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
+            vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
+            vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
+            vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
+            vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
+            vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
+            vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
+            vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
+
+            vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
+            vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
+            vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
+            vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
+
+            sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
+
+            q6 += 64;   qh += 32;   q8 += 128;   is=8;
+
+        }
+
+        sumf += d * sum_t;
+
+    }
+
+    *s = sumf;
+
+#elif defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v2 = vec_splats((unsigned char)0x2);
+    const vector unsigned char v3 = vec_splats((unsigned char)0x3);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+    const vector unsigned char v6 = vec_splats((unsigned char)0x6);
+    const vector signed char off = vec_splats((signed char)0x20);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+        vector signed int vsumi4 = v0;
+        vector signed int vsumi5 = v0;
+        vector signed int vsumi6 = v0;
+        vector signed int vsumi7 = v0;
+
+        const uint8_t * restrict q6 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict qs = x[i].scales;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+            __builtin_prefetch(q6, 0, 0);
+            __builtin_prefetch(qh, 0, 0);
+            __builtin_prefetch(q8, 0, 0);
+
+            vector signed char qxs0 = (vector signed char)vec_xl( 0, q6);
+            vector signed char qxs1 = (vector signed char)vec_xl(16, q6);
+            vector signed char qxs2 = (vector signed char)vec_xl(32, q6);
+            vector signed char qxs3 = (vector signed char)vec_xl(48, q6);
+            q6 += 64;
+
+            vector signed char qxs00 = vec_and(qxs0, lowMask);
+            vector signed char qxs01 = vec_sr(qxs0, v4);
+            vector signed char qxs10 = vec_and(qxs1, lowMask);
+            vector signed char qxs11 = vec_sr(qxs1, v4);
+            vector signed char qxs20 = vec_and(qxs2, lowMask);
+            vector signed char qxs21 = vec_sr(qxs2, v4);
+            vector signed char qxs30 = vec_and(qxs3, lowMask);
+            vector signed char qxs31 = vec_sr(qxs3, v4);
+
+            vector signed char qxhs0 = (vector signed char)vec_xl( 0, qh);
+            vector signed char qxhs1 = (vector signed char)vec_xl(16, qh);
+            qh += 32;
+
+            vector signed char qxh00 = vec_sl(vec_and((vector signed char)v3, qxhs0), v4);
+            vector signed char qxh01 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v4)), v4);
+            vector signed char qxh10 = vec_sl(vec_and((vector signed char)v3, qxhs1), v4);
+            vector signed char qxh11 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v4)), v4);
+            vector signed char qxh20 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v2)), v4);
+            vector signed char qxh21 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v6)), v4);
+            vector signed char qxh30 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v2)), v4);
+            vector signed char qxh31 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v6)), v4);
+
+            vector signed char q6x00 = vec_sub(vec_or(qxh00, qxs00), off);
+            vector signed char q6x01 = vec_sub(vec_or(qxh01, qxs01), off);
+            vector signed char q6x10 = vec_sub(vec_or(qxh10, qxs10), off);
+            vector signed char q6x11 = vec_sub(vec_or(qxh11, qxs11), off);
+            vector signed char q6x20 = vec_sub(vec_or(qxh20, qxs20), off);
+            vector signed char q6x21 = vec_sub(vec_or(qxh21, qxs21), off);
+            vector signed char q6x30 = vec_sub(vec_or(qxh30, qxs30), off);
+            vector signed char q6x31 = vec_sub(vec_or(qxh31, qxs31), off);
+
+            vector signed char q8y00 = vec_xl(  0, q8);
+            vector signed char q8y10 = vec_xl( 16, q8);
+            vector signed char q8y20 = vec_xl( 32, q8);
+            vector signed char q8y30 = vec_xl( 48, q8);
+            vector signed char q8y01 = vec_xl( 64, q8);
+            vector signed char q8y11 = vec_xl( 80, q8);
+            vector signed char q8y21 = vec_xl( 96, q8);
+            vector signed char q8y31 = vec_xl(112, q8);
+            q8 += 128;
+
+            vector signed short qv00 = vec_add(vec_mule(q6x00, q8y00), vec_mulo(q6x00, q8y00));
+            vector signed short qv10 = vec_add(vec_mule(q6x10, q8y10), vec_mulo(q6x10, q8y10));
+            vector signed short qv20 = vec_add(vec_mule(q6x20, q8y20), vec_mulo(q6x20, q8y20));
+            vector signed short qv30 = vec_add(vec_mule(q6x30, q8y30), vec_mulo(q6x30, q8y30));
+            vector signed short qv01 = vec_add(vec_mule(q6x01, q8y01), vec_mulo(q6x01, q8y01));
+            vector signed short qv11 = vec_add(vec_mule(q6x11, q8y11), vec_mulo(q6x11, q8y11));
+            vector signed short qv21 = vec_add(vec_mule(q6x21, q8y21), vec_mulo(q6x21, q8y21));
+            vector signed short qv31 = vec_add(vec_mule(q6x31, q8y31), vec_mulo(q6x31, q8y31));
+
+            vector signed short vscales = vec_unpackh(vec_xl_len(qs, 8));
+            qs += 8;
+
+            vector signed short vs0 = vec_splat(vscales, 0);
+            vector signed short vs1 = vec_splat(vscales, 1);
+            vector signed short vs2 = vec_splat(vscales, 2);
+            vector signed short vs3 = vec_splat(vscales, 3);
+            vector signed short vs4 = vec_splat(vscales, 4);
+            vector signed short vs5 = vec_splat(vscales, 5);
+            vector signed short vs6 = vec_splat(vscales, 6);
+            vector signed short vs7 = vec_splat(vscales, 7);
+
+            vsumi0 = vec_msum(qv00, vs0, vsumi0);
+            vsumi1 = vec_msum(qv01, vs4, vsumi1);
+            vsumi2 = vec_msum(qv10, vs1, vsumi2);
+            vsumi3 = vec_msum(qv11, vs5, vsumi3);
+            vsumi4 = vec_msum(qv20, vs2, vsumi4);
+            vsumi5 = vec_msum(qv21, vs6, vsumi5);
+            vsumi6 = vec_msum(qv30, vs3, vsumi6);
+            vsumi7 = vec_msum(qv31, vs7, vsumi7);
+        }
+
+        vsumi0 = vec_add(vsumi0, vsumi4);
+        vsumi1 = vec_add(vsumi1, vsumi5);
+        vsumi2 = vec_add(vsumi2, vsumi6);
+        vsumi3 = vec_add(vsumi3, vsumi7);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = vec_extract(vsumf0, 0);
+
+#elif defined __loongarch_asx
+
+    const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
+    const __m256i m2 = __lasx_xvreplgr2vr_b(3);
+    const __m256i m32s = __lasx_xvreplgr2vr_b(32);
+
+    __m256 acc = (__m256)__lasx_xvldi(0);
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q4 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m128i scales = __lsx_vld((const __m128i*)x[i].scales, 0);
+
+        __m256i sumi = __lasx_xvldi(0);
+
+        int is = 0;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            const __m128i scale_0 = lsx_shuffle_b(scales, get_scale_shuffle(is + 0));
+            const __m128i scale_1 = lsx_shuffle_b(scales, get_scale_shuffle(is + 1));
+            const __m128i scale_2 = lsx_shuffle_b(scales, get_scale_shuffle(is + 2));
+            const __m128i scale_3 = lsx_shuffle_b(scales, get_scale_shuffle(is + 3));
+            is += 4;
+
+            const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
+            const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
+            const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32;
+
+            const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(q4bitsH, m2), 4);
+            const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 2), m2), 4);
+            const __m256i q4h_2 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 4), m2), 4);
+            const __m256i q4h_3 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 6), m2), 4);
+
+            const __m256i q4_0 = __lasx_xvor_v(__lasx_xvand_v(q4bits1, m4), q4h_0);
+            const __m256i q4_1 = __lasx_xvor_v(__lasx_xvand_v(q4bits2, m4), q4h_1);
+            const __m256i q4_2 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits1, 4), m4), q4h_2);
+            const __m256i q4_3 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits2, 4), m4), q4h_3);
+
+            const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+
+            __m256i q8s_0 = lasx_maddubs_h(m32s, q8_0);
+            __m256i q8s_1 = lasx_maddubs_h(m32s, q8_1);
+            __m256i q8s_2 = lasx_maddubs_h(m32s, q8_2);
+            __m256i q8s_3 = lasx_maddubs_h(m32s, q8_3);
+
+            __m256i p16_0 = lasx_maddubs_h(q4_0, q8_0);
+            __m256i p16_1 = lasx_maddubs_h(q4_1, q8_1);
+            __m256i p16_2 = lasx_maddubs_h(q4_2, q8_2);
+            __m256i p16_3 = lasx_maddubs_h(q4_3, q8_3);
+
+            p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
+            p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
+            p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
+            p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
+
+            p16_0 = lasx_madd_h(lasx_ext8_16(scale_0), p16_0);
+            p16_1 = lasx_madd_h(lasx_ext8_16(scale_1), p16_1);
+            p16_2 = lasx_madd_h(lasx_ext8_16(scale_2), p16_2);
+            p16_3 = lasx_madd_h(lasx_ext8_16(scale_3), p16_3);
+
+            sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
+            sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3));
+        }
+
+        acc = __lasx_xvfmadd_s((__m256)__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
+    }
+
+    *s = hsum_float_8(acc);
+
+#else
+
+    int8_t  aux8[QK_K];
+    int16_t aux16[8];
+    float   sums [8];
+    int32_t aux32[8];
+    memset(sums, 0, 8*sizeof(float));
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q4 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const  int8_t * restrict q8 = y[i].qs;
+        memset(aux32, 0, 8*sizeof(int32_t));
+        int8_t * restrict a = aux8;
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) {
+                a[l +  0] = (int8_t)((q4[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+                a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+                a[l + 64] = (int8_t)((q4[l +  0] >>  4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+                a[l + 96] = (int8_t)((q4[l + 32] >>  4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+            }
+            a  += 128;
+            q4 += 64;
+            qh += 32;
+        }
+        a = aux8;
+        int is = 0;
+        for (int j = 0; j < QK_K/16; ++j) {
+            int scale = x[i].scales[is++];
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+        }
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+#endif
+}
+
+#if defined (__AVX__) || defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) || defined(__loongarch_asx)
+static const int8_t keven_signs_q2xs[1024] = {
+     1,  1,  1,  1,  1,  1,  1,  1, -1,  1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1,  1,
+     1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1,  1,  1, -1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1, -1,
+     1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1, -1,
+     1,  1, -1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1,  1,
+     1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1, -1,
+     1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1,  1,
+     1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1,  1,
+     1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1,  1,  1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1, -1,
+     1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1, -1,
+     1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1,  1,
+     1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1,  1,
+     1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1, -1,
+     1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1,  1,
+     1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1, -1,
+     1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1, -1,
+     1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1,  1,
+     1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1, -1,
+     1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1,  1,
+     1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1,  1,
+     1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1, -1,
+     1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1,  1,
+     1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1, -1,
+     1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1, -1,
+     1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1,  1,
+     1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1,  1,
+     1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1, -1,
+     1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1, -1,
+     1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1,  1,
+     1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1, -1,
+     1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1,  1,
+     1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1,  1,
+     1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1,  1,  1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1,
+};
+#endif
+
+void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq2_xxs * restrict x = vx;
+    const block_q8_K    * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__ARM_NEON)
+
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    uint32_t aux32[4];
+    const uint8_t * aux8 = (const uint8_t *)aux32;
+
+    ggml_int8x16x4_t q2u;
+    ggml_int8x16x4_t q2s;
+    ggml_int8x16x4_t q8b;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint16_t * restrict q2 = x[i].qs;
+        const int8_t   * restrict q8 = y[i].qs;
+        float sumf1 = 0, sumf2 = 0;
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+            memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
+            q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1])));
+            q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3])));
+            q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9])));
+            q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11])));
+            q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >>  0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >>  7) & 127))));
+            q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
+            q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >>  0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >>  7) & 127))));
+            q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127))));
+            q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
+            q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
+            q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
+            q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
+            const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]);
+            const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]);
+            sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28));
+            sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28));
+        }
+        sumf += d*(sumf1 + sumf2);
+    }
+    *s = 0.25f * sumf;
+
+#elif defined(__AVX2__)
+
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    uint32_t aux32[4];
+    const uint8_t * aux8 = (const uint8_t *)aux32;
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint16_t * restrict q2 = x[i].qs;
+        const int8_t   * restrict q8 = y[i].qs;
+        __m256i sumi1 = _mm256_setzero_si256();
+        __m256i sumi2 = _mm256_setzero_si256();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
+            const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
+            const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
+            const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
+                                                   signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);
+            const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],
+                                                   signs64[(aux32[3] >>  7) & 127], signs64[(aux32[3] >>  0) & 127]);
+            const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
+            const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
+            const __m256i dot1  = _mm256_maddubs_epi16(q2_1, q8s_1);
+            const __m256i dot2  = _mm256_maddubs_epi16(q2_2, q8s_2);
+            const uint16_t ls1 = aux32[1] >> 28;
+            const uint16_t ls2 = aux32[3] >> 28;
+            const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
+            const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
+            sumi1 = _mm256_add_epi32(sumi1, p1);
+            sumi2 = _mm256_add_epi32(sumi2, p2);
+        }
+
+        accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
+
+    }
+
+    *s = 0.125f * hsum_float_8(accumf);
+
+#elif defined(__AVX__)
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    uint32_t aux32[4];
+    const uint8_t * aux8 = (const uint8_t *)aux32;
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint16_t * restrict q2 = x[i].qs;
+        const int8_t   * restrict q8 = y[i].qs;
+        __m128i sumi1_0 = _mm_setzero_si128();
+        __m128i sumi1_1 = _mm_setzero_si128();
+        __m128i sumi2_0 = _mm_setzero_si128();
+        __m128i sumi2_1 = _mm_setzero_si128();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
+            const __m128i q2_1_0 = _mm_set_epi64x(iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
+            const __m128i q2_1_1 = _mm_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]]);
+            const __m128i q2_2_0 = _mm_set_epi64x(iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
+            const __m128i q2_2_1 = _mm_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]]);
+            const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);
+            const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);
+            const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[3] >>  7) & 127], signs64[(aux32[3] >>  0) & 127]);
+            const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127]);
+            const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);
+            const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);
+            const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);
+            const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);
+            const __m128i dot1_0  = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
+            const __m128i dot1_1  = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
+            const __m128i dot2_0  = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
+            const __m128i dot2_1  = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
+            const uint16_t ls1 = aux32[1] >> 28;
+            const uint16_t ls2 = aux32[3] >> 28;
+            const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
+            const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
+            const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
+            const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
+            sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
+            sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
+            sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
+            sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
+        }
+
+        accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
+
+    }
+
+    *s = 0.125f * hsum_float_8(accumf);
+
+#elif defined(__POWER9_VECTOR__)
+    const vector int v0 = vec_splats((int32_t)0);
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+
+        const uint16_t * restrict q2 = x[i].qs;
+        const int8_t  *  restrict q8 = y[i].qs;
+
+        for (int j = 0; j < QK_K/32; j += 2) {
+            __builtin_prefetch(q2, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            uint32_t aux32[4];
+            const uint8_t * aux8 = (const uint8_t *)aux32;
+
+            memcpy(aux32, q2, 4*sizeof(uint32_t));
+            q2 += 8;
+
+            vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xxs_grid + aux8[ 0]), *(const int64_t *)(iq2xxs_grid + aux8[ 1])};
+            vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xxs_grid + aux8[ 2]), *(const int64_t *)(iq2xxs_grid + aux8[ 3])};
+            vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xxs_grid + aux8[ 8]), *(const int64_t *)(iq2xxs_grid + aux8[ 9])};
+            vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xxs_grid + aux8[10]), *(const int64_t *)(iq2xxs_grid + aux8[11])};
+
+            vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((aux32[1] >>  0) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >>  7) & 127))};
+            vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((aux32[1] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 21) & 127))};
+            vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((aux32[3] >>  0) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >>  7) & 127))};
+            vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((aux32[3] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 21) & 127))};
+
+            vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0);
+            vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1);
+            vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2);
+            vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3);
+
+            vector signed char q8y0 = vec_xl( 0, q8);
+            vector signed char q8y1 = vec_xl(16, q8);
+            vector signed char q8y2 = vec_xl(32, q8);
+            vector signed char q8y3 = vec_xl(48, q8);
+            q8 += 64;
+
+            vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0));
+            vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1));
+            vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2));
+            vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3));
+
+            const uint16_t ls0 = aux32[1] >> 28;
+            const uint16_t ls1 = aux32[3] >> 28;
+
+            vector signed short vscales01 = vec_splats((int16_t)(2*ls0+1));
+            vector signed short vscales23 = vec_splats((int16_t)(2*ls1+1));
+
+            vsumi0 = vec_msum(qv0, vscales01, vsumi0);
+            vsumi1 = vec_msum(qv1, vscales01, vsumi1);
+            vsumi2 = vec_msum(qv2, vscales23, vsumi2);
+            vsumi3 = vec_msum(qv3, vscales23, vsumi3);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = 0.125f * vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    uint32_t aux32[4];
+    const uint8_t * aux8 = (const uint8_t *)aux32;
+
+    __m256 accumf = (__m256)__lasx_xvldi(0);
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint16_t * restrict q2 = x[i].qs;
+        const int8_t   * restrict q8 = y[i].qs;
+        __m256i sumi1 = __lasx_xvldi(0);
+        __m256i sumi2 = __lasx_xvldi(0);
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
+
+            const __m256i q2_1 = lasx_set_d(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
+            const __m256i q2_2 = lasx_set_d(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
+            const __m256i s2_1 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
+                                                   signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);
+            const __m256i s2_2 = lasx_set_d(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],
+                                                   signs64[(aux32[3] >>  7) & 127], signs64[(aux32[3] >>  0) & 127]);
+            const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1);
+            const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2);
+            const __m256i dot1  = lasx_maddubs_h(q2_1, q8s_1);
+            const __m256i dot2  = lasx_maddubs_h(q2_2, q8s_2);
+            const uint16_t ls1 = aux32[1] >> 28;
+            const uint16_t ls2 = aux32[3] >> 28;
+            const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1));
+            const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1));
+            sumi1 = __lasx_xvadd_w(sumi1, p1);
+            sumi2 = __lasx_xvadd_w(sumi2, p2);
+        }
+
+        accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
+    }
+
+    *s = 0.125f * hsum_float_8(accumf);
+
+#else
+
+    uint32_t aux32[2];
+    const uint8_t * aux8 = (const uint8_t *)aux32;
+
+    float sumf = 0.f;
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint16_t * restrict q2 = x[i].qs;
+        const int8_t   * restrict q8 = y[i].qs;
+        int32_t bsum = 0;
+        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+            memcpy(aux32, q2, 2*sizeof(uint32_t));
+            q2 += 4;
+            const uint32_t ls = 2*(aux32[1] >> 28) + 1;
+            int32_t sumi = 0;
+            for (int l = 0; l < 4; ++l) {
+                const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
+                const uint8_t  signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
+                for (int j = 0; j < 8; ++j) {
+                    sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
+                }
+                q8 += 8;
+            }
+            bsum += sumi * ls;
+        }
+        sumf += d * bsum;
+    }
+    *s = 0.125f * sumf;
+#endif
+}
+
+void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq2_xs * restrict x = vx;
+    const block_q8_K   * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__ARM_NEON)
+
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    ggml_int8x16x4_t q2u;
+    ggml_int8x16x4_t q2s;
+    ggml_int8x16x4_t q8b;
+
+    int32x4x4_t scales32;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint16_t * restrict q2 = x[i].qs;
+        const int8_t   * restrict q8 = y[i].qs;
+        const uint8x8_t scales8 = vld1_u8(x[i].scales);
+        const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf));
+        const uint8x8_t scales_h = vshr_n_u8(scales8, 4);
+        uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));
+        scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1));
+        const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales));
+        const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales));
+        scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1)));
+        scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1)));
+        scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2)));
+        scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2)));
+        int32x4_t sumi = vdupq_n_s32(0);
+        for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
+            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+            q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511))));
+            q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511))));
+            q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511))));
+            q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511))));
+            q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9))));
+            q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9))));
+            q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9))));
+            q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9))));
+            q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
+            q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
+            q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
+            q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
+            const int32x4_t p1 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]);
+            const int32x4_t p2 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]);
+            const int32x4_t p3 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]);
+            const int32x4_t p4 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]);
+            const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4));
+            sumi = vmlaq_s32(sumi, p, scales32.val[ib64]);
+            q2 += 8;
+        }
+        sumf += d*vaddvq_s32(sumi);
+    }
+    *s = 0.125f * sumf;
+
+#elif defined(__AVX2__)
+
+    const __m256i mone = _mm256_set1_epi8(1);
+    static const char block_sign_shuffle_mask_1[32] = {
+        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
+        0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
+    };
+    static const char block_sign_shuffle_mask_2[32] = {
+        0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
+        0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
+    };
+    static const uint8_t bit_selector_mask_bytes[32] = {
+        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+    };
+
+    const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes);
+    const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1);
+    const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2);
+
+    static const uint8_t k_bit_helper[32] = {
+        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+    };
+    const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper);
+    const __m256i m511 = _mm256_set1_epi16(511);
+    const __m128i m4 = _mm_set1_epi8(0xf);
+    const __m128i m1 = _mm_set1_epi8(1);
+
+    uint64_t aux64;
+
+    // somewhat hacky, but gives a significant boost in performance
+    __m256i aux_gindex;
+    const uint16_t * gindex = (const uint16_t *)&aux_gindex;
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint16_t * restrict q2 = x[i].qs;
+        const int8_t   * restrict q8 = y[i].qs;
+
+        memcpy(&aux64, x[i].scales, 8);
+        __m128i stmp = _mm_set1_epi64x(aux64);
+        stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));
+        const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);
+
+        __m256i sumi1 = _mm256_setzero_si256();
+        __m256i sumi2 = _mm256_setzero_si256();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
+
+            const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2);  q2 += 16;
+            aux_gindex = _mm256_and_si256(q2_data, m511);
+
+            const __m256i partial_sign_bits = _mm256_srli_epi16(q2_data, 9);
+            const __m256i partial_sign_bits_upper = _mm256_srli_epi16(q2_data, 13);
+            const __m256i partial_sign_bits_for_counting = _mm256_xor_si256(partial_sign_bits, partial_sign_bits_upper);
+
+            const __m256i odd_bits = _mm256_shuffle_epi8(bit_helper, partial_sign_bits_for_counting);
+            const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, odd_bits);
+
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+
+            const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]],
+                                                   iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]);
+            const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]],
+                                                   iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]);
+            const __m256i q2_3 = _mm256_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]],
+                                                   iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]);
+            const __m256i q2_4 = _mm256_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]],
+                                                   iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
+
+            const __m128i full_signs_l = _mm256_castsi256_si128(full_sign_bits);
+            const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1);
+            const __m256i full_signs_1 = MM256_SET_M128I(full_signs_l, full_signs_l);
+            const __m256i full_signs_2 = MM256_SET_M128I(full_signs_h, full_signs_h);
+
+            __m256i signs;
+            signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1);
+            signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
+            const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone));
+
+            signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2);
+            signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
+            const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone));
+
+            signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1);
+            signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
+            const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone));
+
+            signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2);
+            signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
+            const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone));
+
+            const __m256i dot1  = _mm256_maddubs_epi16(q2_1, q8s_1);
+            const __m256i dot2  = _mm256_maddubs_epi16(q2_2, q8s_2);
+            const __m256i dot3  = _mm256_maddubs_epi16(q2_3, q8s_3);
+            const __m256i dot4  = _mm256_maddubs_epi16(q2_4, q8s_4);
+
+            const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)));
+            const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)));
+            const __m256i sc3 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2)));
+            const __m256i sc4 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3)));
+
+            sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1));
+            sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2));
+            sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot3, sc3));
+            sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot4, sc4));
+        }
+
+        accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
+
+    }
+
+    *s = 0.125f * hsum_float_8(accumf);
+
+#elif defined(__AVX__)
+    const __m128i mone = _mm_set1_epi8(1);
+    static const char block_sign_shuffle_mask_1[32] = {
+        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
+        0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
+    };
+    static const char block_sign_shuffle_mask_2[32] = {
+        0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
+        0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
+    };
+    static const uint8_t bit_selector_mask_bytes[32] = {
+        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+    };
+
+    const __m128i bit_selector_mask_0 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes);
+    const __m128i bit_selector_mask_1 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes + 1);
+    const __m128i block_sign_shuffle_1_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1);
+    const __m128i block_sign_shuffle_1_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1 + 1);
+    const __m128i block_sign_shuffle_2_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2);
+    const __m128i block_sign_shuffle_2_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2 + 1);
+
+    static const uint8_t k_bit_helper[32] = {
+        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+    };
+    const __m128i bit_helper_0 = _mm_loadu_si128((const __m128i*)k_bit_helper);
+    const __m128i bit_helper_1 = _mm_loadu_si128((const __m128i*)k_bit_helper + 1);
+    const __m128i m511 = _mm_set1_epi16(511);
+    const __m128i m4 = _mm_set1_epi8(0xf);
+    const __m128i m1 = _mm_set1_epi8(1);
+
+    uint64_t aux64;
+
+    // somewhat hacky, but gives a significant boost in performance
+    __m256i aux_gindex;
+    const uint16_t * gindex = (const uint16_t *)&aux_gindex;
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint16_t * restrict q2 = x[i].qs;
+        const int8_t   * restrict q8 = y[i].qs;
+
+        memcpy(&aux64, x[i].scales, 8);
+        __m128i stmp = _mm_set1_epi64x(aux64);
+        stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));
+        const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);
+
+        __m128i sumi1_0 = _mm_setzero_si128();
+        __m128i sumi1_1 = _mm_setzero_si128();
+        __m128i sumi2_0 = _mm_setzero_si128();
+        __m128i sumi2_1 = _mm_setzero_si128();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
+
+            const __m128i q2_data_0 = _mm_loadu_si128((const __m128i*)q2);
+            const __m128i q2_data_1 = _mm_loadu_si128((const __m128i*)q2 + 1);  q2 += 16;
+            aux_gindex = MM256_SET_M128I(_mm_and_si128(q2_data_1, m511), _mm_and_si128(q2_data_0, m511));
+
+            const __m128i partial_sign_bits_0 = _mm_srli_epi16(q2_data_0, 9);
+            const __m128i partial_sign_bits_1 = _mm_srli_epi16(q2_data_1, 9);
+            const __m128i partial_sign_bits_upper_0 = _mm_srli_epi16(q2_data_0, 13);
+            const __m128i partial_sign_bits_upper_1 = _mm_srli_epi16(q2_data_1, 13);
+            const __m128i partial_sign_bits_for_counting_0 = _mm_xor_si128(partial_sign_bits_0, partial_sign_bits_upper_0);
+            const __m128i partial_sign_bits_for_counting_1 = _mm_xor_si128(partial_sign_bits_1, partial_sign_bits_upper_1);
+
+            const __m128i odd_bits_0 = _mm_shuffle_epi8(bit_helper_0, partial_sign_bits_for_counting_0);
+            const __m128i odd_bits_1 = _mm_shuffle_epi8(bit_helper_1, partial_sign_bits_for_counting_1);
+            const __m128i full_sign_bits_0 = _mm_or_si128(partial_sign_bits_0, odd_bits_0);
+            const __m128i full_sign_bits_1 = _mm_or_si128(partial_sign_bits_1, odd_bits_1);
+
+            const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_3_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_3_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_4_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_4_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+
+            const __m128i q2_1_0 = _mm_set_epi64x(iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]);
+            const __m128i q2_1_1 = _mm_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]]);
+            const __m128i q2_2_0 = _mm_set_epi64x(iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]);
+            const __m128i q2_2_1 = _mm_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]]);
+            const __m128i q2_3_0 = _mm_set_epi64x(iq2xs_grid[gindex[9]], iq2xs_grid[gindex[8]]);
+            const __m128i q2_3_1 = _mm_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]]);
+            const __m128i q2_4_0 = _mm_set_epi64x(iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
+            const __m128i q2_4_1 = _mm_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]]);
+
+            // AVX2 full_signs_1 is full_sign_bits_0 here
+            // AVX2 full_signs_2 is full_sign_bits_1 here
+            __m128i signs_0, signs_1;
+            signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_0);
+            signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_1);
+            signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
+            signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
+            const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, _mm_or_si128(signs_0, mone));
+            const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, _mm_or_si128(signs_1, mone));
+
+            signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_0);
+            signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_1);
+            signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
+            signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
+            const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, _mm_or_si128(signs_0, mone));
+            const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, _mm_or_si128(signs_1, mone));
+
+            signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_0);
+            signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_1);
+            signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
+            signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
+            const __m128i q8s_3_0 = _mm_sign_epi8(q8_3_0, _mm_or_si128(signs_0, mone));
+            const __m128i q8s_3_1 = _mm_sign_epi8(q8_3_1, _mm_or_si128(signs_1, mone));
+
+            signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_0);
+            signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_1);
+            signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
+            signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
+            const __m128i q8s_4_0 = _mm_sign_epi8(q8_4_0, _mm_or_si128(signs_0, mone));
+            const __m128i q8s_4_1 = _mm_sign_epi8(q8_4_1, _mm_or_si128(signs_1, mone));
+
+            const __m128i dot1_0  = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
+            const __m128i dot1_1  = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
+            const __m128i dot2_0  = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
+            const __m128i dot2_1  = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
+            const __m128i dot3_0  = _mm_maddubs_epi16(q2_3_0, q8s_3_0);
+            const __m128i dot3_1  = _mm_maddubs_epi16(q2_3_1, q8s_3_1);
+            const __m128i dot4_0  = _mm_maddubs_epi16(q2_4_0, q8s_4_0);
+            const __m128i dot4_1  = _mm_maddubs_epi16(q2_4_1, q8s_4_1);
+
+            __m128i sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0));
+            const __m128i sc1_0 = _mm_cvtepi8_epi16(sc_tmp);
+            const __m128i sc1_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
+            sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1));
+            const __m128i sc2_0 = _mm_cvtepi8_epi16(sc_tmp);
+            const __m128i sc2_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
+            sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2));
+            const __m128i sc3_0 = _mm_cvtepi8_epi16(sc_tmp);
+            const __m128i sc3_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
+            sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3));
+            const __m128i sc4_0 = _mm_cvtepi8_epi16(sc_tmp);
+            const __m128i sc4_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
+
+            sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot1_0, sc1_0));
+            sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot1_1, sc1_1));
+            sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot2_0, sc2_0));
+            sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot2_1, sc2_1));
+            sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot3_0, sc3_0));
+            sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot3_1, sc3_1));
+            sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot4_0, sc4_0));
+            sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot4_1, sc4_1));
+        }
+
+        accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
+
+    }
+
+    *s = 0.125f * hsum_float_8(accumf);
+
+#elif defined(__loongarch_asx)
+
+    const __m256i mone = __lasx_xvreplgr2vr_b(1);
+    static const char block_sign_shuffle_mask_1[32] = {
+        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
+        0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
+    };
+    static const char block_sign_shuffle_mask_2[32] = {
+        0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
+        0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
+    };
+    static const uint8_t bit_selector_mask_bytes[32] = {
+        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+    };
+
+    const __m256i bit_selector_mask = __lasx_xvld((const __m256i*)bit_selector_mask_bytes, 0);
+    const __m256i block_sign_shuffle_1 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_1, 0);
+    const __m256i block_sign_shuffle_2 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_2, 0);
+
+    static const uint8_t k_bit_helper[32] = {
+        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+    };
+    const __m256i bit_helper = __lasx_xvld((const __m256i*)k_bit_helper, 0);
+    const __m256i m511 = __lasx_xvreplgr2vr_h(511);
+    const __m128i m4 = __lsx_vreplgr2vr_b(0xf);
+    const __m128i m1 = __lsx_vreplgr2vr_b(1);
+
+    uint64_t aux64;
+
+    // somewhat hacky, but gives a significant boost in performance
+    __m256i aux_gindex;
+    const uint16_t * gindex = (const uint16_t *)&aux_gindex;
+
+    __m256 accumf = (__m256)__lasx_xvldi(0);
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint16_t * restrict q2 = x[i].qs;
+        const int8_t   * restrict q8 = y[i].qs;
+
+        memcpy(&aux64, x[i].scales, 8);
+        __m128i stmp = __lsx_vreplgr2vr_d(aux64);
+        stmp = __lsx_vilvl_b( __lsx_vand_v(__lsx_vsrli_h(stmp, 4), m4), __lsx_vand_v(stmp, m4));
+        const __m128i scales = __lsx_vadd_b(__lsx_vslli_h(stmp, 1), m1);
+
+        __m256i sumi1 = __lasx_xvldi(0);
+        __m256i sumi2 = __lasx_xvldi(0);
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
+
+            const __m256i q2_data = __lasx_xvld((const __m256i*)q2, 0);  q2 += 16;
+            aux_gindex = __lasx_xvand_v(q2_data, m511);
+
+            const __m256i partial_sign_bits = __lasx_xvsrli_h(q2_data, 9);
+            const __m256i partial_sign_bits_upper = __lasx_xvsrli_h(q2_data, 13);
+            const __m256i partial_sign_bits_for_counting = __lasx_xvxor_v(partial_sign_bits, partial_sign_bits_upper);
+
+            const __m256i odd_bits = lasx_shuffle_b(bit_helper, partial_sign_bits_for_counting);
+            const __m256i full_sign_bits = __lasx_xvor_v(partial_sign_bits, odd_bits);
+
+            const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            const __m256i q8_3 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            const __m256i q8_4 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+
+            const __m256i q2_1 = lasx_set_d(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]],
+                                                   iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]);
+            const __m256i q2_2 = lasx_set_d(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]],
+                                                   iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]);
+            const __m256i q2_3 = lasx_set_d(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]],
+                                                   iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]);
+            const __m256i q2_4 = lasx_set_d(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]],
+                                                   iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
+
+            const __m128i full_signs_l = lasx_extracti128(full_sign_bits, 0);
+            const __m128i full_signs_h = lasx_extracti128(full_sign_bits, 1);
+            const __m256i full_signs_1 = lasx_insertf128(full_signs_l, full_signs_l);
+            const __m256i full_signs_2 = lasx_insertf128(full_signs_h, full_signs_h);
+
+            __m256i signs;
+            signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_1);
+            signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
+            const __m256i q8s_1 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_1);
+
+            signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_2);
+            signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
+            const __m256i q8s_2 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_2);
+
+            signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_1);
+            signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
+            const __m256i q8s_3 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_3);
+
+            signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_2);
+            signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
+            const __m256i q8s_4 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_4);
+
+            const __m256i dot1  = lasx_maddubs_h(q2_1, q8s_1);
+            const __m256i dot2  = lasx_maddubs_h(q2_2, q8s_2);
+            const __m256i dot3  = lasx_maddubs_h(q2_3, q8s_3);
+            const __m256i dot4  = lasx_maddubs_h(q2_4, q8s_4);
+
+            const __m256i sc1 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+0)));
+            const __m256i sc2 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+1)));
+            const __m256i sc3 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+2)));
+            const __m256i sc4 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+3)));
+
+            sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot1, sc1));
+            sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot2, sc2));
+            sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot3, sc3));
+            sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot4, sc4));
+        }
+
+        accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
+
+    }
+
+    *s = 0.125f * hsum_float_8(accumf);
+#elif defined(__POWER9_VECTOR__)
+    const vector int v0 = vec_splats((int32_t)0);
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+
+        const uint16_t * restrict q2 = x[i].qs;
+        const uint8_t  * restrict sc = x[i].scales;
+        const int8_t  *  restrict q8 = y[i].qs;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+            __builtin_prefetch(q2, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xs_grid + (q2[0] & 511)), *(const int64_t *)(iq2xs_grid + (q2[1] & 511))};
+            vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xs_grid + (q2[2] & 511)), *(const int64_t *)(iq2xs_grid + (q2[3] & 511))};
+            vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xs_grid + (q2[4] & 511)), *(const int64_t *)(iq2xs_grid + (q2[5] & 511))};
+            vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xs_grid + (q2[6] & 511)), *(const int64_t *)(iq2xs_grid + (q2[7] & 511))};
+
+            vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((q2[0] >> 9))), *(const int64_t *)(signs64 + ((q2[1] >> 9)))};
+            vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((q2[2] >> 9))), *(const int64_t *)(signs64 + ((q2[3] >> 9)))};
+            vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((q2[4] >> 9))), *(const int64_t *)(signs64 + ((q2[5] >> 9)))};
+            vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((q2[6] >> 9))), *(const int64_t *)(signs64 + ((q2[7] >> 9)))};
+            q2 += 8;
+
+            vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0);
+            vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1);
+            vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2);
+            vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3);
+
+            vector signed char q8y0 = vec_xl( 0, q8);
+            vector signed char q8y1 = vec_xl(16, q8);
+            vector signed char q8y2 = vec_xl(32, q8);
+            vector signed char q8y3 = vec_xl(48, q8);
+            q8 += 64;
+
+            vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0));
+            vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1));
+            vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2));
+            vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3));
+
+            const uint16_t ls0 = (uint16_t)(sc[0] & 0xf);
+            const uint16_t ls1 = (uint16_t)(sc[0] >>  4);
+            const uint16_t ls2 = (uint16_t)(sc[1] & 0xf);
+            const uint16_t ls3 = (uint16_t)(sc[1] >>  4);
+            sc += 2;
+
+            vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1));
+            vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1));
+            vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1));
+            vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1));
+
+            vsumi0 = vec_msum(qv0, vscales0, vsumi0);
+            vsumi1 = vec_msum(qv1, vscales1, vsumi1);
+            vsumi2 = vec_msum(qv2, vscales2, vsumi2);
+            vsumi3 = vec_msum(qv3, vscales3, vsumi3);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = 0.125f * vec_extract(vsumf0, 0);
+#else
+
+    float sumf = 0.f;
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint16_t * restrict q2 = x[i].qs;
+        const uint8_t  * restrict sc = x[i].scales;
+        const int8_t   * restrict q8 = y[i].qs;
+        int32_t bsum = 0;
+        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+            const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1;
+            const uint16_t ls2 = 2*(sc[ib32] >>  4) + 1;
+            int32_t sumi = 0;
+            for (int l = 0; l < 2; ++l) {
+                const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
+                const uint8_t  signs = ksigns_iq2xs[q2[l] >> 9];
+                for (int j = 0; j < 8; ++j) {
+                    sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
+                }
+                q8 += 8;
+            }
+            bsum += sumi * ls1;
+            sumi = 0;
+            for (int l = 2; l < 4; ++l) {
+                const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
+                const uint8_t  signs = ksigns_iq2xs[q2[l] >> 9];
+                for (int j = 0; j < 8; ++j) {
+                    sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
+                }
+                q8 += 8;
+            }
+            bsum += sumi * ls2;
+            q2 += 4;
+        }
+        sumf += d * bsum;
+    }
+    *s = 0.125f * sumf;
+#endif
+}
+
+void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq2_s * restrict x = vx;
+    const block_q8_K  * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__ARM_NEON)
+
+   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+   };
+
+    static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
+
+    const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1);
+    const uint8x16_t        mask2 = vld1q_u8(k_mask2);
+    const uint8x16_t m1 = vdupq_n_u8(1);
+    const int32x4_t vzero = vdupq_n_s32(0);
+
+    uint8x16x2_t vs;
+    ggml_int8x16x4_t q2s;
+    ggml_int8x16x4_t q8b;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+
+        const uint8_t * restrict qs = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8);
+        const int8_t  * restrict q8 = y[i].qs;
+
+        int sumi1 = 0, sumi2 = 0;
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+            q2s.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[0] | ((qh[ib32+0] << 8) & 0x300)))),
+                                     vld1_s8((const int8_t *)(iq2s_grid + (qs[1] | ((qh[ib32+0] << 6) & 0x300)))));
+            q2s.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[2] | ((qh[ib32+0] << 4) & 0x300)))),
+                                     vld1_s8((const int8_t *)(iq2s_grid + (qs[3] | ((qh[ib32+0] << 2) & 0x300)))));
+            q2s.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[4] | ((qh[ib32+1] << 8) & 0x300)))),
+                                     vld1_s8((const int8_t *)(iq2s_grid + (qs[5] | ((qh[ib32+1] << 6) & 0x300)))));
+            q2s.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[6] | ((qh[ib32+1] << 4) & 0x300)))),
+                                     vld1_s8((const int8_t *)(iq2s_grid + (qs[7] | ((qh[ib32+1] << 2) & 0x300)))));
+            qs += 8;
+
+            vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16)));
+            vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
+            vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
+            vs.val[0] = vceqq_u8(vs.val[0], mask2);
+            vs.val[1] = vceqq_u8(vs.val[1], mask2);
+
+            q2s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[0]);
+            q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]);
+
+            vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16)));
+            vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
+            vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
+            vs.val[0] = vceqq_u8(vs.val[0], mask2);
+            vs.val[1] = vceqq_u8(vs.val[1], mask2);
+
+            signs += 4;
+
+            q2s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[2]);
+            q2s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[3]);
+
+            const int32x4_t p1 = ggml_vdotq_s32(vzero, q2s.val[0], q8b.val[0]);
+            const int32x4_t p2 = ggml_vdotq_s32(vzero, q2s.val[1], q8b.val[1]);
+            const int32x4_t p3 = ggml_vdotq_s32(vzero, q2s.val[2], q8b.val[2]);
+            const int32x4_t p4 = ggml_vdotq_s32(vzero, q2s.val[3], q8b.val[3]);
+
+            sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32+0] & 0xf));
+            sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32+0] >>  4));
+            sumi1 += vaddvq_s32(p3) * (1 + 2*(x[i].scales[ib32+1] & 0xf));
+            sumi2 += vaddvq_s32(p4) * (1 + 2*(x[i].scales[ib32+1] >>  4));
+        }
+        sumf += d*(sumi1 + sumi2);
+    }
+
+    *s = 0.125f * sumf;
+
+#elif defined(__AVX2__)
+
+   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+   };
+
+    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+    };
+
+    const __m128i m4 = _mm_set1_epi8(0xf);
+    const __m128i m1 = _mm_set1_epi8(1);
+
+    const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);
+    const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);
+
+    uint64_t aux64;
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict qs = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8);
+        const int8_t  * restrict q8 = y[i].qs;
+
+        memcpy(&aux64, x[i].scales, 8);
+        const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1);
+        const __m256i scales16 = _mm256_cvtepi8_epi16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15
+
+        __m256i sumi1 = _mm256_setzero_si256();
+        __m256i sumi2 = _mm256_setzero_si256();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q2_1 = _mm256_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],
+                                                   iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)],
+                                                   iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],
+                                                   iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);
+            const __m256i q2_2 = _mm256_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],
+                                                   iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)],
+                                                   iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],
+                                                   iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
+            qs += 8;
+
+            __m256i aux256 = _mm256_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16));
+            aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
+            const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2);
+            const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1);
+
+            aux256 = _mm256_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16));
+            aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
+            const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2);
+            const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2);
+
+            signs += 4;
+
+            const __m256i dot1  = _mm256_maddubs_epi16(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1
+            const __m256i dot2  = _mm256_maddubs_epi16(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3
+
+            const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+0)));
+            const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+1)));
+            sumi1 = _mm256_add_epi32(sumi1, p1);
+            sumi2 = _mm256_add_epi32(sumi2, p2);
+        }
+
+        accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
+
+    }
+
+    *s = 0.125f * hsum_float_8(accumf);
+
+#elif defined(__AVX__)
+   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+   };
+
+    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+    };
+
+    const __m128i m4 = _mm_set1_epi8(0xf);
+    const __m128i m1 = _mm_set1_epi8(1);
+
+    const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);
+    const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);
+    const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
+    const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);
+
+    uint64_t aux64;
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict qs = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8);
+        const int8_t  * restrict q8 = y[i].qs;
+
+        memcpy(&aux64, x[i].scales, 8);
+        const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1);
+        const __m128i scales16_0 = _mm_cvtepi8_epi16(scales8);
+        const __m128i scales16_1 = _mm_cvtepi8_epi16(_mm_srli_si128(scales8, 8));
+
+        __m128i sumi1_0 = _mm_setzero_si128();
+        __m128i sumi1_1 = _mm_setzero_si128();
+        __m128i sumi2_0 = _mm_setzero_si128();
+        __m128i sumi2_1 = _mm_setzero_si128();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q2_1_0 = _mm_set_epi64x(iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],
+                                                  iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);
+            const __m128i q2_1_1 = _mm_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],
+                                                  iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)]);
+            const __m128i q2_2_0 = _mm_set_epi64x(iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],
+                                                  iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
+            const __m128i q2_2_1 = _mm_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],
+                                                  iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)]);
+            qs += 8;
+
+            __m128i aux128_0 = _mm_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16));
+            __m128i aux128_1 = aux128_0;
+            aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
+            aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
+            const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
+            const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
+            const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);
+            const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);
+
+            aux128_0 = _mm_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16));
+            aux128_1 = aux128_0;
+            aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
+            aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
+            const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
+            const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
+            const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);
+            const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);
+
+            signs += 4;
+
+            const __m128i dot1_0  = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
+            const __m128i dot1_1  = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
+            const __m128i dot2_0  = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
+            const __m128i dot2_1  = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
+
+            const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 0)));
+            const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 1)));
+            const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 0)));
+            const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 1)));
+            sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
+            sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
+            sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
+            sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
+        }
+
+        accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
+
+    }
+
+    *s = 0.125f * hsum_float_8(accumf);
+
+#elif defined(__POWER9_VECTOR__)
+    static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                        0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+    };
+
+    static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
+
+    const vector int v0 = vec_splats((int32_t)0);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    const vector unsigned char mask0 = vec_xl( 0, k_mask1);
+    const vector unsigned char mask1 = vec_xl(16, k_mask1);
+    const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+
+        const uint8_t *  restrict q2 = x[i].qs;
+        const uint8_t *  restrict qh = x[i].qh;
+        const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8);
+        const uint8_t *  restrict sc = x[i].scales;
+        const int8_t  *  restrict q8 = y[i].qs;
+
+        for (int j = 0; j < QK_K/32; j += 2) {
+            __builtin_prefetch(q2, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector signed long long aux64x2_0 = {*(const int64_t *)(iq2s_grid + (q2[0] | ((qh[0] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[1] | ((qh[0] << 6) & 0x300)))};
+            vector signed long long aux64x2_1 = {*(const int64_t *)(iq2s_grid + (q2[2] | ((qh[0] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[3] | ((qh[0] << 2) & 0x300)))};
+            vector signed long long aux64x2_2 = {*(const int64_t *)(iq2s_grid + (q2[4] | ((qh[1] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[5] | ((qh[1] << 6) & 0x300)))};
+            vector signed long long aux64x2_3 = {*(const int64_t *)(iq2s_grid + (q2[6] | ((qh[1] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[7] | ((qh[1] << 2) & 0x300)))};
+            q2 += 8;
+            qh += 2;
+
+            vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]);
+            vector signed char vsigns23 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]);
+            signs += 4;
+
+            vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0);
+            vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1);
+            vector signed char vsigns2 = vec_perm(vsigns23, vsigns23, mask0);
+            vector signed char vsigns3 = vec_perm(vsigns23, vsigns23, mask1);
+
+            vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2);
+            vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2);
+            vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2);
+            vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2);
+
+            vector signed char q2x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux64x2_0), vsigns0);
+            vector signed char q2x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux64x2_1), vsigns1);
+            vector signed char q2x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux64x2_2), vsigns2);
+            vector signed char q2x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux64x2_3), vsigns3);
+
+            vector signed char q8y0 = vec_xl( 0, q8);
+            vector signed char q8y1 = vec_xl(16, q8);
+            vector signed char q8y2 = vec_xl(32, q8);
+            vector signed char q8y3 = vec_xl(48, q8);
+            q8 += 64;
+
+            vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0));
+            vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1));
+            vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2));
+            vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3));
+
+            const uint16_t ls0 = (uint16_t)(sc[0] & 0xf);
+            const uint16_t ls1 = (uint16_t)(sc[0] >>  4);
+            const uint16_t ls2 = (uint16_t)(sc[1] & 0xf);
+            const uint16_t ls3 = (uint16_t)(sc[1] >>  4);
+            sc += 2;
+
+            vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1));
+            vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1));
+            vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1));
+            vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1));
+
+            vsumi0 = vec_msum(qv0, vscales0, vsumi0);
+            vsumi1 = vec_msum(qv1, vscales1, vsumi1);
+            vsumi2 = vec_msum(qv2, vscales2, vsumi2);
+            vsumi3 = vec_msum(qv3, vscales3, vsumi3);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = 0.125f * vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+
+   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+   };
+
+    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+    };
+
+
+    const __m128i m4 = __lsx_vreplgr2vr_b(0xf);
+    const __m128i m1 = __lsx_vreplgr2vr_b(1);
+
+    const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0);
+    const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0);
+    uint64_t aux64;
+
+    __m256 accumf = (__m256)__lasx_xvldi(0);
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict qs = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8);
+        const int8_t  * restrict q8 = y[i].qs;
+
+        __m128i tmp1;
+        memcpy(&aux64, x[i].scales, 8);
+        tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64, 0);
+        tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64 >> 4, 1);
+        const __m128i scales8 = __lsx_vadd_b(__lsx_vslli_h(__lsx_vand_v(tmp1, m4), 1), m1);
+        const __m256i scales16 = lasx_ext8_16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15
+
+        __m256i sumi1 = __lasx_xvldi(0);
+        __m256i sumi2 = __lasx_xvldi(0);
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            const __m256i q2_1 = lasx_set_d(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],
+                                                   iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)],
+                                                   iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],
+                                                   iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);
+            const __m256i q2_2 = lasx_set_d(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],
+                                                   iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)],
+                                                   iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],
+                                                   iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
+            qs += 8;
+
+            __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | ((uint32_t) signs[1] << 16));
+            aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);
+            const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2);
+            const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1);
+
+            aux256 = __lasx_xvreplgr2vr_w(signs[2] | ((uint32_t) signs[3] << 16));
+            aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);
+            const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2);
+            const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2);
+
+            signs += 4;
+
+            const __m256i dot1  = lasx_maddubs_h(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1
+            const __m256i dot2  = lasx_maddubs_h(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3
+
+            const __m256i p1 = lasx_madd_h(dot1, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+0)));
+            const __m256i p2 = lasx_madd_h(dot2, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+1)));
+            sumi1 = __lasx_xvadd_w(sumi1, p1);
+            sumi2 = __lasx_xvadd_w(sumi2, p2);
+        }
+
+        accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
+    }
+
+    *s = 0.125f * hsum_float_8(accumf);
+
+#else
+
+    float sumf = 0;
+    for (int i = 0; i < nb; i++) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const int8_t  * q8 = y[i].qs;
+        const uint8_t * qs = x[i].qs;
+        const uint8_t * qh = x[i].qh;
+        const uint8_t * signs = qs + QK_K/8;
+
+        int bsum = 0;
+        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+            int ls1 = 1 + 2*(x[i].scales[ib32] & 0xf);
+            int ls2 = 1 + 2*(x[i].scales[ib32] >>  4);
+            int sumi1 = 0, sumi2 = 0;
+            for (int l = 0; l < 2; ++l) {
+                const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
+                for (int j = 0; j < 8; ++j) {
+                    sumi1 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);
+                }
+                q8 += 8;
+            }
+            for (int l = 2; l < 4; ++l) {
+                const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
+                for (int j = 0; j < 8; ++j) {
+                    sumi2 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);
+                }
+                q8 += 8;
+            }
+            bsum += ls1 * sumi1 + ls2 * sumi2;
+            qs += 4;
+            signs += 4;
+        }
+
+        sumf += d * bsum;
+    }
+
+    *s = 0.125f * sumf;
+
+#endif
+
+}
+
+void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq3_xxs * restrict x = vx;
+    const block_q8_K    * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__ARM_NEON)
+
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    uint32_t aux32[2];
+
+    ggml_int8x16x4_t q3s;
+    ggml_int8x16x4_t q8b;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict gas = x[i].qs + QK_K/4;
+        const int8_t   * restrict q8 = y[i].qs;
+        float sumf1 = 0, sumf2 = 0;
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+            memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t);
+            const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]);
+            const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]);
+            const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]);
+            const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]);
+            q3 += 16;
+            q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >>  0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >>  7) & 127))));
+            q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127))));
+            q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >>  0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >>  7) & 127))));
+            q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
+            q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0));
+            q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1));
+            q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2));
+            q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3));
+            const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
+            const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
+            sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28));
+            sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28));
+        }
+        sumf += d*(sumf1 + sumf2);
+    }
+    *s = 0.5f * sumf;
+
+#elif defined(__AVX2__)
+
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    uint32_t aux32[2];
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict gas = x[i].qs + QK_K/4;
+        const int8_t  * restrict q8 = y[i].qs;
+        __m256i sumi1 = _mm256_setzero_si256();
+        __m256i sumi2 = _mm256_setzero_si256();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
+                                                  iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
+            q3 += 8;
+            const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
+                                                  iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
+            q3 += 8;
+            memcpy(aux32, gas, 8); gas += 8;
+            const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127],
+                                                   signs64[(aux32[0] >>  7) & 127], signs64[(aux32[0] >>  0) & 127]);
+            const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
+                                                   signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);
+            const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
+            const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
+            const __m256i dot1  = _mm256_maddubs_epi16(q2_1, q8s_1);
+            const __m256i dot2  = _mm256_maddubs_epi16(q2_2, q8s_2);
+            const uint16_t ls1 = aux32[0] >> 28;
+            const uint16_t ls2 = aux32[1] >> 28;
+            const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
+            const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
+            sumi1 = _mm256_add_epi32(sumi1, p1);
+            sumi2 = _mm256_add_epi32(sumi2, p2);
+        }
+
+        accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
+
+    }
+
+    *s = 0.25f * hsum_float_8(accumf);
+
+#elif defined(__AVX__)
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    uint32_t aux32[2];
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict gas = x[i].qs + QK_K/4;
+        const int8_t  * restrict q8 = y[i].qs;
+        __m128i sumi1_0 = _mm_setzero_si128();
+        __m128i sumi1_1 = _mm_setzero_si128();
+        __m128i sumi2_0 = _mm_setzero_si128();
+        __m128i sumi2_1 = _mm_setzero_si128();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q2_1_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
+            const __m128i q2_1_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);
+            q3 += 8;
+            const __m128i q2_2_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
+            const __m128i q2_2_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);
+            q3 += 8;
+            memcpy(aux32, gas, 8); gas += 8;
+            const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[0] >>  7) & 127], signs64[(aux32[0] >>  0) & 127]);
+            const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127]);
+            const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);
+            const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);
+            const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);
+            const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);
+            const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);
+            const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);
+            const __m128i dot1_0  = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
+            const __m128i dot1_1  = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
+            const __m128i dot2_0  = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
+            const __m128i dot2_1  = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
+            const uint16_t ls1 = aux32[0] >> 28;
+            const uint16_t ls2 = aux32[1] >> 28;
+            const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
+            const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
+            const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
+            const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
+            sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
+            sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
+            sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
+            sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
+        }
+
+        accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
+
+    }
+
+    *s = 0.25f * hsum_float_8(accumf);
+
+#elif defined(__POWER9_VECTOR__)
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    const vector int v0 = vec_splats((int32_t)0);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint32_t * restrict signs = (const uint32_t *)(x[i].qs + QK_K/4);
+        const int8_t  * restrict q8 = y[i].qs;
+
+#pragma GCC unroll 1
+        for (int j = 0; j < QK_K/32; j += 2) {
+            __builtin_prefetch(q3, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector unsigned int aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]};
+            vector unsigned int aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]};
+            vector unsigned int aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]};
+            vector unsigned int aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]};
+            q3 += 16;
+
+            vector unsigned long long aux64x2_0 = {(uint64_t)(signs64[(signs[0] >>  0) & 127]), (uint64_t)(signs64[(signs[0] >>  7) & 127])};
+            vector unsigned long long aux64x2_1 = {(uint64_t)(signs64[(signs[0] >> 14) & 127]), (uint64_t)(signs64[(signs[0] >> 21) & 127])};
+            vector unsigned long long aux64x2_2 = {(uint64_t)(signs64[(signs[1] >>  0) & 127]), (uint64_t)(signs64[(signs[1] >>  7) & 127])};
+            vector unsigned long long aux64x2_3 = {(uint64_t)(signs64[(signs[1] >> 14) & 127]), (uint64_t)(signs64[(signs[1] >> 21) & 127])};
+
+            vector signed char q3x0 = vec_mul((vector signed char)aux64x2_0, (vector signed char)aux32x4_0);
+            vector signed char q3x1 = vec_mul((vector signed char)aux64x2_1, (vector signed char)aux32x4_1);
+            vector signed char q3x2 = vec_mul((vector signed char)aux64x2_2, (vector signed char)aux32x4_2);
+            vector signed char q3x3 = vec_mul((vector signed char)aux64x2_3, (vector signed char)aux32x4_3);
+
+            vector signed char q8y0 = vec_xl( 0, q8);
+            vector signed char q8y1 = vec_xl(16, q8);
+            vector signed char q8y2 = vec_xl(32, q8);
+            vector signed char q8y3 = vec_xl(48, q8);
+            q8 += 64;
+
+            vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0));
+            vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1));
+            vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2));
+            vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3));
+
+            const uint16_t ls0 = (uint16_t)(signs[0] >> 28);
+            const uint16_t ls1 = (uint16_t)(signs[1] >> 28);
+            signs += 2;
+
+            vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1));
+            vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1));
+
+            vsumi0 = vec_msum(qv0, vscales01, vsumi0);
+            vsumi1 = vec_msum(qv1, vscales01, vsumi1);
+            vsumi2 = vec_msum(qv2, vscales23, vsumi2);
+            vsumi3 = vec_msum(qv3, vscales23, vsumi3);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = 0.25f * vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    uint32_t aux32[2];
+
+    __m256 accumf = (__m256)__lasx_xvldi(0);
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict gas = x[i].qs + QK_K/4;
+        const int8_t  * restrict q8 = y[i].qs;
+        __m256i sumi1 = __lasx_xvldi(0);
+        __m256i sumi2 = __lasx_xvldi(0);
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            const __m256i q2_1 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
+                                                iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
+            q3 += 8;
+            const __m256i q2_2 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
+                                                iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
+            q3 += 8;
+            memcpy(aux32, gas, 8); gas += 8;
+
+            const __m256i s2_1 = lasx_set_d(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127],
+                                                   signs64[(aux32[0] >>  7) & 127], signs64[(aux32[0] >>  0) & 127]);
+            const __m256i s2_2 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
+                                                   signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);
+            const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1);
+            const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2);
+            const __m256i dot1  = lasx_maddubs_h(q2_1, q8s_1);
+            const __m256i dot2  = lasx_maddubs_h(q2_2, q8s_2);
+            const uint16_t ls1 = aux32[0] >> 28;
+            const uint16_t ls2 = aux32[1] >> 28;
+
+            const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1));
+            const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1));
+            sumi1 = __lasx_xvadd_w(sumi1, p1);
+            sumi2 = __lasx_xvadd_w(sumi2, p2);
+        }
+
+        accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
+    }
+
+    *s = 0.25f * hsum_float_8(accumf);
+
+#else
+
+    uint32_t aux32;
+
+    float sumf = 0.f;
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict gas = x[i].qs + QK_K/4;
+        const int8_t  * restrict q8 = y[i].qs;
+        int32_t bsum = 0;
+        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+            memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t);
+            const uint32_t ls = 2*(aux32 >> 28) + 1;
+            int32_t sumi = 0;
+            for (int l = 0; l < 4; ++l) {
+                const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]);
+                const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]);
+                const uint8_t  signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
+                for (int j = 0; j < 4; ++j) {
+                    sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1);
+                    sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1);
+                }
+                q8 += 8;
+            }
+            q3 += 8;
+            bsum += sumi * ls;
+        }
+        sumf += d * bsum;
+    }
+    *s = 0.25f * sumf;
+#endif
+}
+
+void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq3_s * restrict x = vx;
+    const block_q8_K  * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__ARM_NEON)
+
+    typedef union {
+        uint16x8_t vec_index;
+        uint16_t   index[8];
+    } vec_index_t;
+
+   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+   };
+
+    static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
+
+    static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};
+
+    const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1);
+    const uint8x16_t        mask2 = vld1q_u8(k_mask2);
+
+    const int16x8_t  hshift = vld1q_s16(k_shift);
+    const uint16x8_t m256   = vdupq_n_u16(256);
+    const uint8x16_t m1     = vdupq_n_u8(1);
+
+    uint8x16x2_t vs;
+    ggml_int8x16x4_t q3s;
+    ggml_int8x16x4_t q8b;
+    vec_index_t idx;
+
+    uint32_t scales32[2];
+    const uint8_t * scales8 = (const uint8_t *)scales32;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict qs = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
+        const int8_t   * restrict q8 = y[i].qs;
+
+        memcpy(scales32, x[i].scales, 4);
+        scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101;
+        scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101;
+
+        int sumi1 = 0, sumi2 = 0;
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+            const uint8x16_t idx_l = vld1q_u8(qs); qs += 16;
+            idx.vec_index = vorrq_u16(vmovl_u8(vget_low_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), hshift), m256));
+            const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]],
+                                                        iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]);
+            const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]],
+                                                        iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]);
+            idx.vec_index = vorrq_u16(vmovl_u8(vget_high_u8(idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), hshift), m256));
+            const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]],
+                                                        iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]);
+            const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]],
+                                                        iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]);
+
+
+            vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16)));
+            vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
+            vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
+            vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
+            vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1);
+
+            q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0));
+            q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1));
+
+            vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16)));
+            vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
+            vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
+            vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
+            vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1);
+
+            signs += 4;
+
+            q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_2));
+            q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_3));
+
+            const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
+            const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
+
+            sumi1 += vaddvq_s32(p1) * scales8[ib32/2+0];
+            sumi2 += vaddvq_s32(p2) * scales8[ib32/2+4];
+        }
+        sumf += d*(sumi1 + sumi2);
+    }
+    *s = sumf;
+
+#elif defined(__AVX2__)
+
+   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+   };
+
+    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+    };
+
+    const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);
+    const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);
+
+    const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8);
+    const __m256i idx_mask  = _mm256_set1_epi32(256);
+
+    typedef union {
+        __m256i  vec[2];
+        uint32_t index[16];
+    } index_t;
+
+    index_t idx;
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict qs = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
+        const int8_t  * restrict q8 = y[i].qs;
+        __m256i sumi1 = _mm256_setzero_si256();
+        __m256i sumi2 = _mm256_setzero_si256();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs)); qs += 16;
+            idx.vec[0] = _mm256_set1_epi32(qh[ib32+0]);
+            idx.vec[1] = _mm256_set1_epi32(qh[ib32+1]);
+            idx.vec[0] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[0], idx_shift), idx_mask);
+            idx.vec[1] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[1], idx_shift), idx_mask);
+            idx.vec[0] = _mm256_or_si256(idx.vec[0], _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l)));
+            idx.vec[1] = _mm256_or_si256(idx.vec[1], _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l, 1)));
+
+            // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange.
+            //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4);
+            //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4);
+            const __m256i q2_1 = _mm256_set_epi32(
+                    iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]],
+                    iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]
+            );
+            const __m256i q2_2 = _mm256_set_epi32(
+                    iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]],
+                    iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]]
+            );
+
+            __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16));
+            aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
+            const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2);
+            const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1);
+
+            aux256 = _mm256_set1_epi32(signs[2] | (signs[3] << 16));
+            aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
+            const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2);
+            const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2);
+
+            signs += 4;
+
+            const __m256i dot1  = _mm256_maddubs_epi16(q2_1, q8s_1);
+            const __m256i dot2  = _mm256_maddubs_epi16(q2_2, q8s_2);
+            const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
+            const uint16_t ls2 = x[i].scales[ib32/2] >>  4;
+            const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
+            const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
+            sumi1 = _mm256_add_epi32(sumi1, p1);
+            sumi2 = _mm256_add_epi32(sumi2, p2);
+        }
+
+        accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
+
+    }
+
+    *s = hsum_float_8(accumf);
+
+#elif defined(__AVX__)
+   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+   };
+
+    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+    };
+
+    const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);
+    const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);
+    const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
+    const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);
+
+    const __m128i idx_mul_0 = _mm_set_epi32(32, 64, 128, 256);
+    const __m128i idx_mul_1 = _mm_set_epi32(2, 4, 8, 16);
+    const __m128i idx_mask  = _mm_set1_epi32(256);
+
+    typedef union {
+        __m128i  vec[4];
+        uint32_t index[16];
+    } index_t;
+
+    index_t idx;
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict qs = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
+        const int8_t  * restrict q8 = y[i].qs;
+        __m128i sumi1_0 = _mm_setzero_si128();
+        __m128i sumi1_1 = _mm_setzero_si128();
+        __m128i sumi2_0 = _mm_setzero_si128();
+        __m128i sumi2_1 = _mm_setzero_si128();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs);
+            const __m128i idx_l_0 = _mm_cvtepu8_epi16(qs_tmp);
+            const __m128i idx_l_1 = _mm_cvtepu8_epi16(_mm_srli_si128(qs_tmp, 8)); qs += 16;
+            idx.vec[0] = _mm_set1_epi32(qh[ib32+0]);
+            idx.vec[1] = idx.vec[0];
+            idx.vec[2] = _mm_set1_epi32(qh[ib32+1]);
+            idx.vec[3] = idx.vec[2];
+
+            idx.vec[0] = _mm_and_si128(_mm_mullo_epi32(idx.vec[0], idx_mul_0), idx_mask);
+            idx.vec[1] = _mm_and_si128(_mm_mullo_epi32(idx.vec[1], idx_mul_1), idx_mask);
+            idx.vec[2] = _mm_and_si128(_mm_mullo_epi32(idx.vec[2], idx_mul_0), idx_mask);
+            idx.vec[3] = _mm_and_si128(_mm_mullo_epi32(idx.vec[3], idx_mul_1), idx_mask);
+
+            idx.vec[0] = _mm_or_si128(idx.vec[0], _mm_cvtepi16_epi32(idx_l_0));
+            idx.vec[1] = _mm_or_si128(idx.vec[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8)));
+            idx.vec[2] = _mm_or_si128(idx.vec[2], _mm_cvtepi16_epi32(idx_l_1));
+            idx.vec[3] = _mm_or_si128(idx.vec[3], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8)));
+
+            const __m128i q2_1_0 = _mm_set_epi32(iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]);
+            const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]);
+            const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[9]], iq3s_grid[idx.index[8]]);
+            const __m128i q2_2_1 = _mm_set_epi32(iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]]);
+
+            __m128i aux128_0 = _mm_set1_epi32(signs[0] | (signs[1] << 16));
+            __m128i aux128_1 = aux128_0;
+            aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
+            aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
+            const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
+            const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
+            const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);
+            const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);
+
+            aux128_0 = _mm_set1_epi32(signs[2] | (signs[3] << 16));
+            aux128_1 = aux128_0;
+            aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
+            aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
+            const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
+            const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
+            const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);
+            const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);
+
+            signs += 4;
+
+            const __m128i dot1_0  = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
+            const __m128i dot1_1  = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
+            const __m128i dot2_0  = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
+            const __m128i dot2_1  = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
+            const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
+            const uint16_t ls2 = x[i].scales[ib32/2] >>  4;
+            const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
+            const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
+            const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
+            const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
+            sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
+            sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
+            sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
+            sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
+        }
+
+        accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
+
+    }
+
+    *s = hsum_float_8(accumf);
+
+#elif defined(__POWER9_VECTOR__)
+    static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                        0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+    };
+
+    static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
+
+    const vector int v0 = vec_splats((int32_t)0);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    const vector unsigned char mask0 = vec_xl( 0, k_mask1);
+    const vector unsigned char mask1 = vec_xl(16, k_mask1);
+    const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        const uint8_t *  restrict q3 = x[i].qs;
+        const uint8_t *  restrict qh = x[i].qh;
+        const uint16_t * restrict signs = (const uint16_t *)(x[i].signs);
+        const uint8_t *  restrict sc = x[i].scales;
+        const int8_t  *  restrict q8 = y[i].qs;
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+
+        for (int j = 0; j < QK_K/32; j += 2) {
+            __builtin_prefetch(q3, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector unsigned int aux32x4_0 = {iq3s_grid[q3[ 0] | ((qh[0] << 8) & 256)], iq3s_grid[q3[ 1] | ((qh[0] << 7) & 256)],
+                                             iq3s_grid[q3[ 2] | ((qh[0] << 6) & 256)], iq3s_grid[q3[ 3] | ((qh[0] << 5) & 256)]};
+            vector unsigned int aux32x4_1 = {iq3s_grid[q3[ 4] | ((qh[0] << 4) & 256)], iq3s_grid[q3[ 5] | ((qh[0] << 3) & 256)],
+                                             iq3s_grid[q3[ 6] | ((qh[0] << 2) & 256)], iq3s_grid[q3[ 7] | ((qh[0] << 1) & 256)]};
+            vector unsigned int aux32x4_2 = {iq3s_grid[q3[ 8] | ((qh[1] << 8) & 256)], iq3s_grid[q3[ 9] | ((qh[1] << 7) & 256)],
+                                             iq3s_grid[q3[10] | ((qh[1] << 6) & 256)], iq3s_grid[q3[11] | ((qh[1] << 5) & 256)]};
+            vector unsigned int aux32x4_3 = {iq3s_grid[q3[12] | ((qh[1] << 4) & 256)], iq3s_grid[q3[13] | ((qh[1] << 3) & 256)],
+                                             iq3s_grid[q3[14] | ((qh[1] << 2) & 256)], iq3s_grid[q3[15] | ((qh[1] << 1) & 256)]};
+            q3 += 16;
+            qh += 2;
+
+            vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]);
+            vector signed char vsigns02 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]);
+            signs += 4;
+
+            vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0);
+            vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1);
+            vector signed char vsigns2 = vec_perm(vsigns02, vsigns02, mask0);
+            vector signed char vsigns3 = vec_perm(vsigns02, vsigns02, mask1);
+
+            vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2);
+            vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2);
+            vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2);
+            vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2);
+
+            vector signed char q3x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux32x4_0), vsigns0);
+            vector signed char q3x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux32x4_1), vsigns1);
+            vector signed char q3x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux32x4_2), vsigns2);
+            vector signed char q3x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux32x4_3), vsigns3);
+
+            vector signed char q8y0 = vec_xl( 0, q8);
+            vector signed char q8y1 = vec_xl(16, q8);
+            vector signed char q8y2 = vec_xl(32, q8);
+            vector signed char q8y3 = vec_xl(48, q8);
+            q8 += 64;
+
+            vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0));
+            vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1));
+            vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2));
+            vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3));
+
+            const uint16_t ls0 = (uint16_t)(sc[0] & 0xf);
+            const uint16_t ls1 = (uint16_t)(sc[0] >>  4);
+            sc ++;
+
+            vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1));
+            vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1));
+
+            vsumi0 = vec_msum(qv0, vscales01, vsumi0);
+            vsumi1 = vec_msum(qv1, vscales01, vsumi1);
+            vsumi2 = vec_msum(qv2, vscales23, vsumi2);
+            vsumi3 = vec_msum(qv3, vscales23, vsumi3);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+
+   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+   };
+
+    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+    };
+
+    const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0);
+    const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0);
+
+    __m256i idx_shift = lasx_set_w(1, 2, 3, 4, 5, 6, 7, 8);
+    const __m256i idx_mask  = __lasx_xvreplgr2vr_w(256);
+
+    typedef union {
+        __m256i  vec[2];
+        uint32_t index[16];
+    } index_t;
+
+    index_t idx;
+
+    __m256 accumf = (__m256)__lasx_xvldi(0);
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict qs = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
+        const int8_t  * restrict q8 = y[i].qs;
+        __m256i sumi1 = __lasx_xvldi(0);
+        __m256i sumi2 = __lasx_xvldi(0);
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            const __m256i idx_l = lasx_extu8_16(__lsx_vld(qs, 0)); qs += 16;
+            idx.vec[0] = __lasx_xvreplgr2vr_w(qh[ib32+0]);
+            idx.vec[1] = __lasx_xvreplgr2vr_w(qh[ib32+1]);
+            idx.vec[0] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[0], idx_shift), idx_mask);
+            idx.vec[1] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[1], idx_shift), idx_mask);
+            idx.vec[0] = __lasx_xvor_v(idx.vec[0], lasx_ext16_32(lasx_extracti128(idx_l, 0)));
+            idx.vec[1] = __lasx_xvor_v(idx.vec[1], lasx_ext16_32(lasx_extracti128(idx_l, 1)));
+
+            // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange.
+            //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4);
+            //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4);
+            const __m256i q2_1 = lasx_set_w(
+                    iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]],
+                    iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]
+            );
+            const __m256i q2_2 = lasx_set_w(
+                    iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]],
+                    iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]]
+            );
+
+            __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | (signs[1] << 16));
+            aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);
+            const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2);
+            const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1);
+
+            aux256 = __lasx_xvreplgr2vr_w(signs[2] | (signs[3] << 16));
+            aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);
+            const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2);
+            const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2);
+
+            signs += 4;
+
+            const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1);
+            const __m256i dot2  = lasx_maddubs_h(q2_2, q8s_2);
+            const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
+            const uint16_t ls2 = x[i].scales[ib32/2] >>  4;
+            const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1));
+            const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1));
+            sumi1 = __lasx_xvadd_w(sumi1, p1);
+            sumi2 = __lasx_xvadd_w(sumi2, p2);
+        }
+
+        accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
+    }
+
+    *s = hsum_float_8(accumf);
+
+#else
+
+    float sumf = 0.f;
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict qs = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const uint8_t * restrict signs = x[i].signs;
+        const int8_t  * restrict q8 = y[i].qs;
+        int32_t bsum = 0;
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1;
+            const uint32_t ls2 = 2*(x[i].scales[ib32/2] >>  4) + 1;
+            int32_t sumi = 0;
+            for (int l = 0; l < 4; ++l) {
+                const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256)));
+                const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256)));
+                for (int j = 0; j < 4; ++j) {
+                    sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
+                    sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
+                }
+                q8 += 8;
+            }
+            qs += 8;
+            signs += 4;
+            bsum += sumi * ls1;
+            sumi = 0;
+            for (int l = 0; l < 4; ++l) {
+                const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256)));
+                const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256)));
+                for (int j = 0; j < 4; ++j) {
+                    sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
+                    sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
+                }
+                q8 += 8;
+            }
+            qs += 8;
+            signs += 4;
+            bsum += sumi * ls2;
+        }
+        sumf += d * bsum;
+    }
+    *s = sumf;
+#endif
+}
+
+#if defined(__AVX2__)
+static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
+    const __m256i ax = _mm256_sign_epi8(x, x);
+    const __m256i sy = _mm256_sign_epi8(y, x);
+    return _mm256_maddubs_epi16(ax, sy);
+}
+#elif defined(__loongarch_asx)
+static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
+    const __m256i ax = __lasx_xvsigncov_b(x, x);
+    const __m256i sy = __lasx_xvsigncov_b(x, y);
+    __m256i tmp1, tmp2, tmp3;
+    tmp1 = __lasx_xvmulwev_h_bu_b(ax, sy);
+    tmp2 = __lasx_xvmulwod_h_bu_b(ax, sy);
+    tmp3 = __lasx_xvadd_h(tmp1, tmp2);
+    return __lasx_xvsat_h(tmp3, 15);
+}
+#endif
+
+void ggml_vec_dot_iq1_s_q8_K  (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq1_s * restrict x = vx;
+    const block_q8_K  * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined __ARM_NEON
+
+    ggml_int8x16x4_t q1b;
+    ggml_int8x16x4_t q8b;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint16_t * qh = x[i].qh;
+
+        int sumi1 = 0, sumi2 = 0, sumi3 = 0;
+
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+
+            q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[ib+0] << 8) & 0x700)))),
+                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[ib+0] << 5) & 0x700)))));
+            q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[ib+0] << 2) & 0x700)))),
+                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[ib+0] >> 1) & 0x700)))));
+            q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[ib+1] << 8) & 0x700)))),
+                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[ib+1] << 5) & 0x700)))));
+            q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[ib+1] << 2) & 0x700)))),
+                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[ib+1] >> 1) & 0x700)))));
+            qs += 8;
+
+            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+            const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[0], q8b.val[0]), q1b.val[1], q8b.val[1]);
+            const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[2], q8b.val[2]), q1b.val[3], q8b.val[3]);
+
+            const int ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
+            const int ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
+            sumi1 += vaddvq_s32(p1) * ls1;
+            sumi2 += vaddvq_s32(p2) * ls2;
+            sumi3 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * ls1 * (qh[ib+0] & 0x8000 ? -1 : 1)
+                   + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * ls2 * (qh[ib+1] & 0x8000 ? -1 : 1);
+
+        }
+
+        sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2 + IQ1S_DELTA * sumi3);
+    }
+
+    *s = sumf;
+
+#elif defined __AVX2__
+
+    __m256 accum = _mm256_setzero_ps();
+    float accum1 = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint16_t * qh = x[i].qh;
+
+        __m256i sumi = _mm256_setzero_si256();
+        int sumi1 = 0;
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+            const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)],
+                                                    iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
+            const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)],
+                                                    iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
+            qs += 8;
+            const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+            const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
+            const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
+            const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
+            const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
+            const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(ls1));
+            const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(ls2));
+
+            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p1, p2));
+            sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
+                   + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
+        }
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        accum = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi), accum);
+        accum1 += d * sumi1;
+
+    }
+
+    *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
+
+#elif defined __AVX__
+    __m256 accum = _mm256_setzero_ps();
+    float accum1 = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint16_t * qh = x[i].qh;
+
+        __m128i sumi1_0 = _mm_setzero_si128();
+        __m128i sumi1_1 = _mm_setzero_si128();
+        int sumi1 = 0;
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+            const __m128i q1b_1_0 = _mm_set_epi64x(iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
+            const __m128i q1b_1_1 = _mm_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)]);
+            const __m128i q1b_2_0 = _mm_set_epi64x(iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
+            const __m128i q1b_2_1 = _mm_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)]);
+            qs += 8;
+            const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+
+            const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0);
+            const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1);
+            const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0);
+            const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1);
+            const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
+            const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
+            const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(ls1));
+            const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(ls1));
+            const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(ls2));
+            const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(ls2));
+
+            sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0));
+            sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1));
+            sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
+                   + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
+        }
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum);
+        accum1 += d * sumi1;
+
+    }
+
+    *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
+
+#elif defined(__POWER9_VECTOR__)
+    const vector unsigned char v0 = vec_splats((unsigned char)0x0);
+    const vector unsigned short vsign = vec_splats((unsigned short)0x8000);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed int vsumi0 = vec_splats((int32_t)0);
+        vector signed int vsumi1 = vec_splats((int32_t)0);
+        vector signed int vsumi2 = vec_splats((int32_t)0);
+        vector signed int vsumi3 = vec_splats((int32_t)0);
+        vector signed int vsumi8 = vec_splats((int32_t)0);
+
+        const uint8_t  * restrict q1 = x[i].qs;
+        const uint16_t * restrict qh = x[i].qh;
+        const int8_t   * restrict q8 = y[i].qs;
+        const int16_t  * restrict qs = y[i].bsums;
+
+        for (int j = 0; j < QK_K/32; j += 2) {
+            __builtin_prefetch(q1, 0, 1);
+            __builtin_prefetch(qh, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector signed long long aux64x2_0 = {*(const int64_t *)(iq1s_grid + (q1[0] | ((qh[0] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[1] | ((qh[0] << 5) & 0x700)))};
+            vector signed long long aux64x2_1 = {*(const int64_t *)(iq1s_grid + (q1[2] | ((qh[0] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[3] | ((qh[0] >> 1) & 0x700)))};
+            vector signed long long aux64x2_2 = {*(const int64_t *)(iq1s_grid + (q1[4] | ((qh[1] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[5] | ((qh[1] << 5) & 0x700)))};
+            vector signed long long aux64x2_3 = {*(const int64_t *)(iq1s_grid + (q1[6] | ((qh[1] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[7] | ((qh[1] >> 1) & 0x700)))};
+            q1 += 8;
+
+            vector signed char q1x0 = (vector signed char)aux64x2_0;
+            vector signed char q1x1 = (vector signed char)aux64x2_1;
+            vector signed char q1x2 = (vector signed char)aux64x2_2;
+            vector signed char q1x3 = (vector signed char)aux64x2_3;
+
+            vector signed char q8y0 = vec_xl( 0, q8);
+            vector signed char q8y1 = vec_xl(16, q8);
+            vector signed char q8y2 = vec_xl(32, q8);
+            vector signed char q8y3 = vec_xl(48, q8);
+            q8 += 64;
+
+            vector signed short qv0 = vec_add(vec_mule(q1x0, q8y0), vec_mulo(q1x0, q8y0));
+            vector signed short qv1 = vec_add(vec_mule(q1x1, q8y1), vec_mulo(q1x1, q8y1));
+            vector signed short qv2 = vec_add(vec_mule(q1x2, q8y2), vec_mulo(q1x2, q8y2));
+            vector signed short qv3 = vec_add(vec_mule(q1x3, q8y3), vec_mulo(q1x3, q8y3));
+
+            const uint16_t ls0 = (uint16_t)((qh[0] >> 12) & 7);
+            const uint16_t ls1 = (uint16_t)((qh[1] >> 12) & 7);
+
+            vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1));
+            vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1));
+            vector signed short vscales = vec_sld(vscales23, vscales01, 8);
+
+            vsumi0 = vec_msum(qv0, vscales01, vsumi0);
+            vsumi1 = vec_msum(qv1, vscales01, vsumi1);
+            vsumi2 = vec_msum(qv2, vscales23, vsumi2);
+            vsumi3 = vec_msum(qv3, vscales23, vsumi3);
+
+            vector signed short q8ysums = vec_xl_len(qs, 8);
+            qs += 4;
+            q8ysums = vec_mergeh(q8ysums, (vector signed short)v0);
+
+            vector signed short qxh = (vector signed short)vec_sld(vec_splats(qh[1]), vec_splats(qh[0]), 8);
+            qh += 2;
+            vector __bool short vsel = vec_cmpge(qxh, (vector signed short)v0);
+
+            vector signed short q8ysum = vec_sel((vector signed short)vec_xor((vector unsigned short)q8ysums, vsign), q8ysums, vsel);
+
+            vsumi8 = vec_add(vec_mule(q8ysum, vscales), vsumi8);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi8, 0), vec_mul(vd, vec_splats(IQ1S_DELTA)), vsumf0);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+
+    __m256 accum = (__m256)__lasx_xvldi(0);
+    float accum1 = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint16_t * qh = x[i].qh;
+
+        __m256i sumi = __lasx_xvldi(0);
+        int sumi1 = 0;
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+            __m256i q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)], 0);
+            q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], 1);
+            q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)], 2);
+            q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], 3);
+
+            __m256i q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)], 0);
+            q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], 1);
+            q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)], 2);
+            q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], 3);
+
+            qs += 8;
+            const __m256i q8b_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            const __m256i q8b_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+
+            const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
+            const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
+            const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
+            const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
+
+            __m256i tmp1, tmp5, tmp6;
+            tmp1 = __lasx_xvreplgr2vr_h(ls1);
+            tmp5 = __lasx_xvmulwev_w_h(dot1, tmp1);
+            tmp6 = __lasx_xvmulwod_w_h(dot1, tmp1);
+            const __m256i p1 = __lasx_xvadd_w(tmp5, tmp6);
+
+            tmp1 = __lasx_xvreplgr2vr_h(ls2);
+            tmp5 = __lasx_xvmulwev_w_h(dot2, tmp1);
+            tmp6 = __lasx_xvmulwod_w_h(dot2, tmp1);
+            const __m256i p2 = __lasx_xvadd_w(tmp5, tmp6);
+
+            sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p1, p2));
+            sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
+                   + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
+        }
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), accum);
+        accum1 += d * sumi1;
+    }
+
+    *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
+
+#else
+
+    float sumf = 0;
+    for (int i = 0; i < nb; i++) {
+
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint16_t * qh = x[i].qh;
+
+        int sumi = 0, sumi1 = 0;
+        for (int ib = 0; ib < QK_K/32; ++ib) {
+            const int ls = 2*((qh[ib] >> 12) & 7) + 1;
+            const int delta = qh[ib] & 0x8000 ? -1 : 1;
+            int lsum = 0;
+            for (int l = 0; l < 4; ++l) {
+                const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));
+                for (int j = 0; j < 8; ++j) {
+                    lsum += q8[j] * grid[j];
+                }
+                q8 += 8;
+            }
+            sumi  += ls * lsum;
+            sumi1 += ls * delta * (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]);
+            qs += 4;
+        }
+
+        sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1);
+    }
+
+    *s = sumf;
+
+#endif
+}
+
+void ggml_vec_dot_iq1_m_q8_K  (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq1_m * restrict x = vx;
+    const block_q8_K  * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+    iq1m_scale_t scale;
+
+#if defined __ARM_NEON
+    const int32x4_t mask  = vdupq_n_s32(0x7);
+    const int32x4_t mone  = vdupq_n_s32(1);
+    const int32x4_t mzero = vdupq_n_s32(0);
+
+    ggml_int8x16x4_t deltas;
+    deltas.val[0] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(+1));
+    deltas.val[1] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(+1));
+    deltas.val[2] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(-1));
+    deltas.val[3] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(-1));
+
+    ggml_int8x16x4_t q1b;
+    ggml_int8x16x4_t q8b;
+
+    uint32_t aux32;
+    const uint8_t * aux8 = (const uint8_t *)&aux32;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint8_t  * qh = x[i].qh;
+        const uint16_t * sc = (const uint16_t *)x[i].scales;
+
+        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+
+        int32x4_t sumi1 = mzero;
+        int32x4_t sumi2 = mzero;
+
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+
+            q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700)))),
+                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 4) & 0x700)))));
+            q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[1] << 8) & 0x700)))),
+                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[1] << 4) & 0x700)))));
+            q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[2] << 8) & 0x700)))),
+                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[2] << 4) & 0x700)))));
+            q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[3] << 8) & 0x700)))),
+                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[3] << 4) & 0x700)))));
+
+            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+            const int32x4_t p1 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(mzero, q1b.val[1], q8b.val[1]));
+            const int32x4_t p2 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(mzero, q1b.val[3], q8b.val[3]));
+            const int32x4_t p12 = vpaddq_s32(p1, p2);
+
+            const uint32_t * qh32 = (const uint32_t *)qh; // we are 4-byte aligned, so we can do that
+            aux32 = ((qh32[0] >> 3) & 0x01010101) | ((qh32[0] >> 6) & 0x02020202);
+
+            const int32x4_t p3 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[0]], q8b.val[0]), ggml_vdotq_s32(mzero, deltas.val[aux8[1]], q8b.val[1]));
+            const int32x4_t p4 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[2]], q8b.val[2]), ggml_vdotq_s32(mzero, deltas.val[aux8[3]], q8b.val[3]));
+            const int32x4_t p34 = vpaddq_s32(p3, p4);
+
+            int32x4_t scales_4 = ggml_vld1q_u32(sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9);
+
+            scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone);
+
+            sumi1 = vmlaq_s32(sumi1, scales_4, p12);
+            sumi2 = vmlaq_s32(sumi2, scales_4, p34);
+
+            qs += 8; qh += 4;
+
+        }
+
+        sumf += y[i].d * GGML_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2));
+    }
+
+    *s = sumf;
+
+#elif defined __AVX2__
+
+    const __m256i mask = _mm256_set1_epi16(0x7);
+    const __m256i mone = _mm256_set1_epi16(1);
+
+    __m256 accum1 = _mm256_setzero_ps();
+    __m256 accum2 = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint8_t  * qh = x[i].qh;
+        const uint16_t * sc = (const uint16_t *)x[i].scales;
+
+        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+
+        __m256i sumi1 = _mm256_setzero_si256();
+        __m256i sumi2 = _mm256_setzero_si256();
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+            const __m256i q1b_1 = _mm256_set_epi64x(
+                    iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)],
+                    iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]
+            );
+            const __m256i q1b_2 = _mm256_set_epi64x(
+                    iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)],
+                    iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]
+            );
+            const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+            const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
+            const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
+
+            const __m256i delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
+            const __m256i delta2 = _mm256_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
+
+            const __m256i dot3 = mul_add_epi8(delta1, q8b_1);
+            const __m256i dot4 = mul_add_epi8(delta2, q8b_2);
+
+            __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 3), _mm_set1_epi16(sc[ib/2] >> 0));
+            __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 9), _mm_set1_epi16(sc[ib/2] >> 6));
+
+            scale1 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale1, mask), 1), mone);
+            scale2 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale2, mask), 1), mone);
+            const __m256i p1 = _mm256_madd_epi16(dot1, scale1);
+            const __m256i p2 = _mm256_madd_epi16(dot2, scale2);
+            const __m256i p3 = _mm256_madd_epi16(dot3, scale1);
+            const __m256i p4 = _mm256_madd_epi16(dot4, scale2);
+
+            sumi1 = _mm256_add_epi32(sumi1, _mm256_add_epi32(p1, p2));
+            sumi2 = _mm256_add_epi32(sumi2, _mm256_add_epi32(p3, p4));
+
+            qs += 8; qh += 4;
+        }
+
+        const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16));
+
+        accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1);
+        accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2);
+    }
+
+    *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
+
+#elif defined __AVX__
+    const __m128i mask = _mm_set1_epi16(0x7);
+    const __m128i mone = _mm_set1_epi16(1);
+
+    __m256 accum1 = _mm256_setzero_ps();
+    __m256 accum2 = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint8_t  * qh = x[i].qh;
+        const uint16_t * sc = (const uint16_t *)x[i].scales;
+
+        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+
+        __m128i sumi1_0 = _mm_setzero_si128();
+        __m128i sumi1_1 = _mm_setzero_si128();
+        __m128i sumi2_0 = _mm_setzero_si128();
+        __m128i sumi2_1 = _mm_setzero_si128();
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+            const __m128i q1b_1_0 = _mm_set_epi64x(
+                    iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]);
+            const __m128i q1b_1_1 = _mm_set_epi64x(
+                    iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)]);
+            const __m128i q1b_2_0 = _mm_set_epi64x(
+                    iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]);
+            const __m128i q1b_2_1 = _mm_set_epi64x(
+                    iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)]);
+            const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+
+            const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0);
+            const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1);
+            const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0);
+            const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1);
+
+            const __m128i delta1_0 = _mm_set_epi64x(qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
+            const __m128i delta1_1 = _mm_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
+            const __m128i delta2_0 = _mm_set_epi64x(qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
+            const __m128i delta2_1 = _mm_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
+
+            const __m128i dot3_0 = mul_add_epi8_sse(delta1_0, q8b_1_0);
+            const __m128i dot3_1 = mul_add_epi8_sse(delta1_1, q8b_1_1);
+            const __m128i dot4_0 = mul_add_epi8_sse(delta2_0, q8b_2_0);
+            const __m128i dot4_1 = mul_add_epi8_sse(delta2_1, q8b_2_1);
+
+            __m128i scale1_0 = _mm_set1_epi16(sc[ib/2] >> 0);
+            __m128i scale1_1 = _mm_set1_epi16(sc[ib/2] >> 3);
+            __m128i scale2_0 = _mm_set1_epi16(sc[ib/2] >> 6);
+            __m128i scale2_1 = _mm_set1_epi16(sc[ib/2] >> 9);
+
+            scale1_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_0, mask), 1), mone);
+            scale1_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_1, mask), 1), mone);
+            scale2_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_0, mask), 1), mone);
+            scale2_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_1, mask), 1), mone);
+            const __m128i p1_0 = _mm_madd_epi16(dot1_0, scale1_0);
+            const __m128i p1_1 = _mm_madd_epi16(dot1_1, scale1_1);
+            const __m128i p2_0 = _mm_madd_epi16(dot2_0, scale2_0);
+            const __m128i p2_1 = _mm_madd_epi16(dot2_1, scale2_1);
+            const __m128i p3_0 = _mm_madd_epi16(dot3_0, scale1_0);
+            const __m128i p3_1 = _mm_madd_epi16(dot3_1, scale1_1);
+            const __m128i p4_0 = _mm_madd_epi16(dot4_0, scale2_0);
+            const __m128i p4_1 = _mm_madd_epi16(dot4_1, scale2_1);
+
+            sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0));
+            sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1));
+            sumi2_0 = _mm_add_epi32(sumi2_0, _mm_add_epi32(p3_0, p4_0));
+            sumi2_1 = _mm_add_epi32(sumi2_1, _mm_add_epi32(p3_1, p4_1));
+
+            qs += 8; qh += 4;
+        }
+
+        const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16));
+
+        accum1 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum1);
+        accum2 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi2_1, sumi2_0))), accum2);
+    }
+
+    *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
+
+#else
+
+    int sum1[2], sum2[2], delta[4];
+
+    float sumf = 0;
+    for (int i = 0; i < nb; i++) {
+
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint8_t  * qh = x[i].qh;
+        const uint16_t * sc = (const uint16_t *)x[i].scales;
+
+        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+
+        int sumi1 = 0, sumi2 = 0;
+        for (int ib = 0; ib < QK_K/32; ++ib) {
+            delta[0] = qh[0] & 0x08 ? -1 : 1;
+            delta[1] = qh[0] & 0x80 ? -1 : 1;
+            delta[2] = qh[1] & 0x08 ? -1 : 1;
+            delta[3] = qh[1] & 0x80 ? -1 : 1;
+            sum1[0] = sum1[1] = sum2[0] = sum2[1] = 0;
+            for (int l = 0; l < 4; ++l) {
+                const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((uint16_t)qh[l/2] << (8 - 4*(l%2))) & 0x700)));
+                int lsum1 = 0, lsum2 = 0;
+                for (int j = 0; j < 8; ++j) {
+                    lsum1 += q8[j] * grid[j];
+                    lsum2 += q8[j];
+                }
+                q8 += 8;
+                sum1[l/2] += lsum1;
+                sum2[l/2] += lsum2*delta[l];
+            }
+
+            const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1;
+            const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1;
+
+            sumi1 += sum1[0] * ls1 + sum1[1] * ls2;
+            sumi2 += sum2[0] * ls1 + sum2[1] * ls2;
+            qs += 4;
+            qh += 2;
+        }
+
+        sumf += GGML_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2);
+    }
+
+    *s = sumf;
+
+#endif
+}
+
+void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    assert(n % QK4_NL == 0);
+    static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
+
+    const block_iq4_nl * restrict x = vx;
+    const block_q8_0   * restrict y = vy;
+
+    const int nb = n / QK4_NL;
+
+    int ib = 0;
+    float sumf = 0;
+
+#if defined __ARM_NEON
+    const int8x16_t values = vld1q_s8(kvalues_iq4nl);
+    const uint8x16_t m4b = vdupq_n_u8(0x0f);
+    uint8x16x2_t q4bits;
+    int8x16x4_t q4b;
+    int8x16x4_t q8b;
+    int32x4_t prod_1, prod_2;
+
+    for (; ib + 1 < nb; ib += 2) {
+
+        q4bits.val[0] = vld1q_u8(x[ib + 0].qs);
+        q4bits.val[1] = vld1q_u8(x[ib + 1].qs);
+        q8b.val[0]    = vld1q_s8(y[ib + 0].qs);
+        q8b.val[1]    = vld1q_s8(y[ib + 0].qs + 16);
+        q8b.val[2]    = vld1q_s8(y[ib + 1].qs);
+        q8b.val[3]    = vld1q_s8(y[ib + 1].qs + 16);
+
+        q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[0], m4b));
+        q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
+        q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[1], m4b));
+        q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
+
+        prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
+        prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
+
+        sumf +=
+            GGML_FP16_TO_FP32(x[ib+0].d) * GGML_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) +
+            GGML_FP16_TO_FP32(x[ib+1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2);
+    }
+
+#elif defined __AVX2__
+
+    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
+    const __m128i m4b  = _mm_set1_epi8(0x0f);
+    const __m256i mone = _mm256_set1_epi16(1);
+
+    __m256 accum1 = _mm256_setzero_ps();
+    __m256 accum2 = _mm256_setzero_ps();
+    for (; ib + 1 < nb; ib += 2) {
+        const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);
+        const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);
+        const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs);
+        const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs);
+        const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
+                                              _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
+        const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
+                                              _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
+        const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
+        const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
+        const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
+        const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
+        accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)),
+                _mm256_cvtepi32_ps(p_1), accum1);
+        accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)),
+                _mm256_cvtepi32_ps(p_2), accum2);
+    }
+
+    sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
+
+#elif defined __AVX__
+    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
+    const __m128i m4b  = _mm_set1_epi8(0x0f);
+
+    __m256 accum = _mm256_setzero_ps();
+    for (; ib + 1 < nb; ib += 2) {
+        const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
+        const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
+        const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
+        const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
+        const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
+        const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
+
+        const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
+        const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
+        const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
+        const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
+
+        const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1);
+        const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d);
+        accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
+    }
+
+    sumf = hsum_float_8(accum);
+
+#elif defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector signed int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+
+    const vector signed char values = vec_xl( 0, kvalues_iq4nl);
+
+#pragma GCC unroll 4
+    for (; ib < nb; ++ib) {
+        __builtin_prefetch(x[ib].qs, 0, 1);
+        __builtin_prefetch(y[ib].qs, 0, 1);
+
+
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
+        vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
+        vector signed char q4x0 = vec_and(qxs, lowMask);
+        vector signed char q4x1 = vec_sr(qxs, v4);
+
+        q4x0 = vec_perm(values, values, (vector unsigned char)q4x0);
+        q4x1 = vec_perm(values, values, (vector unsigned char)q4x1);
+
+        vector signed char q8y0 = vec_xl( 0, y[ib].qs);
+        vector signed char q8y1 = vec_xl(16, y[ib].qs);
+
+        vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0));
+        vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1));
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+
+        vsumi0 = vec_sum4s(qv0, vsumi0);
+        vsumi1 = vec_sum4s(qv1, vsumi1);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    sumf = vec_extract(vsumf0, 0);
+
+#elif defined (__loongarch_asx)
+
+    const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0);
+    const __m128i m4b  = __lsx_vreplgr2vr_b(0x0f);
+    const __m256i mone = __lasx_xvreplgr2vr_h(1);
+
+    __m256 accum1 = (__m256)__lasx_xvldi(0);
+    __m256 accum2 = (__m256)__lasx_xvldi(0);
+    for (; ib + 1 < nb; ib += 2) {
+        const __m128i q4bits_1 = __lsx_vld((const __m128i*)x[ib + 0].qs, 0);
+        const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[ib + 1].qs, 0);
+        const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[ib + 0].qs, 0);
+        const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[ib + 1].qs, 0);
+        const __m256i q4b_1 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b)),
+                                              lsx_shuffle_b(values128, __lsx_vand_v(q4bits_1, m4b)));
+        const __m256i q4b_2 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b)),
+                                              lsx_shuffle_b(values128, __lsx_vand_v(q4bits_2, m4b)));
+        const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
+        const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
+        const __m256i p_1 = lasx_madd_h(p16_1, mone);
+        const __m256i p_2 = lasx_madd_h(p16_2, mone);
+        accum1 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)),
+                __lasx_xvffint_s_w(p_1), accum1);
+        accum2 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)),
+                __lasx_xvffint_s_w(p_2), accum2);
+    }
+
+    sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2));
+
+#endif
+    for (; ib < nb; ++ib) {
+        const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d);
+        int sumi1 = 0, sumi2 = 0;
+        for (int j = 0; j < QK4_NL/2; ++j) {
+            sumi1 += y[ib].qs[j+       0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];
+            sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >>  4];
+        }
+        sumf += d * (sumi1 + sumi2);
+    }
+    *s = sumf;
+}
+
+void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    assert(n % QK_K == 0);
+
+    const block_iq4_xs * restrict x = vx;
+    const block_q8_K   * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined __ARM_NEON
+    const int8x16_t values = vld1q_s8(kvalues_iq4nl);
+    const uint8x16_t m4b = vdupq_n_u8(0x0f);
+    ggml_uint8x16x2_t q4bits;
+    ggml_int8x16x4_t q4b;
+    ggml_int8x16x4_t q8b;
+    int32x4_t prod_1, prod_2;
+
+    float sumf = 0;
+
+    for (int ibl = 0; ibl < nb; ++ibl) {
+
+        const int8_t  * q8 = y[ibl].qs;
+        const uint8_t * q4 = x[ibl].qs;
+        uint16_t h = x[ibl].scales_h;
+
+        int sumi1 = 0, sumi2 = 0;
+        for (int ib = 0; ib < QK_K/64; ++ib) {
+
+            q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
+            q8b    = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+            q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[0], m4b));
+            q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
+            q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[1], m4b));
+            q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
+
+            prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
+            prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
+
+            int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32;
+            int ls2 = ((x[ibl].scales_l[ib] >>  4) | ((h << 2) & 0x30)) - 32;
+            h >>= 4;
+            sumi1 += vaddvq_s32(prod_1) * ls1;
+            sumi2 += vaddvq_s32(prod_2) * ls2;
+
+        }
+
+        sumf += GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2);
+    }
+
+    *s = sumf;
+
+#elif defined __AVX2__
+
+    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
+    const __m128i m4b  = _mm_set1_epi8(0x0f);
+
+    __m256 accum = _mm256_setzero_ps();
+    for (int ibl = 0; ibl < nb; ++ibl) {
+        const uint8_t * qs = x[ibl].qs;
+        const int8_t  * q8 = y[ibl].qs;
+        uint16_t sh = x[ibl].scales_h;
+        __m256i sumi1 = _mm256_setzero_si256();
+        __m256i sumi2 = _mm256_setzero_si256();
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+            const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)qs);  qs += 16;
+            const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)qs);  qs += 16;
+            const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
+                                                  _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
+            const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
+                                                  _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
+            const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
+            const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
+            const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
+            const int16_t ls2 = ((x[ibl].scales_l[ib/2] >>  4) | ((sh << 2) & 0x30)) - 32;
+            sh >>= 4;
+            const __m256i p_1 = _mm256_madd_epi16(p16_1, _mm256_set1_epi16(ls1));
+            const __m256i p_2 = _mm256_madd_epi16(p16_2, _mm256_set1_epi16(ls2));
+            sumi1 = _mm256_add_epi32(p_1, sumi1);
+            sumi2 = _mm256_add_epi32(p_2, sumi2);
+        }
+        accum = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
+                _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accum);
+    }
+
+    *s = hsum_float_8(accum);
+
+#elif defined __AVX__
+    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
+    const __m128i m4b  = _mm_set1_epi8(0x0f);
+
+    __m256 accum = _mm256_setzero_ps();
+    for (int ibl = 0; ibl < nb; ++ibl) {
+        const uint8_t * qs = x[ibl].qs;
+        const int8_t  * q8 = y[ibl].qs;
+        uint16_t sh = x[ibl].scales_h;
+        __m128i sumi1_0 = _mm_setzero_si128();
+        __m128i sumi1_1 = _mm_setzero_si128();
+        __m128i sumi2_0 = _mm_setzero_si128();
+        __m128i sumi2_1 = _mm_setzero_si128();
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+            const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)qs); qs += 16;
+            const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)qs); qs += 16;
+            const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
+            const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
+            const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
+            const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
+            const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
+            const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
+            const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
+            const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
+            const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
+            const int16_t ls2 = ((x[ibl].scales_l[ib/2] >>  4) | ((sh << 2) & 0x30)) - 32;
+            sh >>= 4;
+            const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, _mm_set1_epi16(ls1));
+            const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, _mm_set1_epi16(ls1));
+            const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, _mm_set1_epi16(ls2));
+            const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, _mm_set1_epi16(ls2));
+            sumi1_0 = _mm_add_epi32(p_1_0, sumi1_0);
+            sumi1_1 = _mm_add_epi32(p_1_1, sumi1_1);
+            sumi2_0 = _mm_add_epi32(p_2_0, sumi2_0);
+            sumi2_1 = _mm_add_epi32(p_2_1, sumi2_1);
+        }
+        __m128i sumi12_0 = _mm_add_epi32(sumi1_0, sumi2_0);
+        __m128i sumi12_1 = _mm_add_epi32(sumi1_1, sumi2_1);
+        accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
+                _mm256_cvtepi32_ps(MM256_SET_M128I(sumi12_1, sumi12_0))), accum);
+    }
+
+    *s = hsum_float_8(accum);
+
+#elif defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    const vector signed char values = vec_xl( 0, kvalues_iq4nl);
+
+    for (int ibl = 0; ibl < nb; ++ibl) {
+
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ibl].d));
+        vector float vyd = vec_splats(y[ibl].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+
+        uint16_t h = x[ibl].scales_h;
+
+        const uint8_t * restrict q4 = x[ibl].qs;
+        const uint8_t * restrict sc = x[ibl].scales_l;
+        const int8_t  * restrict q8 = y[ibl].qs;
+
+        for (int ib = 0; ib < QK_K/64; ib ++ ) {
+            __builtin_prefetch(q4, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector signed char qxs0 = (vector signed char)vec_xl( 0, q4);
+            vector signed char qxs1 = (vector signed char)vec_xl(16, q4);
+            q4 += 32;
+
+            vector signed char q4x00 = (vector signed char)vec_and(qxs0, lowMask);
+            vector signed char q4x01 = (vector signed char)vec_sr(qxs0, v4);
+            vector signed char q4x10 = (vector signed char)vec_and(qxs1, lowMask);
+            vector signed char q4x11 = (vector signed char)vec_sr(qxs1, v4);
+
+            q4x00 = vec_perm(values, values, (vector unsigned char)q4x00);
+            q4x01 = vec_perm(values, values, (vector unsigned char)q4x01);
+            q4x10 = vec_perm(values, values, (vector unsigned char)q4x10);
+            q4x11 = vec_perm(values, values, (vector unsigned char)q4x11);
+
+            vector signed char q8y0 = vec_xl( 0, q8);
+            vector signed char q8y1 = vec_xl(16, q8);
+            vector signed char q8y2 = vec_xl(32, q8);
+            vector signed char q8y3 = vec_xl(48, q8);
+            q8 += 64;
+
+            vector signed short qv0 = vec_add(vec_mule(q4x00, q8y0), vec_mulo(q4x00, q8y0));
+            vector signed short qv1 = vec_add(vec_mule(q4x01, q8y1), vec_mulo(q4x01, q8y1));
+            vector signed short qv2 = vec_add(vec_mule(q4x10, q8y2), vec_mulo(q4x10, q8y2));
+            vector signed short qv3 = vec_add(vec_mule(q4x11, q8y3), vec_mulo(q4x11, q8y3));
+
+            const uint16_t ls0 = (uint16_t)(((sc[0] & 0xf) | ((h << 4) & 0x30)) - 32);
+            const uint16_t ls1 = (uint16_t)(((sc[0] >>  4) | ((h << 2) & 0x30)) - 32);
+            h >>= 4;
+            sc ++;
+
+            vector signed short vscales01 = vec_splats((int16_t)ls0);
+            vector signed short vscales23 = vec_splats((int16_t)ls1);
+
+            vsumi0 = vec_msum(qv0, vscales01, vsumi0);
+            vsumi1 = vec_msum(qv1, vscales01, vsumi1);
+            vsumi2 = vec_msum(qv2, vscales23, vsumi2);
+            vsumi3 = vec_msum(qv3, vscales23, vsumi3);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+
+    const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0);
+    const __m128i m4b  = __lsx_vreplgr2vr_b(0x0f);
+
+    __m256 accum = (__m256)__lasx_xvldi(0);
+    __m256i tmp1;
+    __m128i tmp0, tmp2, tmp3, tmp4, mask_8f, mask;
+
+    mask_8f = __lsx_vreplgr2vr_b(0x8f);
+    for (int ibl = 0; ibl < nb; ++ibl) {
+        const uint8_t * qs = x[ibl].qs;
+        const int8_t  * q8 = y[ibl].qs;
+        uint16_t sh = x[ibl].scales_h;
+        __m256i sumi1 = __lasx_xvldi(0);
+        __m256i sumi2 = __lasx_xvldi(0);
+        __m128i zero = __lsx_vldi(0);
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+            const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0);  qs += 16;
+            const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0);  qs += 16;
+            const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b), mask_8f);
+            tmp0 = __lsx_vori_b(tmp2, 0x10);
+            mask = __lsx_vsle_b(zero, tmp2);
+            tmp3 = __lsx_vand_v(tmp0, mask);
+            tmp3 = __lsx_vshuf_b(values128, zero, tmp3);
+
+            tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_1, m4b), mask_8f);
+            tmp0 = __lsx_vori_b(tmp2, 0x10);
+            mask = __lsx_vsle_b(zero, tmp2);
+            tmp4 = __lsx_vand_v(tmp0, mask);
+            tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
+
+            const __m256i q4b_1 = lasx_insertf128(tmp3, tmp4);
+
+            tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b), mask_8f);
+            tmp0 = __lsx_vori_b(tmp2, 0x10);
+            mask = __lsx_vsle_b(zero, tmp2);
+            tmp3 = __lsx_vand_v(tmp0, mask);
+            tmp3 = __lsx_vshuf_b(values128, zero, tmp3);
+
+            tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_2, m4b), mask_8f);
+            tmp0 = __lsx_vori_b(tmp2, 0x10);
+            mask = __lsx_vsle_b(zero, tmp2);
+            tmp4 = __lsx_vand_v(tmp0, mask);
+            tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
+
+            const __m256i q4b_2 = lasx_insertf128(tmp3, tmp4);
+
+            const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
+            const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
+            const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
+            const int16_t ls2 = ((x[ibl].scales_l[ib/2] >>  4) | ((sh << 2) & 0x30)) - 32;
+            sh >>= 4;
+            __m256i tmp5, tmp6;
+            tmp1 = __lasx_xvreplgr2vr_h(ls1);
+            tmp5 = __lasx_xvmulwev_w_h(p16_1, tmp1);
+            tmp6 = __lasx_xvmulwod_w_h(p16_1, tmp1);
+            const __m256i p_1 = __lasx_xvadd_w(tmp5, tmp6);
+            tmp1 = __lasx_xvreplgr2vr_h(ls2);
+            tmp5 = __lasx_xvmulwev_w_h(p16_2, tmp1);
+            tmp6 = __lasx_xvmulwod_w_h(p16_2, tmp1);
+            const __m256i p_2 = __lasx_xvadd_w(tmp5, tmp6);
+            sumi1 = __lasx_xvadd_w(p_1, sumi1);
+            sumi2 = __lasx_xvadd_w(p_2, sumi2);
+        }
+        accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
+                __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accum);
+    }
+
+    *s = hsum_float_8(accum);
+
+#else
+    float sumf = 0;
+    for (int ibl = 0; ibl < nb; ++ibl) {
+        const float d4d8 = GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d;
+        uint16_t h = x[ibl].scales_h;
+        const uint8_t * qs = x[ibl].qs;
+        const int8_t  * q8 = y[ibl].qs;
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+            const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30);
+            const uint8_t ls2 = (x[ibl].scales_l[ib/2] >>  4) | ((h << 2) & 0x30);
+            h >>= 4;
+            const float d1 = d4d8*(ls1 - 32);
+            const float d2 = d4d8*(ls2 - 32);
+            int sumi1 = 0, sumi2 = 0;
+            for (int j = 0; j < 16; ++j) {
+                sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
+                sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >>  4];
+            }
+            sumf += d1 * (sumi1 + sumi2);
+            qs += 16;
+            q8 += 32;
+            sumi1 = sumi2 = 0;
+            for (int j = 0; j < 16; ++j) {
+                sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
+                sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >>  4];
+            }
+            sumf += d2 * (sumi1 + sumi2);
+            qs += 16;
+            q8 += 32;
+        }
+    }
+    *s = sumf;
+#endif
+}
+
+// ============================ 4-bit non-linear quants
+
+void quantize_row_iq4_nl(const float * restrict x, void * restrict y, int64_t k) {
+    assert(k % QK4_NL == 0);
+    quantize_row_iq4_nl_ref(x, y, k);
+}
+
+void quantize_row_iq4_xs(const float * restrict x, void * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    quantize_iq4_xs(x, y, 1, k, NULL);
+}

+ 89 - 0
llama/ggml-cpu-quants.h

@@ -0,0 +1,89 @@
+/**
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+
+#include "ggml.h"
+
+// GGML CPU internal header
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// Quantization
+void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+
+void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+
+void quantize_row_tq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+
+void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+
+// Dot product
+void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
+void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
+void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
+void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq2_s_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq1_s_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq1_m_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq3_s_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
+#ifdef __cplusplus
+}
+#endif

+ 14054 - 0
llama/ggml-cpu.c

@@ -0,0 +1,14054 @@
+/**
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#define _CRT_SECURE_NO_DEPRECATE // Disables "unsafe" warnings on Windows
+#define _USE_MATH_DEFINES // For M_PI on MSVC
+
+#include "ggml-backend-impl.h"
+#include "ggml-backend.h"
+#include "ggml-cpu-aarch64.h"
+#include "ggml-cpu-impl.h"
+#include "ggml-cpu.h"
+#include "ggml-impl.h"
+#include "ggml-quants.h"
+#include "ggml-cpu-quants.h"
+#include "ggml-threading.h"
+#include "amx.h"
+#include "ggml.h"
+
+#if defined(_MSC_VER) || defined(__MINGW32__)
+#include <malloc.h> // using malloc.h with MSC/MINGW
+#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
+#include <alloca.h>
+#endif
+
+#include <assert.h>
+#include <errno.h>
+#include <time.h>
+#include <math.h>
+#include <stdlib.h>
+#include <string.h>
+#include <stdint.h>
+#include <inttypes.h>
+#include <stdio.h>
+#include <float.h>
+#include <limits.h>
+#include <stdarg.h>
+#include <signal.h>
+#if defined(__gnu_linux__)
+#include <syscall.h>
+#endif
+
+#ifdef GGML_USE_OPENMP
+#include <omp.h>
+#endif
+
+#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8)
+#undef GGML_USE_LLAMAFILE
+#endif
+
+#ifdef GGML_USE_LLAMAFILE
+#include "llamafile/sgemm.h"
+#endif
+
+#if defined(_MSC_VER)
+// disable "possible loss of data" to avoid hundreds of casts
+// we should just be careful :)
+#pragma warning(disable: 4244 4267)
+
+// disable POSIX deprecation warnings
+// these functions are never going away, anyway
+#pragma warning(disable: 4996)
+
+// unreachable code because of multiple instances of code after GGML_ABORT
+#pragma warning(disable: 4702)
+#endif
+
+// Note: once we move threading into a separate C++ file
+// will use std::hardware_destructive_interference_size instead of hardcoding it here
+// and we'll use C++ attribute syntax.
+#define GGML_CACHE_LINE  64
+
+#if defined(__clang__) || defined(__GNUC__)
+#define GGML_CACHE_ALIGN __attribute__((aligned(GGML_CACHE_LINE)))
+#endif
+
+#if defined(__has_feature)
+#if __has_feature(thread_sanitizer)
+#define GGML_TSAN_ENABLED 1
+#endif
+#else  // __has_feature
+#if defined(__SANITIZE_THREAD__)
+#define GGML_TSAN_ENABLED 1
+#endif
+#endif // __has_feature
+
+#define UNUSED GGML_UNUSED
+#define SWAP(x, y, T) do { T SWAP = x; (x) = y; (y) = SWAP; } while (0)
+
+#if defined(GGML_USE_ACCELERATE)
+#include <Accelerate/Accelerate.h>
+#endif
+
+// floating point type used to accumulate sums
+typedef double ggml_float;
+
+#define GGML_GELU_FP16
+#define GGML_GELU_QUICK_FP16
+
+#define GGML_SOFT_MAX_UNROLL 4
+#define GGML_VEC_DOT_UNROLL  2
+#define GGML_VEC_MAD_UNROLL  32
+
+//
+// global data
+//
+
+// precomputed gelu table for f16 (128 KB)
+static ggml_fp16_t ggml_table_gelu_f16[1 << 16];
+
+// precomputed quick gelu table for f16 (128 KB)
+static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
+
+#if defined(__ARM_ARCH)
+struct ggml_arm_arch_features_type {
+    int has_neon;
+    int has_dotprod;
+    int has_i8mm;
+    int has_sve;
+    int sve_cnt;
+} ggml_arm_arch_features = {-1, -1, -1, -1, 0};
+#endif
+
+
+#if defined(_WIN32)
+
+#define WIN32_LEAN_AND_MEAN
+#ifndef NOMINMAX
+    #define NOMINMAX
+#endif
+#include <windows.h>
+
+
+#if !defined(__clang__)
+#define GGML_CACHE_ALIGN __declspec(align(GGML_CACHE_LINE))
+
+typedef volatile LONG atomic_int;
+typedef atomic_int atomic_bool;
+typedef atomic_int atomic_flag;
+
+#define ATOMIC_FLAG_INIT 0
+
+typedef enum {
+    memory_order_relaxed,
+    memory_order_consume,
+    memory_order_acquire,
+    memory_order_release,
+    memory_order_acq_rel,
+    memory_order_seq_cst
+} memory_order;
+
+static void atomic_store(atomic_int * ptr, LONG val) {
+    InterlockedExchange(ptr, val);
+}
+static void atomic_store_explicit(atomic_int * ptr, LONG val, memory_order mo) {
+    // TODO: add support for explicit memory order
+    InterlockedExchange(ptr, val);
+}
+static LONG atomic_load(atomic_int * ptr) {
+    return InterlockedCompareExchange(ptr, 0, 0);
+}
+static LONG atomic_load_explicit(atomic_int * ptr, memory_order mo) {
+    // TODO: add support for explicit memory order
+    return InterlockedCompareExchange(ptr, 0, 0);
+}
+static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
+    return InterlockedExchangeAdd(ptr, inc);
+}
+static LONG atomic_fetch_add_explicit(atomic_int * ptr, LONG inc, memory_order mo) {
+    // TODO: add support for explicit memory order
+    return InterlockedExchangeAdd(ptr, inc);
+}
+static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
+    return InterlockedExchange(ptr, 1);
+}
+static void atomic_flag_clear(atomic_flag * ptr) {
+    InterlockedExchange(ptr, 0);
+}
+static void atomic_thread_fence(memory_order mo) {
+    MemoryBarrier();
+}
+#else // clang
+#include <stdatomic.h>
+#endif
+
+typedef HANDLE pthread_t;
+
+typedef DWORD thread_ret_t;
+static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(void *), void * arg) {
+    (void) unused;
+    HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
+    if (handle == NULL)
+    {
+        return EAGAIN;
+    }
+
+    *out = handle;
+    return 0;
+}
+
+static int pthread_join(pthread_t thread, void * unused) {
+    (void) unused;
+    int ret = (int) WaitForSingleObject(thread, INFINITE);
+    CloseHandle(thread);
+    return ret;
+}
+
+static int sched_yield (void) {
+    Sleep (0);
+    return 0;
+}
+#else
+
+#include <pthread.h>
+#include <stdatomic.h>
+#include <sched.h>
+#if defined(__FreeBSD__)
+#include <pthread_np.h>
+#endif
+
+typedef void * thread_ret_t;
+
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#endif
+
+typedef pthread_t ggml_thread_t;
+
+#ifdef GGML_USE_CPU_HBM
+#include <hbwmalloc.h>
+#endif
+
+#if defined(__APPLE__)
+#include <unistd.h>
+#include <mach/mach.h>
+#include <TargetConditionals.h>
+#endif
+
+//
+// cache line
+//
+
+#if defined(__cpp_lib_hardware_interference_size)
+#define CACHE_LINE_SIZE hardware_destructive_interference_size
+#else
+#if defined(__POWER9_VECTOR__)
+#define CACHE_LINE_SIZE 128
+#else
+#define CACHE_LINE_SIZE 64
+#endif
+#endif
+
+static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
+
+
+static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
+static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
+static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
+
+static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
+    [GGML_TYPE_F32] = {
+        .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f32,
+        .vec_dot_type             = GGML_TYPE_F32,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_F16] = {
+        .from_float               = (ggml_from_float_t) ggml_fp32_to_fp16_row,
+        .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f16,
+        .vec_dot_type             = GGML_TYPE_F16,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_Q4_0] = {
+        .from_float               = quantize_row_q4_0,
+        .vec_dot                  = ggml_vec_dot_q4_0_q8_0,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+#if defined (__ARM_FEATURE_MATMUL_INT8)
+        .nrows                    = 2,
+#else
+        .nrows                    = 1,
+#endif
+    },
+    [GGML_TYPE_Q4_1] = {
+        .from_float               = quantize_row_q4_1,
+        .vec_dot                  = ggml_vec_dot_q4_1_q8_1,
+        .vec_dot_type             = GGML_TYPE_Q8_1,
+#if defined (__ARM_FEATURE_MATMUL_INT8)
+        .nrows                    = 2,
+#else
+        .nrows                    = 1,
+#endif
+    },
+    [GGML_TYPE_Q5_0] = {
+        .from_float               = quantize_row_q5_0,
+        .vec_dot                  = ggml_vec_dot_q5_0_q8_0,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_Q5_1] = {
+        .from_float               = quantize_row_q5_1,
+        .vec_dot                  = ggml_vec_dot_q5_1_q8_1,
+        .vec_dot_type             = GGML_TYPE_Q8_1,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_Q8_0] = {
+        .from_float               = quantize_row_q8_0,
+        .from_float_to_mat        = quantize_mat_q8_0,
+        .vec_dot                  = ggml_vec_dot_q8_0_q8_0,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+#if defined (__ARM_FEATURE_MATMUL_INT8)
+        .nrows                    = 2,
+#else
+        .nrows                    = 1,
+#endif
+    },
+    [GGML_TYPE_Q8_1] = {
+        .from_float               = quantize_row_q8_1,
+        .vec_dot_type             = GGML_TYPE_Q8_1,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_Q2_K] = {
+        .from_float               = quantize_row_q2_K,
+        .vec_dot                  = ggml_vec_dot_q2_K_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_Q3_K] = {
+        .from_float               = quantize_row_q3_K,
+        .vec_dot                  = ggml_vec_dot_q3_K_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_Q4_K] = {
+        .from_float               = quantize_row_q4_K,
+        .vec_dot                  = ggml_vec_dot_q4_K_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_Q5_K] = {
+        .from_float               = quantize_row_q5_K,
+        .vec_dot                  = ggml_vec_dot_q5_K_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_Q6_K] = {
+        .from_float               = quantize_row_q6_K,
+        .vec_dot                  = ggml_vec_dot_q6_K_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_IQ2_XXS] = {
+        .from_float               = NULL,
+        .vec_dot                  = ggml_vec_dot_iq2_xxs_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_IQ2_XS] = {
+        .from_float               = NULL,
+        .vec_dot                  = ggml_vec_dot_iq2_xs_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_IQ3_XXS] = {
+        // NOTE: from_float for iq3 and iq2_s was removed because these quants require initialization in ggml_quantize_init
+        //.from_float               = quantize_row_iq3_xxs,
+        .vec_dot                  = ggml_vec_dot_iq3_xxs_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_IQ3_S] = {
+        //.from_float               = quantize_row_iq3_s,
+        .vec_dot                  = ggml_vec_dot_iq3_s_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_IQ2_S] = {
+        //.from_float               = quantize_row_iq2_s,
+        .vec_dot                  = ggml_vec_dot_iq2_s_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_IQ1_S] = {
+        .from_float               = NULL,
+        .vec_dot                  = ggml_vec_dot_iq1_s_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_IQ1_M] = {
+        .from_float               = NULL,
+        .vec_dot                  = ggml_vec_dot_iq1_m_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_IQ4_NL] = {
+        .from_float               = quantize_row_iq4_nl,
+        .vec_dot                  = ggml_vec_dot_iq4_nl_q8_0,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_IQ4_XS] = {
+        .from_float               = quantize_row_iq4_xs,
+        .vec_dot                  = ggml_vec_dot_iq4_xs_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_Q8_K] = {
+        .from_float               = quantize_row_q8_K,
+    },
+    [GGML_TYPE_BF16] = {
+        .from_float               = (ggml_from_float_t) ggml_fp32_to_bf16_row,
+        .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_bf16,
+        .vec_dot_type             = GGML_TYPE_BF16,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_Q4_0_4_4] = {
+        .from_float               = NULL,
+        .vec_dot                  = NULL,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+        .nrows                    = 1,
+        .ncols                    = 4,
+        .gemv                     = ggml_gemv_q4_0_4x4_q8_0,
+        .gemm                     = ggml_gemm_q4_0_4x4_q8_0,
+    },
+    [GGML_TYPE_Q4_0_4_8] = {
+        .from_float               = NULL,
+        .vec_dot                  = NULL,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+        .nrows                    = 1,
+        .ncols                    = 4,
+        .gemv                     = ggml_gemv_q4_0_4x8_q8_0,
+        .gemm                     = ggml_gemm_q4_0_4x8_q8_0,
+    },
+    [GGML_TYPE_Q4_0_8_8] = {
+        .from_float               = NULL,
+        .vec_dot                  = NULL,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+        .nrows                    = 1,
+        .ncols                    = 8,
+        .gemv                     = ggml_gemv_q4_0_8x8_q8_0,
+        .gemm                     = ggml_gemm_q4_0_8x8_q8_0,
+    },
+    [GGML_TYPE_TQ1_0] = {
+        .from_float               = quantize_row_tq1_0,
+        .vec_dot                  = ggml_vec_dot_tq1_0_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_TQ2_0] = {
+        .from_float               = quantize_row_tq2_0,
+        .vec_dot                  = ggml_vec_dot_tq2_0_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_IQ4_NL_4_4] = {
+        .from_float               = NULL,
+        .vec_dot                  = NULL,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+        .nrows                    = 1,
+        .ncols                    = 4,
+        .gemv                     = ggml_gemv_iq4_nl_4x4_q8_0,
+        .gemm                     = ggml_gemm_iq4_nl_4x4_q8_0,
+    },
+};
+
+const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
+    return &type_traits_cpu[type];
+}
+
+//
+// simd mappings
+//
+
+// we define a common set of C macros which map to specific intrinsics based on the current architecture
+// we then implement the fundamental computation operations below using only these macros
+// adding support for new architectures requires to define the corresponding SIMD macros
+//
+// GGML_F32_STEP / GGML_F16_STEP
+//   number of elements to process in a single step
+//
+// GGML_F32_EPR / GGML_F16_EPR
+//   number of elements to fit in a single register
+//
+
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
+
+#define GGML_SIMD
+
+// F32 NEON
+
+#define GGML_F32_STEP 16
+#define GGML_F32_EPR  4
+
+#define GGML_F32x4              float32x4_t
+#define GGML_F32x4_ZERO         vdupq_n_f32(0.0f)
+#define GGML_F32x4_SET1(x)      vdupq_n_f32(x)
+#define GGML_F32x4_LOAD         vld1q_f32
+#define GGML_F32x4_STORE        vst1q_f32
+#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
+#define GGML_F32x4_ADD          vaddq_f32
+#define GGML_F32x4_MUL          vmulq_f32
+#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
+#define GGML_F32x4_REDUCE(res, x)                  \
+{                                                  \
+    int offset = GGML_F32_ARR >> 1;                \
+    for (int i = 0; i < offset; ++i) {             \
+        (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
+    }                                              \
+    offset >>= 1;                                  \
+    for (int i = 0; i < offset; ++i) {             \
+        (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
+    }                                              \
+    offset >>= 1;                                  \
+    for (int i = 0; i < offset; ++i) {             \
+        (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
+    }                                              \
+    (res) = GGML_F32x4_REDUCE_ONE((x)[0]);         \
+}
+
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 NEON
+
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+    #define GGML_F16_STEP 32
+    #define GGML_F16_EPR  8
+
+    #define GGML_F16x8              float16x8_t
+    #define GGML_F16x8_ZERO         vdupq_n_f16(0.0f)
+    #define GGML_F16x8_SET1(x)      vdupq_n_f16(x)
+    #define GGML_F16x8_LOAD(x)      vld1q_f16((const ggml_fp16_internal_t *)(x))
+    #define GGML_F16x8_STORE        vst1q_f16
+    #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
+    #define GGML_F16x8_ADD          vaddq_f16
+    #define GGML_F16x8_MUL          vmulq_f16
+    #define GGML_F16x8_REDUCE(res, x)                               \
+    do {                                                            \
+        int offset = GGML_F16_ARR >> 1;                             \
+        for (int i = 0; i < offset; ++i) {                          \
+            (x)[i] = vaddq_f16((x)[i], (x)[offset+i]);              \
+        }                                                           \
+        offset >>= 1;                                               \
+        for (int i = 0; i < offset; ++i) {                          \
+            (x)[i] = vaddq_f16((x)[i], (x)[offset+i]);              \
+        }                                                           \
+        offset >>= 1;                                               \
+        for (int i = 0; i < offset; ++i) {                          \
+            (x)[i] = vaddq_f16((x)[i], (x)[offset+i]);              \
+        }                                                           \
+        const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 ((x)[0])); \
+        const float32x4_t t1 = vcvt_f32_f16(vget_high_f16((x)[0])); \
+        (res) = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1));         \
+    } while (0)
+
+    #define GGML_F16_VEC                GGML_F16x8
+    #define GGML_F16_VEC_ZERO           GGML_F16x8_ZERO
+    #define GGML_F16_VEC_SET1           GGML_F16x8_SET1
+    #define GGML_F16_VEC_LOAD(p, i)     GGML_F16x8_LOAD(p)
+    #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), (r)[i])
+    #define GGML_F16_VEC_FMA            GGML_F16x8_FMA
+    #define GGML_F16_VEC_ADD            GGML_F16x8_ADD
+    #define GGML_F16_VEC_MUL            GGML_F16x8_MUL
+    #define GGML_F16_VEC_REDUCE         GGML_F16x8_REDUCE
+#else
+    // if FP16 vector arithmetic is not supported, we use FP32 instead
+    // and take advantage of the vcvt_ functions to convert to/from FP16
+
+    #define GGML_F16_STEP 16
+    #define GGML_F16_EPR  4
+
+    #define GGML_F32Cx4              float32x4_t
+    #define GGML_F32Cx4_ZERO         vdupq_n_f32(0.0f)
+    #define GGML_F32Cx4_SET1(x)      vdupq_n_f32(x)
+    #define GGML_F32Cx4_LOAD(x)      vcvt_f32_f16(vld1_f16((const ggml_fp16_internal_t *)(x)))
+    #define GGML_F32Cx4_STORE(x, y)  vst1_f16(x, vcvt_f16_f32(y))
+    #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
+    #define GGML_F32Cx4_ADD          vaddq_f32
+    #define GGML_F32Cx4_MUL          vmulq_f32
+    #define GGML_F32Cx4_REDUCE       GGML_F32x4_REDUCE
+
+    #define GGML_F16_VEC                GGML_F32Cx4
+    #define GGML_F16_VEC_ZERO           GGML_F32Cx4_ZERO
+    #define GGML_F16_VEC_SET1           GGML_F32Cx4_SET1
+    #define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx4_LOAD(p)
+    #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i])
+    #define GGML_F16_VEC_FMA            GGML_F32Cx4_FMA
+    #define GGML_F16_VEC_ADD            GGML_F32Cx4_ADD
+    #define GGML_F16_VEC_MUL            GGML_F32Cx4_MUL
+    #define GGML_F16_VEC_REDUCE         GGML_F32Cx4_REDUCE
+#endif
+
+#elif defined(__AVX512F__)
+
+#define GGML_SIMD
+
+// F32 AVX512
+
+#define GGML_F32_STEP 64
+#define GGML_F32_EPR  16
+
+#define GGML_F32x16         __m512
+#define GGML_F32x16_ZERO    _mm512_setzero_ps()
+#define GGML_F32x16_SET1(x) _mm512_set1_ps(x)
+#define GGML_F32x16_LOAD    _mm512_loadu_ps
+#define GGML_F32x16_STORE   _mm512_storeu_ps
+// _mm512_fmadd_ps is defined in AVX512F so no guard is required
+#define GGML_F32x16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
+#define GGML_F32x16_ADD     _mm512_add_ps
+#define GGML_F32x16_MUL     _mm512_mul_ps
+#define GGML_F32x16_REDUCE(res, x)                                    \
+do {                                                                  \
+    int offset = GGML_F32_ARR >> 1;                                   \
+    for (int i = 0; i < offset; ++i) {                                \
+        x[i] = _mm512_add_ps(x[i], x[offset+i]);                      \
+    }                                                                 \
+    offset >>= 1;                                                     \
+    for (int i = 0; i < offset; ++i) {                                \
+        x[i] = _mm512_add_ps(x[i], x[offset+i]);                      \
+    }                                                                 \
+    offset >>= 1;                                                     \
+    for (int i = 0; i < offset; ++i) {                                \
+        x[i] = _mm512_add_ps(x[i], x[offset+i]);                      \
+    }                                                                 \
+    res = (ggml_float) _mm512_reduce_add_ps(x[0]);                    \
+} while (0)
+
+// TODO: is this optimal ?
+
+#define GGML_F32_VEC        GGML_F32x16
+#define GGML_F32_VEC_ZERO   GGML_F32x16_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x16_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x16_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x16_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x16_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x16_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x16_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x16_REDUCE
+
+// F16 AVX512
+
+// F16 AVX
+
+#define GGML_F16_STEP 64
+#define GGML_F16_EPR  16
+
+// AVX512 has FP16 extension (AVX512_FP16) but I don't have it on my machine so I use FP32 instead
+
+#define GGML_F32Cx16             __m512
+#define GGML_F32Cx16_ZERO        _mm512_setzero_ps()
+#define GGML_F32Cx16_SET1(x)     _mm512_set1_ps(x)
+
+// unlike  _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
+// so F16C guard isn't required
+#define GGML_F32Cx16_LOAD(x)     _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
+#define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
+
+#define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
+#define GGML_F32Cx16_ADD         _mm512_add_ps
+#define GGML_F32Cx16_MUL         _mm512_mul_ps
+#define GGML_F32Cx16_REDUCE(res, x)                               \
+do {                                                              \
+    int offset = GGML_F32_ARR >> 1;                               \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm512_add_ps(x[i], x[offset+i]);                  \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm512_add_ps(x[i], x[offset+i]);                  \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm512_add_ps(x[i], x[offset+i]);                  \
+    }                                                             \
+    res = (ggml_float) _mm512_reduce_add_ps(x[0]);                \
+} while (0)
+
+#define GGML_F16_VEC                GGML_F32Cx16
+#define GGML_F16_VEC_ZERO           GGML_F32Cx16_ZERO
+#define GGML_F16_VEC_SET1           GGML_F32Cx16_SET1
+#define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx16_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx16_STORE(p, r[i])
+#define GGML_F16_VEC_FMA            GGML_F32Cx16_FMA
+#define GGML_F16_VEC_ADD            GGML_F32Cx16_ADD
+#define GGML_F16_VEC_MUL            GGML_F32Cx16_MUL
+
+#define GGML_F16_VEC_REDUCE         GGML_F32Cx16_REDUCE
+#elif defined(__AVX__)
+
+#define GGML_SIMD
+
+// F32 AVX
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  8
+
+#define GGML_F32x8         __m256
+#define GGML_F32x8_ZERO    _mm256_setzero_ps()
+#define GGML_F32x8_SET1(x) _mm256_set1_ps(x)
+#define GGML_F32x8_LOAD    _mm256_loadu_ps
+#define GGML_F32x8_STORE   _mm256_storeu_ps
+#if defined(__FMA__)
+    #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a)
+#else
+    #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a)
+#endif
+#define GGML_F32x8_ADD     _mm256_add_ps
+#define GGML_F32x8_MUL     _mm256_mul_ps
+#define GGML_F32x8_REDUCE(res, x)                                 \
+do {                                                              \
+    int offset = GGML_F32_ARR >> 1;                               \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \
+    }                                                             \
+    const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]),    \
+                                 _mm256_extractf128_ps(x[0], 1)); \
+    const __m128 t1 = _mm_hadd_ps(t0, t0);                        \
+    res = (ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t1, t1));        \
+} while (0)
+// TODO: is this optimal ?
+
+#define GGML_F32_VEC        GGML_F32x8
+#define GGML_F32_VEC_ZERO   GGML_F32x8_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x8_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x8_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x8_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x8_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x8_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x8_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
+
+// F16 AVX
+
+#define GGML_F16_STEP 32
+#define GGML_F16_EPR  8
+
+// F16 arithmetic is not supported by AVX, so we use F32 instead
+
+#define GGML_F32Cx8             __m256
+#define GGML_F32Cx8_ZERO        _mm256_setzero_ps()
+#define GGML_F32Cx8_SET1(x)     _mm256_set1_ps(x)
+
+#if defined(__F16C__)
+// the  _mm256_cvt intrinsics require F16C
+#define GGML_F32Cx8_LOAD(x)     _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
+#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
+#else
+static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
+    float tmp[8];
+
+    for (int i = 0; i < 8; i++) {
+        tmp[i] = GGML_FP16_TO_FP32(x[i]);
+    }
+
+    return _mm256_loadu_ps(tmp);
+}
+static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
+    float arr[8];
+
+    _mm256_storeu_ps(arr, y);
+
+    for (int i = 0; i < 8; i++)
+        x[i] = GGML_FP32_TO_FP16(arr[i]);
+}
+#define GGML_F32Cx8_LOAD(x)     __avx_f32cx8_load(x)
+#define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y)
+#endif
+
+#define GGML_F32Cx8_FMA         GGML_F32x8_FMA
+#define GGML_F32Cx8_ADD         _mm256_add_ps
+#define GGML_F32Cx8_MUL         _mm256_mul_ps
+#define GGML_F32Cx8_REDUCE      GGML_F32x8_REDUCE
+
+#define GGML_F16_VEC                GGML_F32Cx8
+#define GGML_F16_VEC_ZERO           GGML_F32Cx8_ZERO
+#define GGML_F16_VEC_SET1           GGML_F32Cx8_SET1
+#define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx8_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
+#define GGML_F16_VEC_FMA            GGML_F32Cx8_FMA
+#define GGML_F16_VEC_ADD            GGML_F32Cx8_ADD
+#define GGML_F16_VEC_MUL            GGML_F32Cx8_MUL
+#define GGML_F16_VEC_REDUCE         GGML_F32Cx8_REDUCE
+
+#elif defined(__POWER9_VECTOR__)
+
+#define GGML_SIMD
+
+// F32 POWER9
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  4
+
+#define GGML_F32x4              vector float
+#define GGML_F32x4_ZERO         0.0f
+#define GGML_F32x4_SET1         vec_splats
+#define GGML_F32x4_LOAD(p)      vec_xl(0, p)
+#define GGML_F32x4_STORE(p, r)  vec_xst(r, 0, p)
+#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
+#define GGML_F32x4_ADD          vec_add
+#define GGML_F32x4_MUL          vec_mul
+#define GGML_F32x4_REDUCE(res, x)              \
+{                                              \
+    int offset = GGML_F32_ARR >> 1;            \
+    for (int i = 0; i < offset; ++i) {         \
+        x[i] = vec_add(x[i], x[offset+i]);     \
+    }                                          \
+    offset >>= 1;                              \
+    for (int i = 0; i < offset; ++i) {         \
+        x[i] = vec_add(x[i], x[offset+i]);     \
+    }                                          \
+    offset >>= 1;                              \
+    for (int i = 0; i < offset; ++i) {         \
+        x[i] = vec_add(x[i], x[offset+i]);     \
+    }                                          \
+    res = vec_extract(x[0], 0) +               \
+          vec_extract(x[0], 1) +               \
+          vec_extract(x[0], 2) +               \
+          vec_extract(x[0], 3);                \
+}
+
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 POWER9
+#define GGML_F16_STEP       GGML_F32_STEP
+#define GGML_F16_EPR        GGML_F32_EPR
+#define GGML_F16_VEC        GGML_F32x4
+#define GGML_F16_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F16_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F16_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F16_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F16_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
+// Use vec_xl, not vec_ld, in case the load address is not aligned.
+#define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ?                   \
+  vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \
+  vec_extract_fp32_from_shortl(vec_xl(0, p))
+#define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i]
+#define GGML_F16_VEC_STORE(p, r, i)                             \
+  if (i & 0x1)                                                  \
+    vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)],  \
+                                   r[i - GGML_ENDIAN_BYTE(0)]), \
+            0, p - GGML_F16_EPR)
+
+#elif defined(__wasm_simd128__)
+
+#define GGML_SIMD
+
+// F32 WASM
+
+#define GGML_F32_STEP 16
+#define GGML_F32_EPR  4
+
+#define GGML_F32x4              v128_t
+#define GGML_F32x4_ZERO         wasm_f32x4_splat(0.0f)
+#define GGML_F32x4_SET1(x)      wasm_f32x4_splat(x)
+#define GGML_F32x4_LOAD         wasm_v128_load
+#define GGML_F32x4_STORE        wasm_v128_store
+#define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a)
+#define GGML_F32x4_ADD          wasm_f32x4_add
+#define GGML_F32x4_MUL          wasm_f32x4_mul
+#define GGML_F32x4_REDUCE(res, x)                  \
+{                                                  \
+    int offset = GGML_F32_ARR >> 1;                \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
+    }                                              \
+    offset >>= 1;                                  \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
+    }                                              \
+    offset >>= 1;                                  \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
+    }                                              \
+    res = wasm_f32x4_extract_lane(x[0], 0) +       \
+          wasm_f32x4_extract_lane(x[0], 1) +       \
+          wasm_f32x4_extract_lane(x[0], 2) +       \
+          wasm_f32x4_extract_lane(x[0], 3);        \
+}
+
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 WASM
+
+#define GGML_F16_STEP 16
+#define GGML_F16_EPR  4
+
+inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) {
+    float tmp[4];
+
+    tmp[0] = GGML_FP16_TO_FP32(p[0]);
+    tmp[1] = GGML_FP16_TO_FP32(p[1]);
+    tmp[2] = GGML_FP16_TO_FP32(p[2]);
+    tmp[3] = GGML_FP16_TO_FP32(p[3]);
+
+    return wasm_v128_load(tmp);
+}
+
+inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
+    float tmp[4];
+
+    wasm_v128_store(tmp, x);
+
+    p[0] = GGML_FP32_TO_FP16(tmp[0]);
+    p[1] = GGML_FP32_TO_FP16(tmp[1]);
+    p[2] = GGML_FP32_TO_FP16(tmp[2]);
+    p[3] = GGML_FP32_TO_FP16(tmp[3]);
+}
+
+#define GGML_F16x4             v128_t
+#define GGML_F16x4_ZERO        wasm_f32x4_splat(0.0f)
+#define GGML_F16x4_SET1(x)     wasm_f32x4_splat(x)
+#define GGML_F16x4_LOAD(x)     __wasm_f16x4_load(x)
+#define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y)
+#define GGML_F16x4_FMA         GGML_F32x4_FMA
+#define GGML_F16x4_ADD         wasm_f32x4_add
+#define GGML_F16x4_MUL         wasm_f32x4_mul
+#define GGML_F16x4_REDUCE(res, x)                  \
+{                                                  \
+    int offset = GGML_F16_ARR >> 1;                \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
+    }                                              \
+    offset >>= 1;                                  \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
+    }                                              \
+    offset >>= 1;                                  \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
+    }                                              \
+    res = wasm_f32x4_extract_lane(x[0], 0) +       \
+          wasm_f32x4_extract_lane(x[0], 1) +       \
+          wasm_f32x4_extract_lane(x[0], 2) +       \
+          wasm_f32x4_extract_lane(x[0], 3);        \
+}
+
+#define GGML_F16_VEC                GGML_F16x4
+#define GGML_F16_VEC_ZERO           GGML_F16x4_ZERO
+#define GGML_F16_VEC_SET1           GGML_F16x4_SET1
+#define GGML_F16_VEC_LOAD(p, i)     GGML_F16x4_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i])
+#define GGML_F16_VEC_FMA            GGML_F16x4_FMA
+#define GGML_F16_VEC_ADD            GGML_F16x4_ADD
+#define GGML_F16_VEC_MUL            GGML_F16x4_MUL
+#define GGML_F16_VEC_REDUCE         GGML_F16x4_REDUCE
+
+#elif defined(__SSE3__)
+
+#define GGML_SIMD
+
+// F32 SSE
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  4
+
+#define GGML_F32x4         __m128
+#define GGML_F32x4_ZERO    _mm_setzero_ps()
+#define GGML_F32x4_SET1(x) _mm_set1_ps(x)
+#define GGML_F32x4_LOAD    _mm_loadu_ps
+#define GGML_F32x4_STORE   _mm_storeu_ps
+#if defined(__FMA__)
+    // TODO: Does this work?
+    #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a)
+#else
+    #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a)
+#endif
+#define GGML_F32x4_ADD     _mm_add_ps
+#define GGML_F32x4_MUL     _mm_mul_ps
+#define GGML_F32x4_REDUCE(res, x)                                 \
+{                                                                 \
+    int offset = GGML_F32_ARR >> 1;                               \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \
+    }                                                             \
+    const __m128 t0 = _mm_hadd_ps(x[0], x[0]);                    \
+    res = (ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t0, t0));        \
+}
+// TODO: is this optimal ?
+
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 SSE
+
+#define GGML_F16_STEP 32
+#define GGML_F16_EPR  4
+
+static inline __m128 __sse_f16x4_load(ggml_fp16_t *x) {
+    float tmp[4];
+
+    tmp[0] = GGML_FP16_TO_FP32(x[0]);
+    tmp[1] = GGML_FP16_TO_FP32(x[1]);
+    tmp[2] = GGML_FP16_TO_FP32(x[2]);
+    tmp[3] = GGML_FP16_TO_FP32(x[3]);
+
+    return _mm_loadu_ps(tmp);
+}
+
+static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
+    float arr[4];
+
+    _mm_storeu_ps(arr, y);
+
+    x[0] = GGML_FP32_TO_FP16(arr[0]);
+    x[1] = GGML_FP32_TO_FP16(arr[1]);
+    x[2] = GGML_FP32_TO_FP16(arr[2]);
+    x[3] = GGML_FP32_TO_FP16(arr[3]);
+}
+
+#define GGML_F32Cx4             __m128
+#define GGML_F32Cx4_ZERO        _mm_setzero_ps()
+#define GGML_F32Cx4_SET1(x)     _mm_set1_ps(x)
+#define GGML_F32Cx4_LOAD(x)     __sse_f16x4_load(x)
+#define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y)
+#define GGML_F32Cx4_FMA         GGML_F32x4_FMA
+#define GGML_F32Cx4_ADD         _mm_add_ps
+#define GGML_F32Cx4_MUL         _mm_mul_ps
+#define GGML_F32Cx4_REDUCE      GGML_F32x4_REDUCE
+
+#define GGML_F16_VEC                 GGML_F32Cx4
+#define GGML_F16_VEC_ZERO            GGML_F32Cx4_ZERO
+#define GGML_F16_VEC_SET1            GGML_F32Cx4_SET1
+#define GGML_F16_VEC_LOAD(p, i)      GGML_F32Cx4_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i)  GGML_F32Cx4_STORE(p, r[i])
+#define GGML_F16_VEC_FMA             GGML_F32Cx4_FMA
+#define GGML_F16_VEC_ADD             GGML_F32Cx4_ADD
+#define GGML_F16_VEC_MUL             GGML_F32Cx4_MUL
+#define GGML_F16_VEC_REDUCE          GGML_F32Cx4_REDUCE
+
+#elif defined(__loongarch_asx)
+
+#define GGML_SIMD
+
+// F32 LASX
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  8
+
+#define GGML_F32x8         __m256
+#define GGML_F32x8_ZERO    (__m256)__lasx_xvldi(0)
+#define GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
+#define GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0)
+#define GGML_F32x8_STORE(x,y)   __lasx_xvst((y), (x), 0)
+#define GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a)
+#define GGML_F32x8_ADD     __lasx_xvfadd_s
+#define GGML_F32x8_MUL     __lasx_xvfmul_s
+#define GGML_F32x8_REDUCE(res, x)                                 \
+do {                                                              \
+    int offset = GGML_F32_ARR >> 1;                               \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = __lasx_xvfadd_s(x[i], x[offset+i]);                  \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = __lasx_xvfadd_s(x[i], x[offset+i]);                  \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = __lasx_xvfadd_s(x[i], x[offset+i]);                  \
+    }                                                             \
+    float *tmp_p = (float *)&x[0]; \
+    res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7];  \
+} while (0)
+// TODO: is this optimal ?
+
+#define GGML_F32_VEC        GGML_F32x8
+#define GGML_F32_VEC_ZERO   GGML_F32x8_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x8_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x8_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x8_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x8_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x8_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x8_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
+
+// F16 LASX
+
+#define GGML_F16_STEP 32
+#define GGML_F16_EPR  8
+
+// F16 arithmetic is not supported by AVX, so we use F32 instead
+
+#define GGML_F32Cx8          __m256
+#define GGML_F32Cx8_ZERO    (__m256)__lasx_xvldi(0)
+#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
+
+static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
+    float tmp[8];
+
+    for (int i = 0; i < 8; i++) {
+        tmp[i] = GGML_FP16_TO_FP32(x[i]);
+    }
+
+    return (__m256)__lasx_xvld(tmp, 0);
+}
+static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
+    float arr[8];
+
+    __lasx_xvst(y, arr, 0);
+
+    for (int i = 0; i < 8; i++) {
+        x[i] = GGML_FP32_TO_FP16(arr[i]);
+    }
+}
+#define GGML_F32Cx8_LOAD(x)     __lasx_f32cx8_load(x)
+#define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
+
+#define GGML_F32Cx8_FMA         GGML_F32x8_FMA
+#define GGML_F32Cx8_ADD         __lasx_xvfadd_s
+#define GGML_F32Cx8_MUL         __lasx_xvfmul_s
+#define GGML_F32Cx8_REDUCE      GGML_F32x8_REDUCE
+
+#define GGML_F16_VEC                GGML_F32Cx8
+#define GGML_F16_VEC_ZERO           GGML_F32Cx8_ZERO
+#define GGML_F16_VEC_SET1           GGML_F32Cx8_SET1
+#define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx8_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
+#define GGML_F16_VEC_FMA            GGML_F32Cx8_FMA
+#define GGML_F16_VEC_ADD            GGML_F32Cx8_ADD
+#define GGML_F16_VEC_MUL            GGML_F32Cx8_MUL
+#define GGML_F16_VEC_REDUCE         GGML_F32Cx8_REDUCE
+
+#elif defined(__loongarch_sx)
+
+#define GGML_SIMD
+
+// F32 LSX
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  4
+
+#define GGML_F32x4         __m128
+#define GGML_F32x4_ZERO    __lsx_vldi(0)
+#define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
+#define GGML_F32x4_LOAD(x) __lsx_vld((x), 0)
+#define GGML_F32x4_STORE((x),(y))   __lsx_vst((y), (x), 0)
+#define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
+#define GGML_F32x4_ADD     __lsx_vfadd_s
+#define GGML_F32x4_MUL     __lsx_vfmul_s
+#define GGML_F32x4_REDUCE(res, x)                                                     \
+{                                                                                     \
+    int offset = GGML_F32_ARR >> 1;                                                   \
+    for (int i = 0; i < offset; ++i) {                                                \
+        x[i] = __lsx_vfadd_s(x[i], x[offset + i]);                                    \
+    }                                                                                 \
+    offset >>= 1;                                                                     \
+    for (int i = 0; i < offset; ++i) {                                                \
+        x[i] = __lsx_vfadd_s(x[i], x[offset + i]);                                    \
+    }                                                                                 \
+    offset >>= 1;                                                                     \
+    for (int i = 0; i < offset; ++i) {                                                \
+        x[i] = __lsx_vfadd_s(x[i], x[offset + i]);                                    \
+    }                                                                                 \
+    __m128i tmp     = __lsx_vsrli_d((__m128i) x[0], 32);                              \
+    tmp             = (__m128i) __lsx_vfadd_s((__m128) tmp, x[0]);                    \
+    tmp             = __lsx_vpickev_w(__lsx_vldi(0), tmp);                            \
+    const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88);                                     \
+    tmp             = __lsx_vsrli_d((__m128i) t0, 32);                                \
+    tmp             = (__m128i) __lsx_vfadd_s((__m128) tmp, t0);                      \
+    tmp             = __lsx_vpickev_w(__lsx_vldi(0), tmp);                            \
+    res             = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
+}
+
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 LSX
+
+#define GGML_F16_STEP 32
+#define GGML_F16_EPR  4
+
+static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) {
+    float tmp[4];
+
+    tmp[0] = GGML_FP16_TO_FP32(x[0]);
+    tmp[1] = GGML_FP16_TO_FP32(x[1]);
+    tmp[2] = GGML_FP16_TO_FP32(x[2]);
+    tmp[3] = GGML_FP16_TO_FP32(x[3]);
+
+    return __lsx_vld(tmp, 0);
+}
+
+static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
+    float arr[4];
+
+    __lsx_vst(y, arr, 0);
+
+    x[0] = GGML_FP32_TO_FP16(arr[0]);
+    x[1] = GGML_FP32_TO_FP16(arr[1]);
+    x[2] = GGML_FP32_TO_FP16(arr[2]);
+    x[3] = GGML_FP32_TO_FP16(arr[3]);
+}
+
+#define GGML_F32Cx4             __m128
+#define GGML_F32Cx4_ZERO        __lsx_vldi(0)
+#define GGML_F32Cx4_SET1(x)     __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
+#define GGML_F32Cx4_LOAD(x)     __lsx_f16x4_load(x)
+#define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
+#define GGML_F32Cx4_FMA         GGML_F32x4_FMA
+#define GGML_F32Cx4_ADD         __lsx_vfadd_s
+#define GGML_F32Cx4_MUL         __lsx_vfmul_s
+#define GGML_F32Cx4_REDUCE      GGML_F32x4_REDUCE
+
+#define GGML_F16_VEC                 GGML_F32Cx4
+#define GGML_F16_VEC_ZERO            GGML_F32Cx4_ZERO
+#define GGML_F16_VEC_SET1            GGML_F32Cx4_SET1
+#define GGML_F16_VEC_LOAD(p, i)      GGML_F32Cx4_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i)  GGML_F32Cx4_STORE(p, r[i])
+#define GGML_F16_VEC_FMA             GGML_F32Cx4_FMA
+#define GGML_F16_VEC_ADD             GGML_F32Cx4_ADD
+#define GGML_F16_VEC_MUL             GGML_F32Cx4_MUL
+#define GGML_F16_VEC_REDUCE          GGML_F32Cx4_REDUCE
+
+#endif
+
+// GGML_F32_ARR / GGML_F16_ARR
+//   number of registers to use per step
+#ifdef GGML_SIMD
+#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR)
+#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
+#endif
+
+//
+// Threading defs
+//
+
+typedef pthread_t          ggml_thread_t;
+
+#if defined(_WIN32)
+
+typedef CONDITION_VARIABLE ggml_cond_t;
+typedef SRWLOCK            ggml_mutex_t;
+
+#define ggml_mutex_init(m)   InitializeSRWLock(m)
+#define ggml_mutex_destroy(m)
+#define ggml_mutex_lock(m)   AcquireSRWLockExclusive(m)
+#define ggml_mutex_unlock(m) ReleaseSRWLockExclusive(m)
+#define ggml_mutex_lock_shared(m)   AcquireSRWLockShared(m)
+#define ggml_mutex_unlock_shared(m) ReleaseSRWLockShared(m)
+
+#define ggml_cond_init(c)    InitializeConditionVariable(c)
+#define ggml_cond_destroy(c)
+#define ggml_cond_wait(c, m) SleepConditionVariableSRW(c, m, INFINITE, CONDITION_VARIABLE_LOCKMODE_SHARED)
+#define ggml_cond_broadcast(c) WakeAllConditionVariable(c)
+
+#define ggml_thread_create pthread_create
+#define ggml_thread_join   pthread_join
+
+#else
+
+typedef pthread_cond_t     ggml_cond_t;
+typedef pthread_mutex_t    ggml_mutex_t;
+
+#define ggml_mutex_init(m)          pthread_mutex_init(m, NULL)
+#define ggml_mutex_destroy(m)       pthread_mutex_destroy(m)
+#define ggml_mutex_lock(m)          pthread_mutex_lock(m)
+#define ggml_mutex_unlock(m)        pthread_mutex_unlock(m)
+#define ggml_mutex_lock_shared(m)   pthread_mutex_lock(m)
+#define ggml_mutex_unlock_shared(m) pthread_mutex_unlock(m)
+
+#define ggml_lock_init(x)    UNUSED(x)
+#define ggml_lock_destroy(x) UNUSED(x)
+#if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64))
+#define ggml_lock_lock(x)    _mm_pause()
+#else
+#define ggml_lock_lock(x)    UNUSED(x)
+#endif
+#define ggml_lock_unlock(x)  UNUSED(x)
+
+#define GGML_LOCK_INITIALIZER 0
+#define ggml_cond_init(c)      pthread_cond_init(c, NULL)
+#define ggml_cond_destroy(c)   pthread_cond_destroy(c)
+#define ggml_cond_wait(c, m)   pthread_cond_wait(c, m)
+#define ggml_cond_broadcast(c) pthread_cond_broadcast(c)
+
+#define ggml_thread_create pthread_create
+#define ggml_thread_join   pthread_join
+
+#endif
+
+// Threadpool def
+struct ggml_threadpool {
+    ggml_mutex_t mutex;       // mutex for cond.var
+    ggml_cond_t  cond;        // cond.var for waiting for new work
+
+    struct ggml_cgraph * cgraph;
+    struct ggml_cplan  * cplan;
+
+    // synchronization primitives
+    atomic_int n_graph;       // incremented when there is work to be done (i.e each graph)
+    atomic_int GGML_CACHE_ALIGN n_barrier;
+    atomic_int GGML_CACHE_ALIGN n_barrier_passed;
+    atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
+
+    // these are atomic as an annotation for thread-sanitizer
+    atomic_bool stop;         // Used for stopping the threadpool altogether
+    atomic_bool pause;        // Used for pausing the threadpool or individual threads
+    atomic_bool abort;        // Used for aborting processing of a graph
+
+    struct ggml_compute_state * workers;   // per thread state
+    int          n_threads_max; // number of threads in the pool
+    atomic_int   n_threads_cur; // number of threads used in the current graph
+
+    int32_t      prio;        // Scheduling priority
+    uint32_t     poll;        // Polling level (0 - no polling)
+
+    enum ggml_status ec;
+};
+
+// Per-thread state
+struct ggml_compute_state {
+#ifndef GGML_USE_OPENMP
+    ggml_thread_t thrd;
+    bool cpumask[GGML_MAX_N_THREADS];
+    int  last_graph;
+    bool pending;
+#endif
+    struct ggml_threadpool * threadpool;
+    int ith;
+};
+
+//
+// fundamental operations
+//
+
+inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] + y[i]; }
+inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float   v) { for (int i = 0; i < n; ++i) z[i]  = x[i] + v;    }
+inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i] += x[i];        }
+inline static void ggml_vec_acc1_f32(const int n, float * y, const float   v)                  { for (int i = 0; i < n; ++i) y[i] += v;           }
+inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] - y[i]; }
+inline static void ggml_vec_set_f32 (const int n, float * x, const float   v)                  { for (int i = 0; i < n; ++i) x[i]  = v;           }
+inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = x[i];        }
+inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = -x[i];       }
+inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]*y[i];   }
+inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]/y[i];   }
+
+static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
+   assert(nrc == 1);
+   UNUSED(nrc);
+   UNUSED(bx);
+   UNUSED(by);
+   UNUSED(bs);
+
+#if defined(GGML_SIMD)
+    float sumf = 0.0f;
+    const int np = (n & ~(GGML_F32_STEP - 1));
+
+    GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
+
+    GGML_F32_VEC ax[GGML_F32_ARR];
+    GGML_F32_VEC ay[GGML_F32_ARR];
+
+    for (int i = 0; i < np; i += GGML_F32_STEP) {
+        for (int j = 0; j < GGML_F32_ARR; j++) {
+            ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+
+            sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
+        }
+    }
+
+    // reduce sum0..sum3 to sum0
+    GGML_F32_VEC_REDUCE(sumf, sum);
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        sumf += x[i]*y[i];
+    }
+#else
+    // scalar
+    ggml_float sumf = 0.0;
+    for (int i = 0; i < n; ++i) {
+        sumf += (ggml_float)(x[i]*y[i]);
+    }
+#endif
+
+    *s = sumf;
+}
+
+static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    int i = 0;
+    ggml_float sumf = 0;
+
+#if defined(__AVX512BF16__)
+    __m512 c1 = _mm512_setzero_ps();
+    __m512 c2 = _mm512_setzero_ps();
+    for (; i + 64 <= n; i += 64) {
+        c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
+                             m512bh(_mm512_loadu_si512((y + i))));
+        c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
+                             m512bh(_mm512_loadu_si512((y + i + 32))));
+    }
+    sumf += (ggml_float)_mm512_reduce_add_ps(c1);
+    sumf += (ggml_float)_mm512_reduce_add_ps(c2);
+
+#elif defined(__AVX512F__)
+#define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16))
+    __m512 c1 = _mm512_setzero_ps();
+    __m512 c2 = _mm512_setzero_ps();
+    for (; i + 32 <= n; i += 32) {
+        c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
+        c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2);
+    }
+    sumf += (ggml_float)_mm512_reduce_add_ps(c1);
+    sumf += (ggml_float)_mm512_reduce_add_ps(c2);
+
+#undef LOAD
+#elif defined(__AVX2__) || defined(__AVX__)
+#if defined(__AVX2__)
+#define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16))
+#else
+#define LOAD(p) _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16)), (_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_bsrli_si128(_mm_loadu_si128((const __m128i *)(p)), 8)), 16)), 1))
+#endif
+    __m256 c1 = _mm256_setzero_ps();
+    __m256 c2 = _mm256_setzero_ps();
+    __m256 c3 = _mm256_setzero_ps();
+    __m256 c4 = _mm256_setzero_ps();
+    for (; i + 32 <= n; i += 32) {
+        c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
+        c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2);
+        c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3);
+        c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4);
+    }
+    __m128 g;
+    c1 = _mm256_add_ps(_mm256_add_ps(c1, c3),
+                       _mm256_add_ps(c2, c4));
+    g = _mm_add_ps(_mm256_extractf128_ps(c1, 1),
+                   _mm256_castps256_ps128(c1));
+    g = _mm_add_ps(g, _mm_movehl_ps(g, g));
+    g = _mm_add_ss(g, _mm_movehdup_ps(g));
+    sumf += (ggml_float)_mm_cvtss_f32(g);
+
+#undef LOAD
+#endif
+
+    for (; i < n; ++i) {
+        sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
+                             GGML_BF16_TO_FP32(y[i]));
+    }
+    *s = sumf;
+}
+
+static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    ggml_float sumf = 0.0;
+
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
+
+    GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
+
+    GGML_F16_VEC ax[GGML_F16_ARR];
+    GGML_F16_VEC ay[GGML_F16_ARR];
+
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+
+            sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
+        }
+    }
+
+    // reduce sum0..sum3 to sum0
+    GGML_F16_VEC_REDUCE(sumf, sum);
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
+    }
+#else
+    for (int i = 0; i < n; ++i) {
+        sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
+    }
+#endif
+
+    *s = sumf;
+}
+
+// compute GGML_VEC_DOT_UNROLL dot products at once
+// xs - x row stride in bytes
+inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
+    ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
+
+    ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
+
+    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
+        x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
+    }
+
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
+
+    GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
+
+    GGML_F16_VEC ax[GGML_F16_ARR];
+    GGML_F16_VEC ay[GGML_F16_ARR];
+
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+
+            for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
+                ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
+
+                sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
+            }
+        }
+    }
+
+    // reduce sum0..sum3 to sum0
+    for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
+        GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
+            sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
+        }
+    }
+#else
+    for (int i = 0; i < n; ++i) {
+        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
+            sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
+        }
+    }
+#endif
+
+    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
+        s[i] = sumf[i];
+    }
+}
+
+inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F32_STEP - 1));
+
+    GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
+
+    GGML_F32_VEC ax[GGML_F32_ARR];
+    GGML_F32_VEC ay[GGML_F32_ARR];
+
+    for (int i = 0; i < np; i += GGML_F32_STEP) {
+        for (int j = 0; j < GGML_F32_ARR; j++) {
+            ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
+
+            GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
+        }
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        y[i] += x[i]*v;
+    }
+#else
+    // scalar
+    for (int i = 0; i < n; ++i) {
+        y[i] += x[i]*v;
+    }
+#endif
+}
+
+inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) {
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
+
+    GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
+
+    GGML_F16_VEC ax[GGML_F16_ARR];
+    GGML_F16_VEC ay[GGML_F16_ARR];
+
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+            ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
+
+            GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
+        }
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
+    }
+#else
+    // scalar
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
+    }
+#endif
+}
+
+// xs and vs are byte strides of x and v
+inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
+
+    const float * restrict x[GGML_VEC_MAD_UNROLL];
+    const float * restrict v[GGML_VEC_MAD_UNROLL];
+
+    for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) {
+        x[i] = (const float *) ((const char *) xv + i*xs);
+        v[i] = (const float *) ((const char *) vv + i*vs);
+    }
+
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F32_STEP - 1));
+
+    GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL];
+
+    for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+        vx[k] = GGML_F32_VEC_SET1(v[k][0]);
+    }
+
+    GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR];
+    GGML_F32_VEC ay[GGML_F32_ARR];
+
+    for (int i = 0; i < np; i += GGML_F32_STEP) {
+        for (int j = 0; j < GGML_F32_ARR; j++) {
+            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+
+            for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+                ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR);
+                ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);
+            }
+
+            GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
+        }
+    }
+
+    // leftovers
+    for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+        for (int i = np; i < n; ++i) {
+            y[i] += x[k][i]*v[k][0];
+        }
+    }
+#else
+    // scalar
+    for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+        for (int i = 0; i < n; ++i) {
+            y[i] += x[k][i]*v[k][0];
+        }
+    }
+#endif
+}
+
+//inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) { for (int i = 0; i < n; ++i) y[i] *= v;          }
+inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) {
+#if defined(GGML_USE_ACCELERATE)
+    vDSP_vsmul(y, 1, &v, y, 1, n);
+#elif defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F32_STEP - 1));
+
+    GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
+
+    GGML_F32_VEC ay[GGML_F32_ARR];
+
+    for (int i = 0; i < np; i += GGML_F32_STEP) {
+        for (int j = 0; j < GGML_F32_ARR; j++) {
+            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_MUL(ay[j], vx);
+
+            GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
+        }
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        y[i] *= v;
+    }
+#else
+    // scalar
+    for (int i = 0; i < n; ++i) {
+        y[i] *= v;
+    }
+#endif
+}
+
+inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
+
+    GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
+
+    GGML_F16_VEC ay[GGML_F16_ARR];
+
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+            ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
+
+            GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
+        }
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
+    }
+#else
+    // scalar
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
+    }
+#endif
+}
+
+inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s);   }
+inline static void ggml_vec_sqr_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i];   }
+inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
+inline static void ggml_vec_log_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]);  }
+inline static void ggml_vec_sin_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]);  }
+inline static void ggml_vec_cos_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]);  }
+inline static void ggml_vec_abs_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
+inline static void ggml_vec_sgn_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
+inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
+inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]);  }
+inline static void ggml_vec_elu_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); }
+inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
+inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
+inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }
+// TODO: optimize performance
+inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
+inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
+inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
+
+static const float GELU_COEF_A     = 0.044715f;
+static const float GELU_QUICK_COEF = -1.702f;
+static const float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
+
+inline static float ggml_gelu_f32(float x) {
+    return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    const uint16_t * i16 = (const uint16_t *) x;
+    for (int i = 0; i < n; ++i) {
+        y[i] = ggml_table_gelu_f16[i16[i]];
+    }
+}
+
+#ifdef GGML_GELU_FP16
+inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
+    uint16_t t;
+    for (int i = 0; i < n; ++i) {
+        if (x[i] <= -10.0f) {
+            y[i] = 0.0f;
+        } else if (x[i] >= 10.0f) {
+            y[i] = x[i];
+        } else {
+            ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
+            memcpy(&t, &fp16, sizeof(uint16_t));
+            y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]);
+        }
+    }
+}
+#else
+inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = ggml_gelu_f32(x[i]);
+    }
+}
+#endif
+
+inline static float ggml_gelu_quick_f32(float x) {
+    return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
+}
+
+//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+//    const uint16_t * i16 = (const uint16_t *) x;
+//    for (int i = 0; i < n; ++i) {
+//        y[i] = ggml_table_gelu_quick_f16[i16[i]];
+//    }
+//}
+
+#ifdef GGML_GELU_QUICK_FP16
+inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
+    uint16_t t;
+    for (int i = 0; i < n; ++i) {
+        ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
+        memcpy(&t, &fp16, sizeof(uint16_t));
+        y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]);
+    }
+}
+#else
+inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = ggml_gelu_quick_f32(x[i]);
+    }
+}
+#endif
+
+// Sigmoid Linear Unit (SiLU) function
+inline static float ggml_silu_f32(float x) {
+    return x/(1.0f + expf(-x));
+}
+
+#if __FINITE_MATH_ONLY__
+#error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix"
+#error "ref: https://github.com/ggerganov/llama.cpp/pull/7154#issuecomment-2143844461"
+#endif
+
+#if defined(__ARM_NEON) && defined(__aarch64__)
+
+// adapted from arm limited optimized routine
+// the maximum error is 1.45358 plus 0.5 ulps
+// numbers above 88.38 will flush to infinity
+// numbers beneath -103.97 will flush to zero
+inline static float32x4_t ggml_v_expf(float32x4_t x) {
+    const float32x4_t r = vdupq_n_f32(0x1.8p23f);
+    const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
+    const float32x4_t n = vsubq_f32(z, r);
+    const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
+                                    vdupq_n_f32(0x1.7f7d1cp-20f));
+    const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
+    const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
+    const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
+    const float32x4_t u = vmulq_f32(b, b);
+    const float32x4_t j = vfmaq_f32(
+        vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
+        vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
+                  vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
+    if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
+        return vfmaq_f32(k, j, k);
+    const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
+    const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
+    const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
+    return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
+                     vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
+}
+
+// computes silu x/(1+exp(-x)) in single precision vector
+inline static float32x4_t ggml_v_silu(float32x4_t x) {
+    const float32x4_t one = vdupq_n_f32(1.0f);
+    const float32x4_t zero = vdupq_n_f32(0.0f);
+    const float32x4_t neg_x = vsubq_f32(zero, x);
+    const float32x4_t exp_neg_x = ggml_v_expf(neg_x);
+    const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
+    return vdivq_f32(x, one_plus_exp_neg_x);
+}
+
+#elif defined(__AVX512F__) && defined(__AVX512DQ__)
+
+// adapted from arm limited optimized routine
+// the maximum error is 1.45358 plus 0.5 ulps
+// numbers above 88.38 will flush to infinity
+// numbers beneath -103.97 will flush to zero
+inline static __m512 ggml_v_expf(__m512 x) {
+  const __m512 r = _mm512_set1_ps(0x1.8p23f);
+  const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
+  const __m512 n = _mm512_sub_ps(z, r);
+  const __m512 b =
+      _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
+                       _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
+  const __mmask16 d =
+      _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
+  const __m512 u = _mm512_mul_ps(b, b);
+  const __m512 j = _mm512_fmadd_ps(
+      _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
+                                      _mm512_set1_ps(0x1.573e2ep-5f)),
+                      u,
+                      _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
+                                      _mm512_set1_ps(0x1.fffdb6p-2f))),
+      u,
+      _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
+  const __m512 res = _mm512_scalef_ps(j, n);
+  if (_mm512_kortestz(d, d))
+    return res;
+  const __m512 zero = _mm512_setzero_ps();
+  const __m512 alt = _mm512_mask_blend_ps(
+      _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
+  return _mm512_mask_blend_ps(d, res, alt);
+}
+
+// computes silu x/(1+exp(-x)) in single precision vector
+inline static __m512 ggml_v_silu(__m512 x) {
+    const __m512 one = _mm512_set1_ps(1);
+    const __m512 zero = _mm512_setzero_ps();
+    const __m512 neg_x = _mm512_sub_ps(zero, x);
+    const __m512 exp_neg_x = ggml_v_expf(neg_x);
+    const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
+    return _mm512_div_ps(x, one_plus_exp_neg_x);
+}
+
+#elif defined(__AVX2__) && defined(__FMA__)
+
+// adapted from arm limited optimized routine
+// the maximum error is 1.45358 plus 0.5 ulps
+// numbers above 88.38 will flush to infinity
+// numbers beneath -103.97 will flush to zero
+inline static __m256 ggml_v_expf(__m256 x) {
+  const __m256 r = _mm256_set1_ps(0x1.8p23f);
+  const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
+  const __m256 n = _mm256_sub_ps(z, r);
+  const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
+                                    _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
+  const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
+  const __m256 k = _mm256_castsi256_ps(
+      _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
+  const __m256i c = _mm256_castps_si256(
+      _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
+                    _mm256_set1_ps(126), _CMP_GT_OQ));
+  const __m256 u = _mm256_mul_ps(b, b);
+  const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
+                                                                   _mm256_set1_ps(0x1.573e2ep-5f)), u,
+                                                   _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
+                                                                   _mm256_set1_ps(0x1.fffdb6p-2f))),
+                                   u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
+  if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
+    return _mm256_fmadd_ps(j, k, k);
+  const __m256i g = _mm256_and_si256(
+      _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
+      _mm256_set1_epi32(0x82000000u));
+  const __m256 s1 =
+      _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
+  const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
+  const __m256i d = _mm256_castps_si256(
+      _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
+                    _mm256_set1_ps(192), _CMP_GT_OQ));
+  return _mm256_or_ps(
+      _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
+      _mm256_andnot_ps(
+          _mm256_castsi256_ps(d),
+          _mm256_or_ps(
+              _mm256_and_ps(_mm256_castsi256_ps(c),
+                            _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
+              _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
+}
+
+// computes silu x/(1+exp(-x)) in single precision vector
+inline static __m256 ggml_v_silu(__m256 x) {
+    const __m256 one = _mm256_set1_ps(1);
+    const __m256 zero = _mm256_setzero_ps();
+    const __m256 neg_x = _mm256_sub_ps(zero, x);
+    const __m256 exp_neg_x = ggml_v_expf(neg_x);
+    const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
+    return _mm256_div_ps(x, one_plus_exp_neg_x);
+}
+
+#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
+
+#if defined(__FMA__)
+#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
+#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
+#else
+#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
+#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
+#endif
+
+// adapted from arm limited optimized routine
+// the maximum error is 1.45358 plus 0.5 ulps
+// numbers above 88.38 will flush to infinity
+// numbers beneath -103.97 will flush to zero
+inline static __m128 ggml_v_expf(__m128 x) {
+    const __m128 r = _mm_set1_ps(0x1.8p23f);
+    const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r);
+    const __m128 n = _mm_sub_ps(z, r);
+    const __m128 b =
+        NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x));
+    const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23);
+    const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1))));
+    const __m128i c =
+        _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126)));
+    const __m128 u = _mm_mul_ps(b, b);
+    const __m128 j =
+        MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u,
+                        MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))),
+                u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b));
+    if (!_mm_movemask_epi8(c))
+        return MADD128(j, k, k);
+    const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())),
+                                    _mm_set1_epi32(0x82000000u));
+    const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u)));
+    const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g));
+    const __m128i d =
+        _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192)));
+    return _mm_or_ps(
+        _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)),
+        _mm_andnot_ps(_mm_castsi128_ps(d),
+                      _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)),
+                                _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
+}
+
+// computes silu x/(1+exp(-x)) in single precision vector
+inline static __m128 ggml_v_silu(__m128 x) {
+    const __m128 one = _mm_set1_ps(1);
+    const __m128 zero = _mm_setzero_ps();
+    const __m128 neg_x = _mm_sub_ps(zero, x);
+    const __m128 exp_neg_x = ggml_v_expf(neg_x);
+    const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x);
+    return _mm_div_ps(x, one_plus_exp_neg_x);
+}
+
+#endif // __ARM_NEON / __AVX2__ / __SSE2__
+
+static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
+    int i = 0;
+#if defined(__AVX512F__) && defined(__AVX512DQ__)
+    for (; i + 15 < n; i += 16) {
+        _mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i)));
+    }
+#elif defined(__AVX2__) && defined(__FMA__)
+    for (; i + 7 < n; i += 8) {
+        _mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i)));
+    }
+#elif defined(__SSE2__)
+    for (; i + 3 < n; i += 4) {
+        _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
+    }
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+    for (; i + 3 < n; i += 4) {
+        vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
+    }
+#endif
+    for (; i < n; ++i) {
+        y[i] = ggml_silu_f32(x[i]);
+    }
+}
+
+static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
+    int i = 0;
+    ggml_float sum = 0;
+#if defined(__AVX512F__) && defined(__AVX512DQ__)
+    for (; i + 15 < n; i += 16) {
+        __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
+                                               _mm512_set1_ps(max)));
+        _mm512_storeu_ps(y + i, val);
+        sum += (ggml_float)_mm512_reduce_add_ps(val);
+    }
+#elif defined(__AVX2__) && defined(__FMA__)
+    for (; i + 7 < n; i += 8) {
+        __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
+                                               _mm256_set1_ps(max)));
+        _mm256_storeu_ps(y + i, val);
+        __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
+                                 _mm256_castps256_ps128(val));
+        val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
+        val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
+        sum += (ggml_float)_mm_cvtss_f32(val2);
+    }
+#elif defined(__SSE2__)
+    for (; i + 3 < n; i += 4) {
+        __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
+                                            _mm_set1_ps(max)));
+        _mm_storeu_ps(y + i, val);
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+        val = _mm_add_ps(val, _mm_movehl_ps(val, val));
+        val = _mm_add_ss(val, _mm_movehdup_ps(val));
+#else
+        __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
+        val = _mm_add_ps(val, tmp);
+        tmp = _mm_movehl_ps(tmp, val);
+        val = _mm_add_ss(val, tmp);
+#endif
+        sum += (ggml_float)_mm_cvtss_f32(val);
+    }
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+    for (; i + 3 < n; i += 4) {
+        float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
+                                                vdupq_n_f32(max)));
+        vst1q_f32(y + i, val);
+        sum += (ggml_float)vaddvq_f32(val);
+    }
+#endif
+    for (; i < n; ++i) {
+        float val = expf(x[i] - max);
+        sum += (ggml_float)val;
+        y[i] = val;
+    }
+    return sum;
+}
+
+static ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) {
+    // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i)
+
+    int i = 0;
+    ggml_float sum = 0;
+    for (; i < n; ++i) {
+        float val = x[i] - max;
+        y[i] = val;
+        sum += (ggml_float)expf(val);
+    }
+    return sum = (ggml_float)logf(sum);
+}
+
+inline static float ggml_silu_backward_f32(float x, float dy) {
+    const float s = 1.0f/(1.0f + expf(-x));
+    return dy*s*(1.0f + x*(1.0f - s));
+}
+
+inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
+    for (int i = 0; i < n; ++i) {
+        dx[i] = ggml_silu_backward_f32(x[i], dy[i]);
+    }
+}
+
+inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
+#ifndef GGML_USE_ACCELERATE
+    ggml_float sum = 0.0;
+    for (int i = 0; i < n; ++i) {
+        sum += (ggml_float)x[i];
+    }
+    *s = sum;
+#else
+    vDSP_sve(x, 1, s, n);
+#endif
+}
+
+inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) {
+    ggml_float sum = 0.0;
+    for (int i = 0; i < n; ++i) {
+        sum += (ggml_float)x[i];
+    }
+    *s = sum;
+}
+
+inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_t * x) {
+    float sum = 0.0f;
+    for (int i = 0; i < n; ++i) {
+        sum += GGML_FP16_TO_FP32(x[i]);
+    }
+    *s = sum;
+}
+
+inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) {
+    float sum = 0.0f;
+    for (int i = 0; i < n; ++i) {
+        sum += GGML_BF16_TO_FP32(x[i]);
+    }
+    *s = sum;
+}
+
+inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
+#ifndef GGML_USE_ACCELERATE
+    float max = -INFINITY;
+    for (int i = 0; i < n; ++i) {
+        max = MAX(max, x[i]);
+    }
+    *s = max;
+#else
+    vDSP_maxv(x, 1, s, n);
+#endif
+}
+
+inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) {
+    ggml_vec_norm_f32(n, s, x);
+    *s = 1.f/(*s);
+}
+
+inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) {
+    float max = -INFINITY;
+    int idx = 0;
+    for (int i = 0; i < n; ++i) {
+        max = MAX(max, x[i]);
+        if (max == x[i]) { idx = i; }
+    }
+    *s = idx;
+}
+
+// Helpers for polling loops
+#if defined(__aarch64__) && ( defined(__clang__) || defined(__GNUC__) )
+static inline void ggml_thread_cpu_relax(void) {
+    __asm__ volatile("yield" ::: "memory");
+}
+#elif defined(__x86_64__)
+static inline void ggml_thread_cpu_relax(void) {
+    _mm_pause();
+}
+#else
+static inline void ggml_thread_cpu_relax(void) {;}
+#endif
+
+//
+// NUMA support
+//
+
+#define GGML_NUMA_MAX_NODES 8
+#define GGML_NUMA_MAX_CPUS 512
+
+struct ggml_numa_node {
+    uint32_t cpus[GGML_NUMA_MAX_CPUS]; // hardware threads on this node
+    uint32_t n_cpus;
+};
+
+struct ggml_numa_nodes {
+    enum ggml_numa_strategy numa_strategy;
+    struct ggml_numa_node nodes[GGML_NUMA_MAX_NODES];
+    uint32_t n_nodes;
+    uint32_t total_cpus; // hardware threads on system
+    uint32_t current_node; // node on which main process is execting
+#if defined(__gnu_linux__)
+    cpu_set_t cpuset; // cpuset from numactl
+#else
+    uint32_t cpuset; // no NUMA support outside of Linux at this time. Use a portable datatype
+#endif
+};
+
+//
+// ggml state
+//
+
+struct ggml_state {
+    struct ggml_numa_nodes numa;
+};
+
+static struct ggml_state g_state = {0};
+
+void ggml_barrier(struct ggml_threadpool * tp) {
+    int n_threads = atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed);
+    if (n_threads == 1) {
+        return;
+    }
+
+#ifdef GGML_USE_OPENMP
+    #pragma omp barrier
+#else
+    int n_passed = atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed);
+
+    // enter barrier (full seq-cst fence)
+    int n_barrier = atomic_fetch_add_explicit(&tp->n_barrier, 1, memory_order_seq_cst);
+
+    if (n_barrier == (n_threads - 1)) {
+        // last thread
+        atomic_store_explicit(&tp->n_barrier, 0, memory_order_relaxed);
+
+        // exit barrier (fill seq-cst fence)
+        atomic_fetch_add_explicit(&tp->n_barrier_passed, 1, memory_order_seq_cst);
+        return;
+    }
+
+    // wait for other threads
+    while (atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed) == n_passed) {
+        ggml_thread_cpu_relax();
+    }
+
+    // exit barrier (full seq-cst fence)
+    // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead
+    #ifdef GGML_TSAN_ENABLED
+    atomic_fetch_add_explicit(&tp->n_barrier_passed, 0, memory_order_seq_cst);
+    #else
+    atomic_thread_fence(memory_order_seq_cst);
+    #endif
+#endif
+}
+
+#if defined(__gnu_linux__)
+static cpu_set_t ggml_get_numa_affinity(void) {
+    cpu_set_t cpuset;
+    pthread_t thread;
+    thread = pthread_self();
+    CPU_ZERO(&cpuset);
+    pthread_getaffinity_np(thread, sizeof(cpu_set_t), &cpuset);
+    return cpuset;
+}
+#else
+static uint32_t ggml_get_numa_affinity(void) {
+    return 0; // no NUMA support
+}
+#endif
+
+void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
+    if (g_state.numa.n_nodes > 0) {
+        fprintf(stderr, "ggml_numa_init: NUMA already initialized\n");
+
+        return;
+    }
+
+#if defined(__gnu_linux__)
+    struct stat st;
+    char path[256];
+    int rv;
+
+    // set numa scheme
+    g_state.numa.numa_strategy = numa_flag;
+
+    GGML_PRINT_DEBUG("numa strategy %u\n",g_state.numa.numa_strategy);
+
+    g_state.numa.cpuset = ggml_get_numa_affinity();
+
+    // enumerate nodes
+    while (g_state.numa.n_nodes < GGML_NUMA_MAX_NODES) {
+        rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u", g_state.numa.n_nodes);
+        GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
+        if (stat(path, &st) != 0) { break; }
+        ++g_state.numa.n_nodes;
+    }
+
+    // enumerate CPUs
+    while (g_state.numa.total_cpus < GGML_NUMA_MAX_CPUS) {
+        rv = snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%u", g_state.numa.total_cpus);
+        GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
+        if (stat(path, &st) != 0) { break; }
+        ++g_state.numa.total_cpus;
+    }
+
+    GGML_PRINT_DEBUG("found %u numa nodes, %u CPUs\n", g_state.numa.n_nodes, g_state.numa.total_cpus);
+
+    // figure out which node we're on
+    uint current_cpu;
+    int getcpu_ret = 0;
+#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 33) || defined(__COSMOPOLITAN__)
+    getcpu_ret = getcpu(&current_cpu, &g_state.numa.current_node);
+#else
+    // old glibc doesn't have a wrapper for this call. Fall back on direct syscall
+#   if !defined(SYS_getcpu) && defined(SYS_get_cpu)
+#       define SYS_getcpu SYS_get_cpu // some older glibc versions use this name
+#   endif
+    getcpu_ret = syscall(SYS_getcpu, &current_cpu, &g_state.numa.current_node);
+#endif
+
+    if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1 || getcpu_ret != 0) {
+        g_state.numa.n_nodes = 0;
+        return;
+    }
+
+    GGML_PRINT_DEBUG("found our process on numa node %u, CPU %u\n", g_state.numa.current_node, current_cpu);
+
+    for (uint32_t n = 0; n < g_state.numa.n_nodes; ++n) {
+        struct ggml_numa_node * node = &g_state.numa.nodes[n];
+        GGML_PRINT_DEBUG("CPUs on node %u:", n);
+        node->n_cpus = 0;
+        for (uint32_t c = 0; c < g_state.numa.total_cpus; ++c) {
+            rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u/cpu%u", n, c);
+            GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
+            if (stat(path, &st) == 0) {
+                node->cpus[node->n_cpus++] = c;
+                GGML_PRINT_DEBUG(" %u", c);
+            }
+        }
+        GGML_PRINT_DEBUG("\n");
+    }
+
+    if (ggml_is_numa()) {
+        FILE *fptr = fopen("/proc/sys/kernel/numa_balancing", "r");
+        if (fptr != NULL) {
+            char buf[42];
+            if (fgets(buf, sizeof(buf), fptr) && strncmp(buf, "0\n", sizeof(buf)) != 0) {
+                GGML_LOG_WARN("/proc/sys/kernel/numa_balancing is enabled, this has been observed to impair performance\n");
+            }
+            fclose(fptr);
+        }
+    }
+#else
+    UNUSED(numa_flag);
+    // TODO
+#endif
+}
+
+bool ggml_is_numa(void) {
+    return g_state.numa.n_nodes > 1;
+}
+
+#if defined(__ARM_ARCH)
+
+#if defined(__linux__) && defined(__aarch64__)
+#include <sys/auxv.h>
+#elif defined(__APPLE__)
+#include <sys/sysctl.h>
+#endif
+
+#if !defined(HWCAP2_I8MM)
+#define HWCAP2_I8MM 0
+#endif
+
+static void ggml_init_arm_arch_features(void) {
+#if defined(__linux__) && defined(__aarch64__)
+    uint32_t hwcap = getauxval(AT_HWCAP);
+    uint32_t hwcap2 = getauxval(AT_HWCAP2);
+
+    ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
+    ggml_arm_arch_features.has_dotprod = !!(hwcap && HWCAP_ASIMDDP);
+    ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
+    ggml_arm_arch_features.has_sve  = !!(hwcap & HWCAP_SVE);
+
+#if defined(__ARM_FEATURE_SVE)
+    ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
+#endif
+#elif defined(__APPLE__)
+    int oldp = 0;
+    size_t size = sizeof(oldp);
+    if (sysctlbyname("hw.optional.AdvSIMD", &oldp, &size, NULL, 0) != 0) {
+        oldp = 0;
+    }
+    ggml_arm_arch_features.has_neon = oldp;
+
+    if (sysctlbyname("hw.optional.arm.FEAT_DotProd", &oldp, &size, NULL, 0) != 0) {
+        oldp = 0;
+    }
+    ggml_arm_arch_features.has_dotprod = oldp;
+
+    if (sysctlbyname("hw.optional.arm.FEAT_I8MM", &oldp, &size, NULL, 0) != 0) {
+        oldp = 0;
+    }
+    ggml_arm_arch_features.has_i8mm = oldp;
+
+    ggml_arm_arch_features.has_sve = 0;
+    ggml_arm_arch_features.sve_cnt = 0;
+#else
+// Run-time CPU feature detection not implemented for this platform, fallback to compile time
+#if defined(__ARM_NEON)
+    ggml_arm_arch_features.has_neon = 1;
+#else
+    ggml_arm_arch_features.has_neon = 0;
+#endif
+
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    ggml_arm_arch_features.has_i8mm = 1;
+#else
+    ggml_arm_arch_features.has_i8mm = 0;
+#endif
+
+#if defined(__ARM_FEATURE_SVE)
+    ggml_arm_arch_features.has_sve = 1;
+    ggml_arm_arch_features.sve_cnt = 16;
+#else
+    ggml_arm_arch_features.has_sve = 0;
+    ggml_arm_arch_features.sve_cnt = 0;
+#endif
+#endif
+}
+#endif
+
+struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
+    GGML_ASSERT(!ggml_get_no_alloc(ctx));
+
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
+
+    ggml_set_i32(result, value);
+
+    return result;
+}
+
+struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) {
+    GGML_ASSERT(!ggml_get_no_alloc(ctx));
+
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+
+    ggml_set_f32(result, value);
+
+    return result;
+}
+
+struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
+    const int n     = ggml_nrows(tensor);
+    const int nc    = tensor->ne[0];
+    const size_t n1 = tensor->nb[1];
+
+    char * const data = tensor->data;
+
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            {
+                assert(tensor->nb[0] == sizeof(int8_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_I16:
+            {
+                assert(tensor->nb[0] == sizeof(int16_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_I32:
+            {
+                assert(tensor->nb[0] == sizeof(int32_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_F16:
+            {
+                assert(tensor->nb[0] == sizeof(ggml_fp16_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
+                }
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                assert(tensor->nb[0] == sizeof(ggml_fp16_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
+                }
+            } break;
+        case GGML_TYPE_F32:
+            {
+                assert(tensor->nb[0] == sizeof(float));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
+                }
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+
+    return tensor;
+}
+
+struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
+    const int n     = ggml_nrows(tensor);
+    const int nc    = tensor->ne[0];
+    const size_t n1 = tensor->nb[1];
+
+    char * const data = tensor->data;
+
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            {
+                assert(tensor->nb[0] == sizeof(int8_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_I16:
+            {
+                assert(tensor->nb[0] == sizeof(int16_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_I32:
+            {
+                assert(tensor->nb[0] == sizeof(int32_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_F16:
+            {
+                assert(tensor->nb[0] == sizeof(ggml_fp16_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
+                }
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                assert(tensor->nb[0] == sizeof(ggml_bf16_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
+                }
+            } break;
+        case GGML_TYPE_F32:
+            {
+                assert(tensor->nb[0] == sizeof(float));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
+                }
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+
+    return tensor;
+}
+
+int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
+    if (!ggml_is_contiguous(tensor)) {
+        int64_t id[4] = { 0, 0, 0, 0 };
+        ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
+        return ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]);
+    }
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
+                return ((int8_t *)(tensor->data))[i];
+            }
+        case GGML_TYPE_I16:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
+                return ((int16_t *)(tensor->data))[i];
+            }
+        case GGML_TYPE_I32:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
+                return ((int32_t *)(tensor->data))[i];
+            }
+        case GGML_TYPE_F16:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
+                return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
+            }
+        case GGML_TYPE_BF16:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
+                return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
+            }
+        case GGML_TYPE_F32:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(float));
+                return ((float *)(tensor->data))[i];
+            }
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
+    if (!ggml_is_contiguous(tensor)) {
+        int64_t id[4] = { 0, 0, 0, 0 };
+        ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
+        ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value);
+        return;
+    }
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
+                ((int8_t *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_I16:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
+                ((int16_t *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_I32:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
+                ((int32_t *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_F16:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
+                ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
+                ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(float));
+                ((float *)(tensor->data))[i] = value;
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
+    void * data   = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            return ((int8_t *) data)[0];
+        case GGML_TYPE_I16:
+            return ((int16_t *) data)[0];
+        case GGML_TYPE_I32:
+            return ((int32_t *) data)[0];
+        case GGML_TYPE_F16:
+            return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
+        case GGML_TYPE_BF16:
+            return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
+        case GGML_TYPE_F32:
+            return ((float *) data)[0];
+        default:
+            GGML_ABORT("fatal error");
+    }
+}
+
+void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) {
+    void * data   = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            {
+                ((int8_t *)(data))[0] = value;
+            } break;
+        case GGML_TYPE_I16:
+            {
+                ((int16_t *)(data))[0] = value;
+            } break;
+        case GGML_TYPE_I32:
+            {
+                ((int32_t *)(data))[0] = value;
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ((float *)(data))[0] = value;
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
+    if (!ggml_is_contiguous(tensor)) {
+        int64_t id[4] = { 0, 0, 0, 0 };
+        ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
+        return ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]);
+    }
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            {
+                return ((int8_t *)(tensor->data))[i];
+            }
+        case GGML_TYPE_I16:
+            {
+                return ((int16_t *)(tensor->data))[i];
+            }
+        case GGML_TYPE_I32:
+            {
+                return ((int32_t *)(tensor->data))[i];
+            }
+        case GGML_TYPE_F16:
+            {
+                return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
+            }
+        case GGML_TYPE_BF16:
+            {
+                return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
+            }
+        case GGML_TYPE_F32:
+            {
+                return ((float *)(tensor->data))[i];
+            }
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
+    if (!ggml_is_contiguous(tensor)) {
+        int64_t id[4] = { 0, 0, 0, 0 };
+        ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
+        ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value);
+        return;
+    }
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            {
+                ((int8_t *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_I16:
+            {
+                ((int16_t *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_I32:
+            {
+                ((int32_t *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ((float *)(tensor->data))[i] = value;
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
+    void * data   = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            return ((int8_t *) data)[0];
+        case GGML_TYPE_I16:
+            return ((int16_t *) data)[0];
+        case GGML_TYPE_I32:
+            return ((int32_t *) data)[0];
+        case GGML_TYPE_F16:
+            return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
+        case GGML_TYPE_BF16:
+            return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
+        case GGML_TYPE_F32:
+            return ((float *) data)[0];
+        default:
+            GGML_ABORT("fatal error");
+    }
+}
+
+void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) {
+    void * data   = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            {
+                ((int8_t *)(data))[0] = value;
+            } break;
+        case GGML_TYPE_I16:
+            {
+                ((int16_t *)(data))[0] = value;
+            } break;
+        case GGML_TYPE_I32:
+            {
+                ((int32_t *)(data))[0] = value;
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ((float *)(data))[0] = value;
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+// ggml_compute_forward_dup
+
+static void ggml_compute_forward_dup_same_cont(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+    GGML_ASSERT(src0->type == dst->type);
+
+    const size_t nb0 = ggml_type_size(src0->type);
+
+    const int ith = params->ith; // thread index
+    const int nth = params->nth; // number of threads
+
+    // parallelize by elements
+    const int ne = ggml_nelements(dst);
+    const int dr = (ne + nth - 1) / nth;
+    const int ie0 = dr * ith;
+    const int ie1 = MIN(ie0 + dr, ne);
+
+    if (ie0 < ie1) {
+        memcpy(
+            ((char *)  dst->data + ie0*nb0),
+            ((char *) src0->data + ie0*nb0),
+            (ie1 - ie0) * nb0);
+    }
+}
+
+static void ggml_compute_forward_dup_f16(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    const int ith = params->ith; // thread index
+    const int nth = params->nth; // number of threads
+
+    // parallelize by rows
+    const int nr = ne01;
+    // number of rows per thread
+    const int dr = (nr + nth - 1) / nth;
+    // row range for this thread
+    const int ir0 = dr * ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    if (src0->type == dst->type &&
+        ne00 == ne0 &&
+        nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
+        // copy by rows
+        const size_t rs = ne00*nb00;
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    memcpy(
+                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
+                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
+                        rs);
+                }
+            }
+        }
+        return;
+    }
+
+    // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
+
+    if (ggml_is_contiguous(dst)) {
+        if (nb00 == sizeof(ggml_fp16_t)) {
+            if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                const size_t rs = ne00 * nb00;
+                char * dst_ptr = (char *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                            memcpy(dst_ptr + id, src0_ptr, rs);
+                            id += rs;
+                        }
+                        id += rs * (ne01 - ir1);
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                float * dst_ptr = (float *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]);
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
+                ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
+                float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
+
+                size_t id = 0;
+                size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
+                char * dst_ptr = (char *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
+                            }
+
+                            quantize_row_q(src0_f32, dst_ptr + id, ne00);
+                            id += rs;
+                        }
+                        id += rs * (ne01 - ir1);
+                    }
+                }
+            } else {
+                GGML_ABORT("fatal error"); // TODO: implement
+            }
+        } else {
+            //printf("%s: this is not optimal - fix me\n", __func__);
+
+            if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                float * dst_ptr = (float *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = *src0_ptr;
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else {
+                GGML_ABORT("fatal error"); // TODO: implement
+            }
+        }
+        return;
+    }
+
+    // dst counters
+    int64_t i10 = 0;
+    int64_t i11 = 0;
+    int64_t i12 = 0;
+    int64_t i13 = 0;
+
+    if (dst->type == GGML_TYPE_F16) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
+
+                        if (++i10 == ne00) {
+                            i10 = 0;
+                            if (++i11 == ne01) {
+                                i11 = 0;
+                                if (++i12 == ne02) {
+                                    i12 = 0;
+                                    if (++i13 == ne03) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else if (dst->type == GGML_TYPE_F32) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
+
+                        if (++i10 == ne0) {
+                            i10 = 0;
+                            if (++i11 == ne1) {
+                                i11 = 0;
+                                if (++i12 == ne2) {
+                                    i12 = 0;
+                                    if (++i13 == ne3) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else {
+        GGML_ABORT("fatal error"); // TODO: implement
+    }
+}
+
+static void ggml_compute_forward_dup_bf16(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    const int ith = params->ith; // thread index
+    const int nth = params->nth; // number of threads
+
+    // parallelize by rows
+    const int nr = ne01;
+    // number of rows per thread
+    const int dr = (nr + nth - 1) / nth;
+    // row range for this thread
+    const int ir0 = dr * ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    if (src0->type == dst->type &&
+        ne00 == ne0 &&
+        nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
+        // copy by rows
+        const size_t rs = ne00*nb00;
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    memcpy(
+                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
+                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
+                        rs);
+                }
+            }
+        }
+        return;
+    }
+
+    // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
+
+    if (ggml_is_contiguous(dst)) {
+        if (nb00 == sizeof(ggml_bf16_t)) {
+            if (dst->type == GGML_TYPE_BF16) {
+                size_t id = 0;
+                const size_t rs = ne00 * nb00;
+                char * dst_ptr = (char *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                            memcpy(dst_ptr + id, src0_ptr, rs);
+                            id += rs;
+                        }
+                        id += rs * (ne01 - ir1);
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                float * dst_ptr = (float *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
+                ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
+                float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
+
+                size_t id = 0;
+                size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
+                char * dst_ptr = (char *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
+                            }
+
+                            quantize_row_q(src0_f32, dst_ptr + id, ne00);
+                            id += rs;
+                        }
+                        id += rs * (ne01 - ir1);
+                    }
+                }
+            } else {
+                GGML_ABORT("fatal error"); // TODO: implement
+            }
+        } else {
+            //printf("%s: this is not optimal - fix me\n", __func__);
+
+            if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                float * dst_ptr = (float *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_BF16) {
+                size_t id = 0;
+                ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = *src0_ptr;
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else {
+                GGML_ABORT("fatal error"); // TODO: implement
+            }
+        }
+        return;
+    }
+
+    // dst counters
+    int64_t i10 = 0;
+    int64_t i11 = 0;
+    int64_t i12 = 0;
+    int64_t i13 = 0;
+
+    if (dst->type == GGML_TYPE_BF16) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
+
+                        if (++i10 == ne00) {
+                            i10 = 0;
+                            if (++i11 == ne01) {
+                                i11 = 0;
+                                if (++i12 == ne02) {
+                                    i12 = 0;
+                                    if (++i13 == ne03) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else if (dst->type == GGML_TYPE_F16) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
+
+                        if (++i10 == ne0) {
+                            i10 = 0;
+                            if (++i11 == ne1) {
+                                i11 = 0;
+                                if (++i12 == ne2) {
+                                    i12 = 0;
+                                    if (++i13 == ne3) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else if (dst->type == GGML_TYPE_F32) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
+
+                        if (++i10 == ne0) {
+                            i10 = 0;
+                            if (++i11 == ne1) {
+                                i11 = 0;
+                                if (++i12 == ne2) {
+                                    i12 = 0;
+                                    if (++i13 == ne3) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else {
+        GGML_ABORT("fatal error"); // TODO: implement
+    }
+}
+
+static void ggml_compute_forward_dup_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    const int ith = params->ith; // thread index
+    const int nth = params->nth; // number of threads
+
+    // parallelize by rows
+    const int nr = ne01;
+    // number of rows per thread
+    const int dr = (nr + nth - 1) / nth;
+    // row range for this thread
+    const int ir0 = dr * ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    if (src0->type == dst->type &&
+        ne00 == ne0 &&
+        nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
+        // copy by rows
+        const size_t rs = ne00*nb00;
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    memcpy(
+                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
+                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
+                        rs);
+                }
+            }
+        }
+        return;
+    }
+
+    if (ggml_is_contiguous(dst)) {
+        // TODO: simplify
+        if (nb00 == sizeof(float)) {
+            if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                const size_t rs = ne00 * nb00;
+                char * dst_ptr = (char *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                            memcpy(dst_ptr + id, src0_ptr, rs);
+                            id += rs;
+                        }
+                        id += rs * (ne01 - ir1);
+                    }
+                }
+            } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
+                ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
+
+                size_t id = 0;
+                size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
+                char * dst_ptr = (char *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+                            quantize_row_q(src0_ptr, dst_ptr + id, ne00);
+                            id += rs;
+                        }
+                        id += rs * (ne01 - ir1);
+                    }
+                }
+            } else {
+                GGML_ABORT("fatal error"); // TODO: implement
+            }
+        } else {
+            //printf("%s: this is not optimal - fix me\n", __func__);
+
+            if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                float * dst_ptr = (float *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = *src0_ptr;
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_BF16) {
+                size_t id = 0;
+                ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else {
+                GGML_ABORT("fatal error"); // TODO: implement
+            }
+        }
+
+        return;
+    }
+
+    // dst counters
+
+    int64_t i10 = 0;
+    int64_t i11 = 0;
+    int64_t i12 = 0;
+    int64_t i13 = 0;
+
+    if (dst->type == GGML_TYPE_F32) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        memcpy(dst_ptr, src0_ptr, sizeof(float));
+
+                        if (++i10 == ne0) {
+                            i10 = 0;
+                            if (++i11 == ne1) {
+                                i11 = 0;
+                                if (++i12 == ne2) {
+                                    i12 = 0;
+                                    if (++i13 == ne3) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else if (dst->type == GGML_TYPE_F16) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
+
+                        if (++i10 == ne0) {
+                            i10 = 0;
+                            if (++i11 == ne1) {
+                                i11 = 0;
+                                if (++i12 == ne2) {
+                                    i12 = 0;
+                                    if (++i13 == ne3) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else if (dst->type == GGML_TYPE_BF16) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
+
+                        if (++i10 == ne0) {
+                            i10 = 0;
+                            if (++i11 == ne1) {
+                                i11 = 0;
+                                if (++i12 == ne2) {
+                                    i12 = 0;
+                                    if (++i13 == ne3) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else {
+        GGML_ABORT("fatal error"); // TODO: implement
+    }
+}
+
+// A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
+static void ggml_compute_forward_dup_bytes(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+    GGML_ASSERT(src0->type == dst->type);
+
+    GGML_TENSOR_UNARY_OP_LOCALS;
+
+    if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
+        ggml_compute_forward_dup_same_cont(params, dst);
+        return;
+    }
+
+    const size_t type_size = ggml_type_size(src0->type);
+    const int ith = params->ith; // thread index
+    const int nth = params->nth; // number of threads
+
+
+    // parallelize by rows
+    const int nr = ne01;
+    // number of rows per thread
+    const int dr = (nr + nth - 1) / nth;
+    // row range for this thread
+    const int ir0 = dr * ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    if (src0->type == dst->type &&
+        ne00 == ne0 &&
+        nb00 == type_size && nb0 == type_size) {
+        // copy by rows
+        const size_t rs = ne00 * type_size;
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    memcpy(
+                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
+                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
+                        rs);
+                }
+            }
+        }
+        return;
+    }
+
+    if (ggml_is_contiguous(dst)) {
+        size_t id = 0;
+        char * dst_ptr = (char *) dst->data;
+        const size_t rs = ne00 * type_size;
+
+        if (nb00 == type_size) {
+            // src0 is contigous on first dimension, copy by rows
+            for (int64_t i03 = 0; i03 < ne03; i03++) {
+                for (int64_t i02 = 0; i02 < ne02; i02++) {
+                    id += rs * ir0;
+                    for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                        const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                        memcpy(dst_ptr + id, src0_ptr, rs);
+                        id += rs;
+                    }
+                    id += rs * (ne01 - ir1);
+                }
+            }
+        } else {
+            //printf("%s: this is not optimal - fix me\n", __func__);
+
+            for (int64_t i03 = 0; i03 < ne03; i03++) {
+                for (int64_t i02 = 0; i02 < ne02; i02++) {
+                    id += rs * ir0;
+                    for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                        for (int64_t i00 = 0; i00 < ne00; i00++) {
+                            const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
+                            memcpy(dst_ptr + id, src0_ptr, type_size);
+
+                            id += type_size;
+                        }
+                    }
+                    id += rs * (ne01 - ir1);
+                }
+            }
+        }
+
+        return;
+    }
+
+    // dst counters
+
+    int64_t i10 = 0;
+    int64_t i11 = 0;
+    int64_t i12 = 0;
+    int64_t i13 = 0;
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            i10 += ne00 * ir0;
+            while (i10 >= ne0) {
+                i10 -= ne0;
+                if (++i11 == ne1) {
+                    i11 = 0;
+                    if (++i12 == ne2) {
+                        i12 = 0;
+                        if (++i13 == ne3) {
+                            i13 = 0;
+                        }
+                    }
+                }
+            }
+            for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                for (int64_t i00 = 0; i00 < ne00; i00++) {
+                    const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                          char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                    memcpy(dst_ptr, src0_ptr, type_size);
+
+                    if (++i10 == ne0) {
+                        i10 = 0;
+                        if (++i11 == ne1) {
+                            i11 = 0;
+                            if (++i12 == ne2) {
+                                i12 = 0;
+                                if (++i13 == ne3) {
+                                    i13 = 0;
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+            i10 += ne00 * (ne01 - ir1);
+            while (i10 >= ne0) {
+                i10 -= ne0;
+                if (++i11 == ne1) {
+                    i11 = 0;
+                    if (++i12 == ne2) {
+                        i12 = 0;
+                        if (++i13 == ne3) {
+                            i13 = 0;
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_dup(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (src0->type == dst->type) {
+        ggml_compute_forward_dup_bytes(params, dst);
+        return;
+    }
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_dup_f16(params, dst);
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                ggml_compute_forward_dup_bf16(params, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_dup_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_add
+
+static void ggml_compute_forward_add_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT( nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    if (nb10 == sizeof(float)) {
+        for (int ir = ir0; ir < ir1; ++ir) {
+            // src1 is broadcastable across src0 and dst in i1, i2, i3
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+            const int64_t nr0 = ne00 / ne10;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+            float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
+
+            for (int64_t r = 0; r < nr0; ++r) {
+#ifdef GGML_USE_ACCELERATE
+                vDSP_vadd(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
+#else
+                ggml_vec_add_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
+#endif
+            }
+        }
+    } else {
+        // src1 is not contiguous
+        for (int ir = ir0; ir < ir1; ++ir) {
+            // src1 is broadcastable across src0 and dst in i1, i2, i3
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                const int64_t i10 = i0 % ne10;
+                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
+
+                dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_add_f16_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    if (dst->type == GGML_TYPE_F32) {
+        GGML_ASSERT( nb0 == sizeof(float));
+    }
+    else {
+        GGML_ASSERT(dst->type  == GGML_TYPE_F16);
+        GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+    }
+
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    if (nb10 == sizeof(float)) {
+        if (dst->type == GGML_TYPE_F16) {
+            for (int ir = ir0; ir < ir1; ++ir) {
+                // src0, src1 and dst are same shape => same indices
+                const int i3 = ir/(ne2*ne1);
+                const int i2 = (ir - i3*ne2*ne1)/ne1;
+                const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+                ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
+                ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+                float *       src1_ptr = (float *)       ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
+
+                for (int i = 0; i < ne0; i++) {
+                    dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
+                }
+            }
+        } else {
+            for (int ir = ir0; ir < ir1; ++ir) {
+                // src0, src1 and dst are same shape => same indices
+                const int i3 = ir/(ne2*ne1);
+                const int i2 = (ir - i3*ne2*ne1)/ne1;
+                const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+                float *       dst_ptr  = (float *)       ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
+                ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+                float *       src1_ptr = (float *)       ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
+
+                for (int i = 0; i < ne0; i++) {
+                    dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
+                }
+            }
+        }
+    }
+    else {
+        // src1 is not contiguous
+        GGML_ABORT("fatal error");
+    }
+}
+
+static void ggml_compute_forward_add_bf16_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    if (dst->type == GGML_TYPE_F32) {
+        GGML_ASSERT( nb0 == sizeof(float));
+    }
+    else {
+        GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
+        GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
+    }
+
+    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    if (nb10 == sizeof(float)) {
+        if (dst->type == GGML_TYPE_BF16) {
+            for (int ir = ir0; ir < ir1; ++ir) {
+                // src0, src1 and dst are same shape => same indices
+                const int i3 = ir/(ne2*ne1);
+                const int i2 = (ir - i3*ne2*ne1)/ne1;
+                const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+                ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
+                ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+                float *       src1_ptr = (float *)       ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
+
+                for (int i = 0; i < ne0; i++) {
+                    dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
+                }
+            }
+        } else {
+            for (int ir = ir0; ir < ir1; ++ir) {
+                // src0, src1 and dst are same shape => same indices
+                const int i3 = ir/(ne2*ne1);
+                const int i2 = (ir - i3*ne2*ne1)/ne1;
+                const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+                float *       dst_ptr  = (float *)       ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
+                ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+                float *       src1_ptr = (float *)       ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
+
+                for (int i = 0; i < ne0; i++) {
+                    dst_ptr[i] = GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
+                }
+            }
+        }
+    }
+    else {
+        // src1 is not contiguous
+        GGML_ABORT("fatal error");
+    }
+}
+
+static void ggml_compute_forward_add_f16_f16(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F16);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
+
+    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    if (nb10 == sizeof(ggml_fp16_t)) {
+        for (int ir = ir0; ir < ir1; ++ir) {
+            // src0, src1 and dst are same shape => same indices
+            const int i3 = ir/(ne2*ne1);
+            const int i2 = (ir - i3*ne2*ne1)/ne1;
+            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+            ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
+            ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+            ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
+
+            for (int i = 0; i < ne0; i++) {
+                dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(src1_ptr[i]));
+            }
+        }
+    }
+    else {
+        // src1 is not contiguous
+        GGML_ABORT("fatal error");
+    }
+}
+
+static void ggml_compute_forward_add_bf16_bf16(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
+    GGML_ASSERT(src1->type == GGML_TYPE_BF16);
+    GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
+
+    GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    if (nb10 == sizeof(ggml_bf16_t)) {
+        for (int ir = ir0; ir < ir1; ++ir) {
+            // src0, src1 and dst are same shape => same indices
+            const int i3 = ir/(ne2*ne1);
+            const int i2 = (ir - i3*ne2*ne1)/ne1;
+            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+            ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
+            ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+            ggml_bf16_t * src1_ptr = (ggml_bf16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
+
+            for (int i = 0; i < ne0; i++) {
+                dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + GGML_BF16_TO_FP32(src1_ptr[i]));
+            }
+        }
+    }
+    else {
+        // src1 is not contiguous
+        GGML_ABORT("fatal error");
+    }
+}
+
+static void ggml_compute_forward_add_q_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const enum ggml_type type = src0->type;
+    const enum ggml_type dtype = dst->type;
+    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
+    ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dtype)->from_float;
+
+    // we don't support permuted src0 or src1
+    GGML_ASSERT(nb00 == ggml_type_size(type));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    GGML_ASSERT(ggml_is_quantized(src0->type));
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 indices
+        const int i03 = ir/(ne02*ne01);
+        const int i02 = (ir - i03*ne02*ne01)/ne01;
+        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+        // src1 and dst are same shape as src0 => same indices
+        const int i13 = i03;
+        const int i12 = i02;
+        const int i11 = i01;
+
+        const int i3 = i03;
+        const int i2 = i02;
+        const int i1 = i01;
+
+        void  * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+        float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
+        void  * dst_row  = (void *) ((char *)  dst->data + ( i1*nb1  +  i2*nb2  +  i3*nb3));
+
+        assert(ne00 % 32 == 0);
+
+        // unquantize row from src0 to temp buffer
+        dequantize_row_q(src0_row, wdata, ne00);
+        // add src1
+        ggml_vec_acc_f32(ne00, wdata, src1_row);
+        // quantize row to dst
+        if (quantize_row_q != NULL) {
+            quantize_row_q(wdata, dst_row, ne00);
+        } else {
+            memcpy(dst_row, wdata, ne0*nb0);
+        }
+    }
+}
+
+static void ggml_compute_forward_add(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                if (src1->type == GGML_TYPE_F32) {
+                    ggml_compute_forward_add_f32(params, dst);
+                }
+                else {
+                    GGML_ABORT("fatal error");
+                }
+            } break;
+        case GGML_TYPE_F16:
+            {
+                if (src1->type == GGML_TYPE_F16) {
+                    ggml_compute_forward_add_f16_f16(params, dst);
+                }
+                else if (src1->type == GGML_TYPE_F32) {
+                    ggml_compute_forward_add_f16_f32(params, dst);
+                }
+                else {
+                    GGML_ABORT("fatal error");
+                }
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                if (src1->type == GGML_TYPE_BF16) {
+                    ggml_compute_forward_add_bf16_bf16(params, dst);
+                }
+                else if (src1->type == GGML_TYPE_F32) {
+                    ggml_compute_forward_add_bf16_f32(params, dst);
+                }
+                else {
+                    GGML_ABORT("fatal error");
+                }
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_TQ1_0:
+        case GGML_TYPE_TQ2_0:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
+        case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ2_S:
+        case GGML_TYPE_Q4_0_4_4:
+        case GGML_TYPE_Q4_0_4_8:
+        case GGML_TYPE_Q4_0_8_8:
+            {
+                ggml_compute_forward_add_q_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_add1
+
+static void ggml_compute_forward_add1_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_scalar(src1));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT( nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are same shape => same indices
+        const int i3 = ir/(ne2*ne1);
+        const int i2 = (ir - i3*ne2*ne1)/ne1;
+        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+#ifdef GGML_USE_ACCELERATE
+        UNUSED(ggml_vec_add1_f32);
+
+        vDSP_vadd(
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
+                (float *) ((char *) src1->data), 0,
+                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), 1,
+                ne0);
+#else
+        ggml_vec_add1_f32(ne0,
+                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
+               *(float *) src1->data);
+#endif
+    }
+}
+
+static void ggml_compute_forward_add1_f16_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_scalar(src1));
+
+    // scalar to add
+    const float v = *(float *) src1->data;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
+
+    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are same shape => same indices
+        const int i3 = ir/(ne2*ne1);
+        const int i2 = (ir - i3*ne2*ne1)/ne1;
+        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
+        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+        for (int i = 0; i < ne0; i++) {
+            dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v);
+        }
+    }
+}
+
+static void ggml_compute_forward_add1_f16_f16(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_scalar(src1));
+
+    // scalar to add
+    const float v = GGML_FP16_TO_FP32(*(ggml_fp16_t *) src1->data);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F16);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
+
+    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are same shape => same indices
+        const int i3 = ir/(ne2*ne1);
+        const int i2 = (ir - i3*ne2*ne1)/ne1;
+        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
+        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+        for (int i = 0; i < ne0; i++) {
+            dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v);
+        }
+    }
+}
+
+static void ggml_compute_forward_add1_q_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_scalar(src1));
+
+    // scalar to add
+    const float v = *(float *) src1->data;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    const enum ggml_type type = src0->type;
+    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
+    ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(type)->from_float;
+
+    // we don't support permuted src0
+    GGML_ASSERT(nb00 == ggml_type_size(type));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    GGML_ASSERT(ggml_is_quantized(src0->type));
+    GGML_ASSERT(dst->type == src0->type);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are same shape => same indices
+        const int i3 = ir/(ne2*ne1);
+        const int i2 = (ir - i3*ne2*ne1)/ne1;
+        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+        void  * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));
+        void  * dst_row  = (void *) ((char *)  dst->data + (i1*nb1  + i2*nb2  + i3*nb0 ));
+
+        assert(ne0 % 32 == 0);
+
+        // unquantize row from src0 to temp buffer
+        dequantize_row_q(src0_row, wdata, ne0);
+        // add src1
+        ggml_vec_acc1_f32(ne0, wdata, v);
+        // quantize row to dst
+        quantize_row_q(wdata, dst_row, ne0);
+    }
+}
+
+static void ggml_compute_forward_add1_bf16_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_scalar(src1));
+
+    // scalar to add
+    const float v = *(float *) src1->data;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
+
+    GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are same shape => same indices
+        const int i3 = ir/(ne2*ne1);
+        const int i2 = (ir - i3*ne2*ne1)/ne1;
+        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+        ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
+        ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+        for (int i = 0; i < ne0; i++) {
+            dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
+        }
+    }
+}
+
+static void ggml_compute_forward_add1_bf16_bf16(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_scalar(src1));
+
+    // scalar to add
+    const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
+    GGML_ASSERT(src1->type == GGML_TYPE_BF16);
+    GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
+
+    GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are same shape => same indices
+        const int i3 = ir/(ne2*ne1);
+        const int i2 = (ir - i3*ne2*ne1)/ne1;
+        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+        ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
+        ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+        for (int i = 0; i < ne0; i++) {
+            dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
+        }
+    }
+}
+
+static void ggml_compute_forward_add1(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_add1_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                if (src1->type == GGML_TYPE_F16) {
+                    ggml_compute_forward_add1_f16_f16(params, dst);
+                }
+                else if (src1->type == GGML_TYPE_F32) {
+                    ggml_compute_forward_add1_f16_f32(params, dst);
+                }
+                else {
+                    GGML_ABORT("fatal error");
+                }
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                if (src1->type == GGML_TYPE_BF16) {
+                    ggml_compute_forward_add1_bf16_bf16(params, dst);
+                }
+                else if (src1->type == GGML_TYPE_F32) {
+                    ggml_compute_forward_add1_bf16_f32(params, dst);
+                }
+                else {
+                    GGML_ABORT("fatal error");
+                }
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_TQ1_0:
+        case GGML_TYPE_TQ2_0:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
+        case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ2_S:
+        case GGML_TYPE_Q4_0_4_4:
+        case GGML_TYPE_Q4_0_4_8:
+        case GGML_TYPE_Q4_0_8_8:
+            {
+                ggml_compute_forward_add1_q_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_acc
+
+static void ggml_compute_forward_acc_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+
+    // view src0 and dst with these strides and data offset inbytes during acc
+    // nb0 is implicitly element_size because src0 and dst are contiguous
+    size_t nb1     = ((int32_t *) dst->op_params)[0];
+    size_t nb2     = ((int32_t *) dst->op_params)[1];
+    size_t nb3     = ((int32_t *) dst->op_params)[2];
+    size_t offset  = ((int32_t *) dst->op_params)[3];
+    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
+
+    if (!inplace) {
+        if (params->ith == 0) {
+            // memcpy needs to be synchronized across threads to avoid race conditions.
+            // => do it in INIT phase
+            memcpy(
+                ((char *)  dst->data),
+                ((char *) src0->data),
+                ggml_nbytes(dst));
+        }
+        ggml_barrier(params->threadpool);
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr = ggml_nrows(src1);
+    const int nc = src1->ne[0];
+
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
+
+    // src0 and dst as viewed during acc
+    const size_t nb0 = ggml_element_size(src0);
+
+    const size_t nb00 = nb0;
+    const size_t nb01 = nb1;
+    const size_t nb02 = nb2;
+    const size_t nb03 = nb3;
+
+    GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0  + (ne11 == 0 ? 0 : ne11-1)*nb1  + (ne12 == 0 ? 0 : ne12-1)*nb2  + (ne13 == 0 ? 0 : ne13-1)*nb3  < ggml_nbytes(dst));
+    GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_nbytes(src0));
+
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are viewed with shape of src1 and offset
+        // => same indices
+        const int i3 = ir/(ne12*ne11);
+        const int i2 = (ir - i3*ne12*ne11)/ne11;
+        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
+
+#ifdef GGML_USE_ACCELERATE
+        vDSP_vadd(
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1,
+                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
+                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1  + offset), 1, nc);
+#else
+        ggml_vec_add_f32(nc,
+                (float *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset),
+                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+#endif
+    }
+}
+
+static void ggml_compute_forward_acc(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_acc_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+        case GGML_TYPE_BF16:
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_TQ1_0:
+        case GGML_TYPE_TQ2_0:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
+        case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ2_S:
+        case GGML_TYPE_Q4_0_4_4:
+        case GGML_TYPE_Q4_0_4_8:
+        case GGML_TYPE_Q4_0_8_8:
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_sub
+
+static void ggml_compute_forward_sub_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    assert(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT( nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    if (nb10 == sizeof(float)) {
+        for (int ir = ir0; ir < ir1; ++ir) {
+            // src1 is broadcastable across src0 and dst in i1, i2, i3
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+            const int64_t nr0 = ne00 / ne10;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+            float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
+
+            for (int64_t r = 0; r < nr0; ++r) {
+#ifdef GGML_USE_ACCELERATE
+                vDSP_vsub(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
+#else
+                ggml_vec_sub_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
+#endif
+            }
+        }
+    } else {
+        // src1 is not contiguous
+        for (int ir = ir0; ir < ir1; ++ir) {
+            // src1 is broadcastable across src0 and dst in i1, i2, i3
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                const int64_t i10 = i0 % ne10;
+                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
+
+                dst_ptr[i0] = src0_ptr[i0] - *src1_ptr;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_sub(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sub_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_mul
+
+static void ggml_compute_forward_mul_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t nr = ggml_nrows(src0);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT( nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    if (nb10 == sizeof(float)) {
+        for (int64_t ir = ith; ir < nr; ir += nth) {
+            // src0 and dst are same shape => same indices
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+            const int64_t nr0 = ne00 / ne10;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+            float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
+
+            for (int64_t r = 0 ; r < nr0; ++r) {
+#ifdef GGML_USE_ACCELERATE
+                UNUSED(ggml_vec_mul_f32);
+
+                vDSP_vmul(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
+#else
+                ggml_vec_mul_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
+#endif
+            }
+        }
+    } else {
+        // src1 is not contiguous
+        for (int64_t ir = ith; ir < nr; ir += nth) {
+            // src0 and dst are same shape => same indices
+            // src1 is broadcastable across src0 and dst in i1, i2, i3
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+
+            for (int64_t i0 = 0; i0 < ne00; ++i0) {
+                const int64_t i10 = i0 % ne10;
+                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
+
+                dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr);
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_mul(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src1->type == GGML_TYPE_F32 && "only f32 src1 supported for now");
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_mul_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_div
+
+static void ggml_compute_forward_div_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t nr = ggml_nrows(src0);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT( nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    if (nb10 == sizeof(float)) {
+        for (int64_t ir = ith; ir < nr; ir += nth) {
+            // src0 and dst are same shape => same indices
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+            const int64_t nr0 = ne00 / ne10;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+            float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
+
+            for (int64_t r = 0; r < nr0; ++r) {
+#ifdef GGML_USE_ACCELERATE
+                UNUSED(ggml_vec_div_f32);
+
+                vDSP_vdiv(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
+#else
+                ggml_vec_div_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
+#endif
+            }
+        }
+    } else {
+        // src1 is not contiguous
+        for (int64_t ir = ith; ir < nr; ir += nth) {
+            // src0 and dst are same shape => same indices
+            // src1 is broadcastable across src0 and dst in i1, i2, i3
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+
+            for (int64_t i0 = 0; i0 < ne00; ++i0) {
+                const int64_t i10 = i0 % ne10;
+                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
+
+                dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr);
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_div(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_div_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_sqr
+
+static void ggml_compute_forward_sqr_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n     = ggml_nrows(src0);
+    const int nc    = src0->ne[0];
+
+    assert( dst->nb[0] == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sqr_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_sqr(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sqr_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_sqrt
+
+static void ggml_compute_forward_sqrt_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert( dst->nb[0] == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sqrt_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_sqrt(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sqrt_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_log
+
+static void ggml_compute_forward_log_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_log_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_log(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_log_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_sin
+
+static void ggml_compute_forward_sin_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sin_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_sin(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sin_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_cos
+
+static void ggml_compute_forward_cos_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_cos_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_cos(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_cos_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_sum
+
+static void ggml_compute_forward_sum_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_scalar(dst));
+    assert(src0->nb[0] == sizeof(float));
+
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
+
+    ggml_float sum     = 0;
+    ggml_float row_sum = 0;
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = 0; i01 < ne01; i01++) {
+                ggml_vec_sum_f32_ggf(ne00,
+                        &row_sum,
+                        (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
+                sum += row_sum;
+            }
+        }
+    }
+    ((float *) dst->data)[0] = sum;
+}
+
+static void ggml_compute_forward_sum_f16(
+    const struct ggml_compute_params * params,
+          struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_scalar(dst));
+
+    assert(src0->nb[0] == sizeof(ggml_fp16_t));
+
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
+
+    float sum = 0;
+    float row_sum = 0;
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = 0; i01 < ne01; i01++) {
+                ggml_vec_sum_f16_ggf(ne00,
+                    &row_sum,
+                    (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
+                sum += row_sum;
+            }
+        }
+    }
+    ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum);
+}
+
+static void ggml_compute_forward_sum_bf16(
+    const struct ggml_compute_params * params,
+          struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_scalar(dst));
+
+    assert(src0->nb[0] == sizeof(ggml_bf16_t));
+
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
+
+    float sum = 0;
+    float row_sum = 0;
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = 0; i01 < ne01; i01++) {
+                ggml_vec_sum_bf16_ggf(ne00,
+                    &row_sum,
+                    (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
+                sum += row_sum;
+            }
+        }
+    }
+    ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);
+}
+
+static void ggml_compute_forward_sum(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sum_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_sum_f16(params, dst);
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                ggml_compute_forward_sum_bf16(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_sum_rows
+
+static void ggml_compute_forward_sum_rows_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT(dst->nb[0] == sizeof(float));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(ne0 == 1);
+    GGML_ASSERT(ne1 == ne01);
+    GGML_ASSERT(ne2 == ne02);
+    GGML_ASSERT(ne3 == ne03);
+
+    for (int64_t i3 = 0; i3 < ne03; i3++) {
+        for (int64_t i2 = 0; i2 < ne02; i2++) {
+            for (int64_t i1 = 0; i1 < ne01; i1++) {
+                float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
+                float * dst_row = (float *) ((char *) dst->data  + i1*nb1  + i2*nb2  + i3*nb3);
+                float row_sum = 0;
+                ggml_vec_sum_f32(ne00, &row_sum, src_row);
+                dst_row[0] = row_sum;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_sum_rows(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sum_rows_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_mean
+
+static void ggml_compute_forward_mean_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(src0->nb[0] == sizeof(float));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    assert(ne0 == 1);
+    assert(ne1 == ne01);
+    assert(ne2 == ne02);
+    assert(ne3 == ne03);
+
+    UNUSED(ne0);
+    UNUSED(ne1);
+    UNUSED(ne2);
+    UNUSED(ne3);
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = 0; i01 < ne01; i01++) {
+                ggml_vec_sum_f32(ne00,
+                        (float *) ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
+                        (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
+
+                *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_mean(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_mean_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_argmax
+
+static void ggml_compute_forward_argmax_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(src0->nb[0] == sizeof(float));
+    assert(dst->nb[0] == sizeof(float));
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+
+    const size_t nb01 = src0->nb[1];
+    const size_t nb0 = dst->nb[0];
+
+    for (int64_t i1 = 0; i1 < ne01; i1++) {
+        float * src = (float *) ((char *) src0->data + i1*nb01);
+        int32_t * dst_ = (int32_t *) ((char *)  dst->data + i1*nb0);
+        int v = 0;
+        ggml_vec_argmax_f32(ne00, &v, src);
+        dst_[0] = v;
+    }
+}
+
+static void ggml_compute_forward_argmax(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_argmax_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_count_equal
+
+static void ggml_compute_forward_count_equal_i32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    GGML_ASSERT(src0->type == GGML_TYPE_I32);
+    GGML_ASSERT(src1->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_are_same_shape(src0, src1));
+    GGML_ASSERT(ggml_is_scalar(dst));
+    GGML_ASSERT(dst->type == GGML_TYPE_I64);
+
+    const int64_t nr = ggml_nrows(src0);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    int64_t * sums = (int64_t *) params->wdata;
+    int64_t sum_thread = 0;
+
+    // rows per thread
+    const int64_t dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t ir = ir0; ir < ir1; ++ir) {
+        const int64_t i03 =  ir                        / (ne02*ne01);
+        const int64_t i02 = (ir - i03*ne03)            /       ne01;
+        const int64_t i01 =  ir - i03*ne03 - i02*ne02;
+
+        const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01;
+        const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11;
+
+        for (int64_t i00 = 0; i00 < ne00; ++i00) {
+            const int32_t val0 = *((const int32_t *) (data0 + i00*nb00));
+            const int32_t val1 = *((const int32_t *) (data1 + i00*nb10));
+
+            sum_thread += val0 == val1;
+        }
+    }
+    if (ith != 0) {
+        sums[ith] = sum_thread;
+    }
+    ggml_barrier(params->threadpool);
+
+    if (ith != 0) {
+        return;
+    }
+
+    for (int ith_other = 1; ith_other < nth; ++ith_other) {
+        sum_thread += sums[ith_other];
+    }
+    *((int64_t *) dst->data) = sum_thread;
+}
+
+static void ggml_compute_forward_count_equal(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_I32:
+            {
+                ggml_compute_forward_count_equal_i32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_repeat
+
+static void ggml_compute_forward_repeat_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_can_repeat(src0, dst));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    // guaranteed to be an integer due to the check in ggml_can_repeat
+    const int nr0 = (int)(ne0/ne00);
+    const int nr1 = (int)(ne1/ne01);
+    const int nr2 = (int)(ne2/ne02);
+    const int nr3 = (int)(ne3/ne03);
+
+    // TODO: support for transposed / permuted tensors
+    GGML_ASSERT(nb0  == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // TODO: maybe this is not optimal?
+    for                         (int i3 = 0; i3 < nr3;  i3++) {
+        for                     (int k3 = 0; k3 < ne03; k3++) {
+            for                 (int i2 = 0; i2 < nr2;  i2++) {
+                for             (int k2 = 0; k2 < ne02; k2++) {
+                    for         (int i1 = 0; i1 < nr1;  i1++) {
+                        for     (int k1 = 0; k1 < ne01; k1++) {
+                            for (int i0 = 0; i0 < nr0;  i0++) {
+                                ggml_vec_cpy_f32(ne00,
+                                        (float *) ((char *)  dst->data + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0),
+                                        (float *) ((char *) src0->data + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01));
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_repeat_f16(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_can_repeat(src0, dst));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    // guaranteed to be an integer due to the check in ggml_can_repeat
+    const int nr0 = (int)(ne0/ne00);
+    const int nr1 = (int)(ne1/ne01);
+    const int nr2 = (int)(ne2/ne02);
+    const int nr3 = (int)(ne3/ne03);
+
+    // TODO: support for transposed / permuted tensors
+    GGML_ASSERT(nb0  == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    // TODO: maybe this is not optimal?
+    for                         (int i3 = 0; i3 < nr3;  i3++) {
+        for                     (int k3 = 0; k3 < ne03; k3++) {
+            for                 (int i2 = 0; i2 < nr2;  i2++) {
+                for             (int k2 = 0; k2 < ne02; k2++) {
+                    for         (int i1 = 0; i1 < nr1;  i1++) {
+                        for     (int k1 = 0; k1 < ne01; k1++) {
+                            for (int i0 = 0; i0 < nr0;  i0++) {
+                                ggml_fp16_t * y = (ggml_fp16_t *) ((char *)  dst->data + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0);
+                                ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01);
+                                // ggml_vec_cpy_f16(ne00, y, x)
+                                for (int i = 0; i < ne00; ++i) {
+                                    y[i]  = x[i];
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_repeat(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+        case GGML_TYPE_BF16:
+        case GGML_TYPE_I16:
+            {
+                ggml_compute_forward_repeat_f16(params, dst);
+            } break;
+        case GGML_TYPE_F32:
+        case GGML_TYPE_I32:
+            {
+                ggml_compute_forward_repeat_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_repeat_back
+
+static void ggml_compute_forward_repeat_back_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_can_repeat(dst, src0));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    // guaranteed to be an integer due to the check in ggml_can_repeat
+    const int nr0 = (int)(ne00/ne0);
+    const int nr1 = (int)(ne01/ne1);
+    const int nr2 = (int)(ne02/ne2);
+    const int nr3 = (int)(ne03/ne3);
+
+    // TODO: support for transposed / permuted tensors
+    GGML_ASSERT(nb0  == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    if (ggml_is_contiguous(dst)) {
+        ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
+    } else {
+        for         (int k3 = 0; k3 < ne3; k3++) {
+            for     (int k2 = 0; k2 < ne2; k2++) {
+                for (int k1 = 0; k1 < ne1; k1++) {
+                    ggml_vec_set_f32(ne0,
+                        (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
+                        0);
+                }
+            }
+        }
+    }
+
+    // TODO: maybe this is not optimal?
+    for                         (int i3 = 0; i3 < nr3; i3++) {
+        for                     (int k3 = 0; k3 < ne3; k3++) {
+            for                 (int i2 = 0; i2 < nr2; i2++) {
+                for             (int k2 = 0; k2 < ne2; k2++) {
+                    for         (int i1 = 0; i1 < nr1; i1++) {
+                        for     (int k1 = 0; k1 < ne1; k1++) {
+                            for (int i0 = 0; i0 < nr0; i0++) {
+                                ggml_vec_acc_f32(ne0,
+                                        (float *) ((char *)  dst->data + (         k3)*nb3  + (         k2)*nb2  + (         k1)*nb1),
+                                        (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_repeat_back(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_repeat_back_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_concat
+
+static void ggml_compute_forward_concat_f32(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int32_t dim = ggml_get_op_params_i32(dst, 0);
+
+    GGML_ASSERT(dim >= 0 && dim < 4);
+
+    int64_t o[4] = {0, 0, 0, 0};
+    o[dim] = src0->ne[dim];
+
+    const float * x;
+
+    // TODO: smarter multi-theading
+    for (int i3 = 0; i3 < ne3; i3++) {
+        for (int i2 = ith; i2 < ne2; i2 += nth) {
+            for (int i1 = 0; i1 < ne1; i1++) {
+                for (int i0 = 0; i0 < ne0; i0++) {
+                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+                        x = (const float *) ((const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03);
+                    } else {
+                        x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
+                    }
+
+                    float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
+
+                    *y = *x;
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_concat(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+        case GGML_TYPE_I32:
+            {
+                ggml_compute_forward_concat_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_abs
+
+static void ggml_compute_forward_abs_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_abs_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_abs(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_abs_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_sgn
+
+static void ggml_compute_forward_sgn_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sgn_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_sgn(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sgn_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_neg
+
+static void ggml_compute_forward_neg_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_neg_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_neg(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_neg_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_step
+
+static void ggml_compute_forward_step_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_step_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_step(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_step_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_tanh
+
+static void ggml_compute_forward_tanh_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_tanh_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_tanh(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_tanh_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_elu
+
+static void ggml_compute_forward_elu_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_elu_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_elu(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_elu_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_relu
+
+static void ggml_compute_forward_relu_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_relu_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_relu(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_relu_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_sigmoid
+
+static void ggml_compute_forward_sigmoid_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sigmoid_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_sigmoid(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sigmoid_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_gelu
+
+static void ggml_compute_forward_gelu_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_gelu_f32(nc,
+                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (float *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_gelu(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_gelu_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_gelu_quick
+
+static void ggml_compute_forward_gelu_quick_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_gelu_quick_f32(nc,
+                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (float *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_gelu_quick(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_gelu_quick_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_silu
+
+static void ggml_compute_forward_silu_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_silu_f32(nc,
+                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (float *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
+            UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_silu(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_silu_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+// ggml_compute_forward_leaky_relu
+
+static void ggml_compute_forward_leaky_relu_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    float negative_slope;
+    memcpy(&negative_slope, dst->op_params, sizeof(float));
+
+    assert(dst->nb[0]  == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_leaky_relu_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
+    }
+}
+
+static void ggml_compute_forward_leaky_relu(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_leaky_relu_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_silu_back
+
+static void ggml_compute_forward_silu_back_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * grad = dst->src[1];
+
+    assert(ggml_is_contiguous_1(grad));
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+    assert(ggml_are_same_shape(src0, grad));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_silu_backward_f32(nc,
+                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (float *) ((char *) src0->data + i1*(src0->nb[1])),
+                (float *) ((char *) grad->data + i1*(grad->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_silu_back(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_silu_back_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+
+static void ggml_compute_forward_hardswish_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_hardswish_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+static void ggml_compute_forward_hardswish(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_hardswish_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+static void ggml_compute_forward_hardsigmoid_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_hardsigmoid_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_hardsigmoid(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_hardsigmoid_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+static void ggml_compute_forward_exp_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_exp_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_exp(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_exp_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+
+// ggml_compute_forward_norm
+
+static void ggml_compute_forward_norm_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    GGML_ASSERT(eps > 0.0f);
+
+    // TODO: optimize
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
+                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+                ggml_float sum = 0.0;
+                for (int64_t i00 = 0; i00 < ne00; i00++) {
+                    sum += (ggml_float)x[i00];
+                }
+
+                float mean = sum/ne00;
+
+                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+                ggml_float sum2 = 0.0;
+                for (int64_t i00 = 0; i00 < ne00; i00++) {
+                    float v = x[i00] - mean;
+                    y[i00] = v;
+                    sum2 += (ggml_float)(v*v);
+                }
+
+                float variance = sum2/ne00;
+                const float scale = 1.0f/sqrtf(variance + eps);
+
+                ggml_vec_scale_f32(ne00, y, scale);
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_norm(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_norm_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_group_rms_norm
+
+static void ggml_compute_forward_rms_norm_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    GGML_ASSERT(eps > 0.0f);
+
+    // TODO: optimize
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
+                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+                ggml_float sum = 0.0;
+                for (int64_t i00 = 0; i00 < ne00; i00++) {
+                    sum += (ggml_float)(x[i00] * x[i00]);
+                }
+
+                const float mean = sum/ne00;
+
+                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+                memcpy(y, x, ne00 * sizeof(float));
+                // for (int i00 = 0; i00 < ne00; i00++) {
+                //     y[i00] = x[i00];
+                // }
+
+                const float scale = 1.0f/sqrtf(mean + eps);
+
+                ggml_vec_scale_f32(ne00, y, scale);
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_rms_norm(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rms_norm_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+static void ggml_compute_forward_rms_norm_back_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    // TODO: optimize
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
+                // src1 is same shape as src0 => same indices
+                const int64_t i11 = i01;
+                const int64_t i12 = i02;
+                const int64_t i13 = i03;
+
+                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+                const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
+
+                ggml_float sum_xx  = 0.0;
+                ggml_float sum_xdz = 0.0;
+
+                for (int64_t i00 = 0; i00 < ne00; i00++) {
+                    sum_xx  += (ggml_float)(x[i00] * x[i00]);
+                    sum_xdz += (ggml_float)(x[i00] * dz[i00]);
+                }
+
+                //const float mean     = (float)(sum_xx)/ne00;
+                const float mean_eps = (float)(sum_xx)/ne00 + eps;
+                const float sum_eps  = (float)(sum_xx) + eps*ne00;
+                //const float mean_xdz = (float)(sum_xdz)/ne00;
+                // we could cache rms from forward pass to improve performance.
+                // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms.
+                //const float rms      = sqrtf(mean_eps);
+                const float rrms     = 1.0f / sqrtf(mean_eps);
+                //const float scale    = -rrms/(ne00 * mean_eps); // -1/(n*rms**3)
+
+                {
+                    // z = rms_norm(x)
+                    //
+                    // rms_norm(src0) =
+                    //     scale(
+                    //         src0,
+                    //         div(
+                    //             1,
+                    //             sqrt(
+                    //                 add(
+                    //                     scale(
+                    //                         sum(
+                    //                             sqr(
+                    //                                 src0)),
+                    //                         (1.0/N)),
+                    //                     eps))));
+
+                    // postorder:
+                    // ## op    args         grad
+                    // 00 param src0         grad[#00]
+                    // 01 const 1
+                    // 02 sqr   (#00)        grad[#02]
+                    // 03 sum   (#02)        grad[#03]
+                    // 04 const 1/N
+                    // 05 scale (#03, #04)   grad[#05]
+                    // 06 const eps
+                    // 07 add   (#05, #06)   grad[#07]
+                    // 08 sqrt  (#07)        grad[#08]
+                    // 09 div   (#01,#08)    grad[#09]
+                    // 10 scale (#00,#09)    grad[#10]
+                    //
+                    // backward pass, given grad[#10]
+                    // #10: scale
+                    // grad[#00] += scale(grad[#10],#09)
+                    // grad[#09] += sum(mul(grad[#10],#00))
+                    // #09: div
+                    // grad[#08] += neg(mul(grad[#09], div(#09,#08)))
+                    // #08: sqrt
+                    // grad[#07] += mul(grad[#08], div(0.5, #08))
+                    // #07: add
+                    // grad[#05] += grad[#07]
+                    // #05: scale
+                    // grad[#03] += scale(grad[#05],#04)
+                    // #03: sum
+                    // grad[#02] += repeat(grad[#03], #02)
+                    // #02:
+                    // grad[#00] += scale(mul(#00, grad[#02]), 2.0)
+                    //
+                    // substitute and simplify:
+                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
+                    // grad[#02] = repeat(grad[#03], #02)
+                    // grad[#02] = repeat(scale(grad[#05],#04), #02)
+                    // grad[#02] = repeat(scale(grad[#07],#04), #02)
+                    // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02)
+                    // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02)
+                    // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02)
+                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02)
+                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02)
+                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02)
+                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)
+                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
+                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0)
+                    // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0)
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N)))
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N))
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps))
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps)))
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps))
+                    // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps))
+                    // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps))
+                    // a = b*c + d*e
+                    // a = b*c*f/f + d*e*f/f
+                    // a = (b*c*f + d*e*f)*(1/f)
+                    // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c))
+                    // a = (b + d*e/c)*c
+                    // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps)
+                    // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms
+                    // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms
+                    // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms
+                    // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms
+                    // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms
+                    // a = (dz + x*div(-mean_xdz,mean_eps))*rrms
+                    // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms)
+                    // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
+                    // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
+                }
+                // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
+                // post-order:
+                // dx := x
+                // dx := scale(dx,-mean_xdz/mean_eps)
+                // dx := add(dx, dz)
+                // dx := scale(dx, rrms)
+                float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+                ggml_vec_cpy_f32  (ne00, dx, x);
+                // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
+                ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
+                ggml_vec_acc_f32  (ne00, dx, dz);
+                ggml_vec_scale_f32(ne00, dx, rrms);
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_rms_norm_back(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rms_norm_back_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_group_norm
+
+static void ggml_compute_forward_group_norm_f32(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    // TODO: optimize
+
+    float eps;
+    memcpy(&eps, dst->op_params + 1, sizeof(float));
+
+    int n_channels = src0->ne[2];
+    int n_groups = dst->op_params[0];
+    int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
+    for (int i = ith; i < n_groups; i += nth) {
+        int start = i * n_channels_per_group;
+        int end = start + n_channels_per_group;
+        if (end > n_channels) {
+            end = n_channels;
+        }
+        int step = end - start;
+
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            ggml_float sum = 0.0;
+            for (int64_t i02 = start; i02 < end; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
+
+                    ggml_float sumr = 0.0;
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        sumr += (ggml_float)x[i00];
+                    }
+                    sum += sumr;
+                }
+            }
+            const float mean = sum / (ne00 * ne01 * step);
+
+            ggml_float sum2 = 0.0;
+            for (int64_t i02 = start; i02 < end; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
+
+                    float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
+
+                    ggml_float sumr = 0.0;
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        float v = x[i00] - mean;
+                        y[i00] = v;
+                        sumr += (ggml_float)(v * v);
+                    }
+                    sum2 += sumr;
+                }
+            }
+            const float variance = sum2 / (ne00 * ne01 * step);
+            const float scale = 1.0f / sqrtf(variance + eps);
+
+            for (int64_t i02 = start; i02 < end; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
+                    ggml_vec_scale_f32(ne00, y, scale);
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_group_norm(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_group_norm_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_mul_mat
+
+static void ggml_compute_forward_mul_mat_one_chunk(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst,
+    const enum ggml_type type,
+    const int64_t num_rows_per_vec_dot,
+    const int64_t ir0_start,
+    const int64_t ir0_end,
+    const int64_t ir1_start,
+    const int64_t ir1_end) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const bool src1_cont = ggml_is_contiguous(src1);
+
+    ggml_vec_dot_t const vec_dot      = type_traits_cpu[type].vec_dot;
+    enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
+
+    // broadcast factors
+    const int64_t r2 = ne12 / ne02;
+    const int64_t r3 = ne13 / ne03;
+
+    //printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end);
+
+    // threads with no work simply yield (not sure if it helps)
+    if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
+        return;
+    }
+
+    const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
+    const size_t row_size = ggml_row_size(vec_dot_type, ne10);
+
+    assert(ne12 % ne02 == 0);
+    assert(ne13 % ne03 == 0);
+
+    // block-tiling attempt
+    const int64_t blck_0 = 16;
+    const int64_t blck_1 = 16;
+
+    const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
+
+    // attempt to reduce false-sharing (does not seem to make a difference)
+    // 16 * 2, accounting for mmla kernels
+    float tmp[32];
+
+    for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
+        for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
+            for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
+                const int64_t i13 = (ir1 / (ne12 * ne1));
+                const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
+                const int64_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
+
+                // broadcast src0 into src1
+                const int64_t i03 = i13 / r3;
+                const int64_t i02 = i12 / r2;
+
+                const int64_t i1 = i11;
+                const int64_t i2 = i12;
+                const int64_t i3 = i13;
+
+                const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03);
+
+                // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
+                //       if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
+                //       the original src1 data pointer, so we should index using the indices directly
+                // TODO: this is a bit of a hack, we should probably have a better way to handle this
+                const char * src1_col = (const char*)wdata +
+                    (src1_cont || src1->type != vec_dot_type
+                        ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
+                        : (i11 * nb11 + i12 * nb12 + i13 * nb13));
+                float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
+
+                //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
+                //    vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
+                //}
+
+                for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
+                    vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
+                }
+
+                for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
+                    memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_mul_mat(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    enum ggml_type type = src0->type;
+
+    if (src0->buffer && ggml_backend_cpu_buft_is_aarch64(src0->buffer->buft)) {
+        type = (enum ggml_type)(intptr_t)src0->extra;
+    }
+
+#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
+    if (src0->buffer && ggml_backend_amx_buft_is_amx(src0->buffer->buft)) {
+        ggml_backend_amx_mul_mat(params, dst);
+        return;
+    }
+#endif
+
+    enum ggml_type           const vec_dot_type         = type_traits_cpu[type].vec_dot_type;
+    ggml_from_float_t        const from_float           = type_traits_cpu[vec_dot_type].from_float;
+    ggml_from_float_to_mat_t const from_float_to_mat    = type_traits_cpu[vec_dot_type].from_float_to_mat;
+    int64_t                  const vec_dot_num_rows     = type_traits_cpu[type].nrows;
+    int64_t                  const matmul_num_cols      = type_traits_cpu[type].ncols;
+    int64_t                  const blck_size_interleave = ggml_get_type_traits(type)->blck_size_interleave;
+    ggml_gemv_t              const gemv                 = type_traits_cpu[type].gemv;
+    ggml_gemm_t              const gemm                 = type_traits_cpu[type].gemm;
+
+    GGML_ASSERT(ne0 == ne01);
+    GGML_ASSERT(ne1 == ne11);
+    GGML_ASSERT(ne2 == ne12);
+    GGML_ASSERT(ne3 == ne13);
+
+    // we don't support permuted src0 or src1
+    GGML_ASSERT(nb00 == ggml_type_size(type));
+    GGML_ASSERT(nb10 == ggml_type_size(src1->type));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    // nb01 >= nb00 - src0 is not transposed
+    //   compute by src0 rows
+
+#if GGML_USE_LLAMAFILE
+    // broadcast factors
+    const int64_t r2 = ne12 / ne02;
+    const int64_t r3 = ne13 / ne03;
+
+    const bool src1_cont = ggml_is_contiguous(src1);
+
+    if (src1_cont) {
+        for (int64_t i13 = 0; i13 < ne13; i13++)
+            for (int64_t i12 = 0; i12 < ne12; i12++)
+                if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(type),
+                                     (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
+                                     nb01/ggml_type_size(type),
+                                     (const char *)src1->data + i12*nb12 + i13*nb13,
+                                     nb11/ggml_type_size(src1->type),
+                                     (char *)dst->data + i12*nb2 + i13*nb3,
+                                     nb1/ggml_type_size(dst->type),
+                                     ith, nth,
+                                     type,
+                                     src1->type,
+                                     dst->type))
+                    goto UseGgmlGemm1;
+        return;
+    }
+UseGgmlGemm1:;
+#endif
+
+    if (src1->type != vec_dot_type) {
+        char * wdata = params->wdata;
+
+        const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
+        const size_t nbw2 = nbw1*ne11;
+        const size_t nbw3 = nbw2*ne12;
+
+        assert(params->wsize >= ne13*nbw3);
+        GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+        for (int64_t i13 = 0; i13 < ne13; ++i13) {
+            for (int64_t i12 = 0; i12 < ne12; ++i12) {
+                int64_t i11_processed = 0;
+                if ((ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) {
+                    for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
+                        from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
+                                          (void *)               (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
+                                          4, ne10, blck_size_interleave);
+                    }
+                    i11_processed = ne11 - ne11 % 4;
+                }
+                for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
+                    from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
+                           (void *)               (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
+                           ne10);
+                }
+            }
+        }
+    }
+
+    if (ith == 0) {
+        // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
+        atomic_store_explicit(&params->threadpool->current_chunk, nth, memory_order_relaxed);
+    }
+
+    ggml_barrier(params->threadpool);
+
+#if GGML_USE_LLAMAFILE
+    if (src1->type != vec_dot_type) {
+        const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
+        const size_t row_size = ggml_row_size(vec_dot_type, ne10);
+
+        for (int64_t i13 = 0; i13 < ne13; i13++)
+            for (int64_t i12 = 0; i12 < ne12; i12++)
+                if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(type),
+                                     (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
+                                     nb01/ggml_type_size(type),
+                                     (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
+                                     row_size/ggml_type_size(vec_dot_type),
+                                     (char *)dst->data + i12*nb2 + i13*nb3,
+                                     nb1/ggml_type_size(dst->type),
+                                     ith, nth,
+                                     type,
+                                     vec_dot_type,
+                                     dst->type))
+                    goto UseGgmlGemm2;
+        return;
+    }
+UseGgmlGemm2:;
+#endif
+
+    // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
+    const int64_t nr0 = ne0;
+
+    // This is the size of the rest of the dimensions of the result
+    const int64_t nr1 = ne1 * ne2 * ne3;
+
+    // Now select a reasonable chunk size.
+    int chunk_size = 16;
+
+    // We need to step up the size if it's small
+    if (nr0 == 1 || nr1 == 1) {
+        chunk_size = 64;
+    }
+
+    // distribute the work across the inner or outer loop based on which one is larger
+    // The number of chunks in the 0/1 dim.
+    // CEIL(nr0/chunk_size)
+    int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
+    int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
+
+    // If the chunking is poor for the number of threads on this setup, scrap the whole plan.  Re-chunk it by thread.
+    //   Also, chunking by thread was measured to have perform better on NUMA systems.  See https://github.com/ggerganov/llama.cpp/pull/6915
+    //   In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that.
+    if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) {
+        // distribute the thread work across the inner or outer loop based on which one is larger
+        nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
+        nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
+    }
+
+    // The number of elements in each chunk
+    const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
+    const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
+
+    if ((ggml_n_dims(src0) == 2) && gemv) {
+        const void * src1_wdata      = (src1->type == vec_dot_type) ? src1->data : params->wdata;
+        const size_t src1_col_stride = ggml_is_contiguous(src1) || src1->type != vec_dot_type ? ggml_row_size(vec_dot_type, ne10) : nb11;
+        int64_t src0_start = (ith * ne01) / nth;
+        int64_t src0_end   = ((ith + 1) * ne01) / nth;
+        src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start;
+        src0_end   = (src0_end   % matmul_num_cols) ? src0_end   + matmul_num_cols - (src0_end   % matmul_num_cols): src0_end;
+        if (src0_start >= src0_end) return;
+
+        // If there are more than three rows in src1, use gemm; otherwise, use gemv.
+        if (gemm && (ne11 > 3)) {
+            gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01,
+                 (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
+        }
+        for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++) {
+            gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01,
+                 (const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1,
+                 src0_end - src0_start);
+        }
+        return;
+    }
+
+    // The first chunk comes from our thread_id, the rest will get auto-assigned.
+    int current_chunk = ith;
+
+    while (current_chunk < nchunk0 * nchunk1) {
+        const int64_t ith0 = current_chunk % nchunk0;
+        const int64_t ith1 = current_chunk / nchunk0;
+
+        const int64_t ir0_start = dr0 * ith0;
+        const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
+
+        const int64_t ir1_start = dr1 * ith1;
+        const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
+
+        // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
+        int64_t num_rows_per_vec_dot = vec_dot_num_rows;
+
+        // these checks are needed to avoid crossing dim1 boundaries
+        // can be optimized, but the logic would become more complicated, so keeping it like this for simplicity
+        if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
+            num_rows_per_vec_dot = 1;
+        }
+
+        ggml_compute_forward_mul_mat_one_chunk(params, dst, type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
+
+        if (nth >= nchunk0 * nchunk1) {
+            break;
+        }
+
+        current_chunk = atomic_fetch_add_explicit(&params->threadpool->current_chunk, 1, memory_order_relaxed);
+    }
+}
+
+// ggml_compute_forward_mul_mat_id
+
+static void ggml_compute_forward_mul_mat_id(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+    const struct ggml_tensor * ids = dst->src[2];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const enum ggml_type type = src0->type;
+
+    const bool src1_cont = ggml_is_contiguous(src1);
+
+    ggml_vec_dot_t    const vec_dot         = type_traits_cpu[type].vec_dot;
+    enum ggml_type    const vec_dot_type    = type_traits_cpu[type].vec_dot_type;
+    ggml_from_float_t const from_float      = type_traits_cpu[vec_dot_type].from_float;
+    int64_t           const matmul_num_cols = type_traits_cpu[type].ncols;
+    ggml_gemv_t       const gemv            = type_traits_cpu[type].gemv;
+
+    // we don't support permuted src0 or src1
+    GGML_ASSERT(nb00 == ggml_type_size(type));
+    GGML_ASSERT(nb10 == ggml_type_size(src1->type));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    // row groups
+    const int n_ids = ids->ne[0]; // n_expert_used
+    const int n_as  = ne02;       // n_expert
+
+    char * wdata_src1_end = (src1->type == vec_dot_type) ?
+            (char *) params->wdata :
+            (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
+
+    struct mmid_row_mapping {
+        int32_t i1;
+        int32_t i2;
+    };
+
+    int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
+    struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
+
+    if (src1->type != vec_dot_type) {
+        char * wdata = params->wdata;
+
+        const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
+        const size_t nbw2 = nbw1*ne11;
+        const size_t nbw3 = nbw2*ne12;
+
+        assert(params->wsize >= ne13*nbw3);
+        GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+        for (int64_t i13 = 0; i13 < ne13; ++i13) {
+            for (int64_t i12 = 0; i12 < ne12; ++i12) {
+                for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
+                    from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
+                               (void *)               (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
+                               ne10);
+                }
+            }
+        }
+    }
+
+#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
+
+    if (ith == 0) {
+        // initialize matrix_row_counts
+        memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
+
+        // group rows by src0 matrix
+        for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
+            for (int id = 0; id < n_ids; ++id) {
+                const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
+
+                assert(i02 >= 0 && i02 < n_as);
+
+                MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};
+                matrix_row_counts[i02] += 1;
+            }
+        }
+    }
+
+    ggml_barrier(params->threadpool);
+
+    // compute each matrix multiplication in sequence
+    for (int cur_a = 0; cur_a < n_as; ++cur_a) {
+        const int64_t cne1 = matrix_row_counts[cur_a];
+
+        if (cne1 == 0) {
+            continue;
+        }
+
+        const char * src0_cur = (const char *) src0->data + cur_a*nb02;
+
+        const void * wdata    = (src1->type == vec_dot_type) ? src1->data : params->wdata;
+        const size_t row_size = ggml_row_size(vec_dot_type, ne10);
+
+        const int64_t nr0 = ne01; // src0 rows
+        const int64_t nr1 = cne1; // src1 rows
+
+        if (((ggml_n_dims(src0) - 1) == 2) && gemv) {
+            int64_t src0_cur_start = (ith * ne01) / nth;
+            int64_t src0_cur_end   = ((ith + 1) * ne01) / nth;
+            src0_cur_start = (src0_cur_start % matmul_num_cols) ? src0_cur_start + matmul_num_cols - (src0_cur_start % matmul_num_cols): src0_cur_start;
+            src0_cur_end   = (src0_cur_end % matmul_num_cols) ? src0_cur_end + matmul_num_cols - (src0_cur_end % matmul_num_cols): src0_cur_end;
+            if (src0_cur_start >= src0_cur_end) return;
+
+            for (int ir1 = 0; ir1 < nr1; ir1++) {
+                struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
+                const int id       = row_mapping.i1; // selected expert index
+
+                const int64_t  i11 = id % ne11;
+                const int64_t  i12 = row_mapping.i2; // row index in src1
+
+                const int64_t  i1 = id;  // selected expert index
+                const int64_t  i2 = i12; // row
+
+                const char * src1_col = (const char *) wdata +
+                    (src1_cont || src1->type != vec_dot_type
+                    ? (i11        + i12 * ne11) * row_size
+                    : (i11 * nb11 + i12 * nb12));
+
+                gemv(ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
+                     (const char *) src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start);
+            }
+            continue;
+        }
+
+        // distribute the thread work across the inner or outer loop based on which one is larger
+
+        const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
+        const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
+
+        const int64_t ith0 = ith % nth0;
+        const int64_t ith1 = ith / nth0;
+
+        const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
+        const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
+
+        const int64_t ir010 = dr0*ith0;
+        const int64_t ir011 = MIN(ir010 + dr0, nr0);
+
+        const int64_t ir110 = dr1*ith1;
+        const int64_t ir111 = MIN(ir110 + dr1, nr1);
+
+        // threads with no work simply yield (not sure if it helps)
+        //if (ir010 >= ir011 || ir110 >= ir111) {
+        //    sched_yield();
+        //    continue;
+        //}
+
+        // block-tiling attempt
+        const int64_t blck_0 = 16;
+        const int64_t blck_1 = 16;
+
+        // attempt to reduce false-sharing (does not seem to make a difference)
+        float tmp[16];
+
+        for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
+            for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
+                for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
+                    const int64_t _i12 = ir1; // logical row index for this expert
+
+                    struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
+                    const int id       = row_mapping.i1; // selected expert index
+
+                    const int64_t  i11 = id % ne11;
+                    const int64_t  i12 = row_mapping.i2; // row index in src1
+
+                    const int64_t  i1 = id;  // selected expert index
+                    const int64_t  i2 = i12; // row
+
+                    // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
+                    //       if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
+                    //       the original src1 data pointer, so we should index using the indices directly
+                    // TODO: this is a bit of a hack, we should probably have a better way to handle this
+                    const char * src1_col = (const char *) wdata +
+                        (src1_cont || src1->type != vec_dot_type
+                        ? (i11      + i12*ne11)*row_size
+                        : (i11*nb11 + i12*nb12));
+
+                    float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
+
+                    //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
+                    //    vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
+                    //}
+
+                    for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
+                        vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
+                    }
+
+                    memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
+                }
+            }
+        }
+    }
+
+#undef MMID_MATRIX_ROW
+}
+
+// ggml_compute_forward_out_prod
+
+static void ggml_compute_forward_out_prod_f32(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_ASSERT(ne0  == ne00);
+    GGML_ASSERT(ne1  == ne10);
+    GGML_ASSERT(ne2  == ne02);
+    GGML_ASSERT(ne02 == ne12);
+    GGML_ASSERT(ne3  == ne13);
+    GGML_ASSERT(ne03 == ne13);
+
+    // we don't support permuted src0 or src1
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    // GGML_ASSERT(nb0 <= nb1);
+    // GGML_ASSERT(nb1 <= nb2);
+    // GGML_ASSERT(nb2 <= nb3);
+
+    // nb01 >= nb00 - src0 is not transposed
+    //   compute by src0 rows
+
+    if (ith == 0) {
+        ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
+    }
+    ggml_barrier(params->threadpool);
+
+    // dst[:,:,:,:] = 0
+    // for i2,i3:
+    //   for i1:
+    //     for i01:
+    //       for i0:
+    //         dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
+
+    // parallelize by last three dimensions
+
+    // total rows in dst
+    const int64_t nr = ne1*ne2*ne3;
+
+    // rows per thread
+    const int64_t dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = MIN(ir0 + dr, nr);
+
+    // block-tiling attempt
+    const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
+    const int64_t blck_1 = 16;
+
+    for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
+        const int64_t bir1 = MIN(bir + blck_1, ir1);
+        for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
+            const int64_t bne01 = MIN(bi01 + blck_0, ne01);
+            for (int64_t ir = bir; ir < bir1; ++ir) {
+                // dst indices
+                const int64_t i3 = ir/(ne2*ne1);
+                const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
+                const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+                const int64_t i02 = i2;
+                const int64_t i03 = i3;
+
+                //const int64_t i10 = i1;
+                const int64_t i12 = i2;
+                const int64_t i13 = i3;
+
+#if GGML_VEC_MAD_UNROLL > 2
+                const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
+                for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
+                    const int64_t i11 = i01;
+
+                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
+                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
+
+                    ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
+                }
+                for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
+                    const int64_t i11 = i01;
+
+                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
+                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
+
+                    ggml_vec_mad_f32(ne0, d, s0, *s1);
+                }
+#else
+                for (int64_t i01 = bi01; i01 < bne01; ++i01) {
+                    const int64_t i11 = i01;
+
+                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
+                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
+
+                    ggml_vec_mad_f32(ne0, d, s0, *s1);
+                }
+#endif
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_out_prod_q_f32(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const enum ggml_type type = src0->type;
+    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
+
+    GGML_ASSERT(ne02 == ne12);
+    GGML_ASSERT(ne03 == ne13);
+    GGML_ASSERT(ne2  == ne12);
+    GGML_ASSERT(ne3  == ne13);
+
+    // we don't support permuted src0 dim0
+    GGML_ASSERT(nb00 == ggml_type_size(type));
+
+    // dst dim0 cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    // GGML_ASSERT(nb0 <= nb1);
+    // GGML_ASSERT(nb1 <= nb2);
+    // GGML_ASSERT(nb2 <= nb3);
+
+    GGML_ASSERT(ne0 == ne00);
+    GGML_ASSERT(ne1 == ne10);
+    GGML_ASSERT(ne2 == ne02);
+    GGML_ASSERT(ne3 == ne03);
+
+    // nb01 >= nb00 - src0 is not transposed
+    //   compute by src0 rows
+
+    if (ith == 0) {
+        ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
+    }
+    ggml_barrier(params->threadpool);
+
+    // parallelize by last three dimensions
+
+    // total rows in dst
+    const int64_t nr = ne1*ne2*ne3;
+
+    // rows per thread
+    const int64_t dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = MIN(ir0 + dr, nr);
+
+    // dst[:,:,:,:] = 0
+    // for i2,i3:
+    //   for i1:
+    //     for i01:
+    //       for i0:
+    //         dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
+
+    float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
+
+    for (int64_t ir = ir0; ir < ir1; ++ir) {
+        // dst indices
+        const int64_t i3 = ir/(ne2*ne1);
+        const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
+        const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+        const int64_t i02 = i2;
+        const int64_t i03 = i3;
+
+        //const int64_t i10 = i1;
+        const int64_t i12 = i2;
+        const int64_t i13 = i3;
+
+        for (int64_t i01 = 0; i01 < ne01; ++i01) {
+            const int64_t i11 = i01;
+
+            float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
+            float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+            float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
+
+            dequantize_row_q(s0, wdata, ne0);
+            ggml_vec_mad_f32(ne0, d, wdata, *s1);
+        }
+    }
+}
+
+static void ggml_compute_forward_out_prod(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_TQ1_0:
+        case GGML_TYPE_TQ2_0:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
+        case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ2_S:
+        case GGML_TYPE_Q4_0_4_4:
+        case GGML_TYPE_Q4_0_4_8:
+        case GGML_TYPE_Q4_0_8_8:
+            {
+                ggml_compute_forward_out_prod_q_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                GGML_ABORT("fatal error"); // todo
+                // ggml_compute_forward_out_prod_f16_f32(params, dst);
+            }
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_out_prod_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_scale
+
+static void ggml_compute_forward_scale_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    // scale factor
+    float v;
+    memcpy(&v, dst->op_params, sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    const size_t nb01 = src0->nb[1];
+
+    const size_t nb1 = dst->nb[1];
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        if (dst->data != src0->data) {
+            // src0 is same shape as dst => same indices
+            memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
+        }
+        ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
+    }
+}
+
+static void ggml_compute_forward_scale(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_scale_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_set
+
+static void ggml_compute_forward_set_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+
+    // view src0 and dst with these strides and data offset inbytes during set
+    // nb0 is implicitly element_size because src0 and dst are contiguous
+    size_t nb1     = ((int32_t *) dst->op_params)[0];
+    size_t nb2     = ((int32_t *) dst->op_params)[1];
+    size_t nb3     = ((int32_t *) dst->op_params)[2];
+    size_t offset  = ((int32_t *) dst->op_params)[3];
+    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
+
+    if (!inplace) {
+        if (params->ith == 0) {
+            // memcpy needs to be synchronized across threads to avoid race conditions.
+            // => do it in INIT phase
+            memcpy(
+                ((char *)  dst->data),
+                ((char *) src0->data),
+                ggml_nbytes(dst));
+        }
+        ggml_barrier(params->threadpool);
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr = ggml_nrows(src1);
+    const int nc = src1->ne[0];
+
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
+
+    // src0 and dst as viewed during set
+    const size_t nb0 = ggml_element_size(src0);
+
+    const int im0 = (ne10 == 0 ? 0 : ne10-1);
+    const int im1 = (ne11 == 0 ? 0 : ne11-1);
+    const int im2 = (ne12 == 0 ? 0 : ne12-1);
+    const int im3 = (ne13 == 0 ? 0 : ne13-1);
+
+    GGML_ASSERT(offset + im0*nb0  + im1*nb1  + im2*nb2  + im3*nb3  <= ggml_nbytes(dst));
+
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are viewed with shape of src1 and offset
+        // => same indices
+        const int i3 = ir/(ne12*ne11);
+        const int i2 = (ir - i3*ne12*ne11)/ne11;
+        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
+
+        ggml_vec_cpy_f32(nc,
+                (float *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
+                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+    }
+}
+
+static void ggml_compute_forward_set(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_set_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+        case GGML_TYPE_BF16:
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_TQ1_0:
+        case GGML_TYPE_TQ2_0:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
+        case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ2_S:
+        case GGML_TYPE_Q4_0_4_4:
+        case GGML_TYPE_Q4_0_4_8:
+        case GGML_TYPE_Q4_0_8_8:
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_cpy
+
+static void ggml_compute_forward_cpy(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    ggml_compute_forward_dup(params, dst);
+}
+
+// ggml_compute_forward_cont
+
+static void ggml_compute_forward_cont(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    ggml_compute_forward_dup(params, dst);
+}
+
+// ggml_compute_forward_reshape
+
+static void ggml_compute_forward_reshape(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    // NOP
+    UNUSED(params);
+    UNUSED(dst);
+}
+
+// ggml_compute_forward_view
+
+static void ggml_compute_forward_view(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * dst) {
+    // NOP
+    UNUSED(params);
+    UNUSED(dst);
+}
+
+// ggml_compute_forward_permute
+
+static void ggml_compute_forward_permute(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * dst) {
+    // NOP
+    UNUSED(params);
+    UNUSED(dst);
+}
+
+// ggml_compute_forward_transpose
+
+static void ggml_compute_forward_transpose(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * dst) {
+    // NOP
+    UNUSED(params);
+    UNUSED(dst);
+}
+
+// ggml_compute_forward_get_rows
+
+static void ggml_compute_forward_get_rows_q(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int64_t nc = ne00;
+    const int64_t nr = ggml_nelements(src1);
+
+    const enum ggml_type type = src0->type;
+    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
+
+    assert(ne0  == nc);
+    assert(ne02 == ne11);
+    assert(nb00 == ggml_type_size(type));
+    assert(ggml_nrows(dst) == nr);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t i = ir0; i < ir1; ++i) {
+        const int64_t i12 = i/(ne11*ne10);
+        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
+        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
+        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+        GGML_ASSERT(i01 >= 0 && i01 < ne01);
+
+        dequantize_row_q(
+                (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+                     (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
+    }
+}
+
+static void ggml_compute_forward_get_rows_f16(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int64_t nc = ne00;
+    const int64_t nr = ggml_nelements(src1);
+
+    assert(ne0  == nc);
+    assert(ne02 == ne11);
+    assert(nb00 == sizeof(ggml_fp16_t));
+    assert(ggml_nrows(dst) == nr);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t i = ir0; i < ir1; ++i) {
+        const int64_t i12 = i/(ne11*ne10);
+        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
+        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
+        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+        GGML_ASSERT(i01 >= 0 && i01 < ne01);
+
+        ggml_fp16_to_fp32_row(
+                (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+                     (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
+    }
+}
+
+static void ggml_compute_forward_get_rows_bf16(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int64_t nc = ne00;
+    const int64_t nr = ggml_nelements(src1);
+
+    assert(ne0  == nc);
+    assert(ne02 == ne11);
+    assert(nb00 == sizeof(ggml_bf16_t));
+    assert(ggml_nrows(dst) == nr);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t i = ir0; i < ir1; ++i) {
+        const int64_t i12 = i/(ne11*ne10);
+        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
+        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
+        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+        GGML_ASSERT(i01 >= 0 && i01 < ne01);
+
+        ggml_bf16_to_fp32_row(
+                (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+                     (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
+    }
+}
+
+static void ggml_compute_forward_get_rows_f32(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int64_t nc = ne00;
+    const int64_t nr = ggml_nelements(src1);
+
+    assert(ne0  == nc);
+    assert(ne02 == ne11);
+    assert(nb00 == sizeof(float));
+    assert(ggml_nrows(dst) == nr);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t i = ir0; i < ir1; ++i) {
+        const int64_t i12 = i/(ne11*ne10);
+        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
+        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
+        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+        GGML_ASSERT(i01 >= 0 && i01 < ne01);
+
+        ggml_vec_cpy_f32(nc,
+                (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3),
+                (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
+    }
+}
+
+static void ggml_compute_forward_get_rows(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_TQ1_0:
+        case GGML_TYPE_TQ2_0:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
+        case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ2_S:
+        case GGML_TYPE_Q4_0_4_4:
+        case GGML_TYPE_Q4_0_4_8:
+        case GGML_TYPE_Q4_0_8_8:
+            {
+                ggml_compute_forward_get_rows_q(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_get_rows_f16(params, dst);
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                ggml_compute_forward_get_rows_bf16(params, dst);
+            } break;
+        case GGML_TYPE_F32:
+        case GGML_TYPE_I32:
+            {
+                ggml_compute_forward_get_rows_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+
+    //static bool first = true;
+    //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
+    //if (first) {
+    //    first = false;
+    //} else {
+    //    for (int k = 0; k < dst->ne[1]; ++k) {
+    //        for (int j = 0; j < dst->ne[0]/16; ++j) {
+    //            for (int i = 0; i < 16; ++i) {
+    //                printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
+    //            }
+    //            printf("\n");
+    //        }
+    //        printf("\n");
+    //    }
+    //    printf("\n");
+    //    exit(0);
+    //}
+}
+
+// ggml_compute_forward_get_rows_back
+
+static void ggml_compute_forward_get_rows_back_f32_f16(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    // ggml_compute_forward_dup_same_cont(params, opt0, dst);
+
+    memset(dst->data, 0, ggml_nbytes(dst));
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nelements(src1);
+
+    GGML_ASSERT( dst->ne[0] == nc);
+    GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t));
+
+    for (int i = 0; i < nr; ++i) {
+        const int r = ((int32_t *) src1->data)[i];
+
+        for (int j = 0; j < nc; ++j) {
+            ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];
+            ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_FP16_TO_FP32(v);
+        }
+    }
+}
+
+static void ggml_compute_forward_get_rows_back_f32(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    // ggml_compute_forward_dup_same_cont(params, opt0, dst);
+
+    memset(dst->data, 0, ggml_nbytes(dst));
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nelements(src1);
+
+    GGML_ASSERT( dst->ne[0] == nc);
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < nr; ++i) {
+        const int r = ((int32_t *) src1->data)[i];
+
+        ggml_vec_add_f32(nc,
+                (float *) ((char *)  dst->data + r*dst->nb[1]),
+                (float *) ((char *)  dst->data + r*dst->nb[1]),
+                (float *) ((char *) src0->data + i*src0->nb[1]));
+    }
+}
+
+static void ggml_compute_forward_get_rows_back(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_get_rows_back_f32_f16(params, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_get_rows_back_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+
+    //static bool first = true;
+    //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
+    //if (first) {
+    //    first = false;
+    //} else {
+    //    for (int k = 0; k < dst->ne[1]; ++k) {
+    //        for (int j = 0; j < dst->ne[0]/16; ++j) {
+    //            for (int i = 0; i < 16; ++i) {
+    //                printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
+    //            }
+    //            printf("\n");
+    //        }
+    //        printf("\n");
+    //    }
+    //    printf("\n");
+    //    exit(0);
+    //}
+}
+
+// ggml_compute_forward_diag
+
+static void ggml_compute_forward_diag_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    // TODO: handle transposed/permuted matrices
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(ne00 == ne0);
+    GGML_ASSERT(ne00 == ne1);
+    GGML_ASSERT(ne01 == 1);
+    GGML_ASSERT(ne02 == ne2);
+    GGML_ASSERT(ne03 == ne3);
+
+    GGML_ASSERT(nb00 == sizeof(float));
+    GGML_ASSERT(nb0  == sizeof(float));
+
+    for (int i3 = 0; i3 < ne3; i3++) {
+        for (int i2 = 0; i2 < ne2; i2++) {
+            for (int i1 = 0; i1 < ne1; i1++) {
+                float * d = (float *)((char *)  dst->data + i3*nb3  + i2*nb2 + i1*nb1);
+                float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02);
+                for (int i0 = 0; i0 < i1; i0++) {
+                    d[i0] = 0;
+                }
+                d[i1] = s[i1];
+                for (int i0 = i1+1; i0 < ne0; i0++) {
+                    d[i0] = 0;
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_diag(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_diag_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_diag_mask_inf
+
+static void ggml_compute_forward_diag_mask_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst,
+        const float value) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int  n_past  = ((int32_t *) dst->op_params)[0];
+    const bool inplace = src0->data == dst->data;
+
+    GGML_ASSERT(n_past >= 0);
+
+    if (!inplace) {
+        if (ith == 0) {
+            // memcpy needs to be synchronized across threads to avoid race conditions.
+            // => do it in INIT phase
+            GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+            GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+            memcpy(
+                ((char *)  dst->data),
+                ((char *) src0->data),
+                ggml_nbytes(dst));
+        }
+        ggml_barrier(params->threadpool);
+    }
+
+    // TODO: handle transposed/permuted matrices
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+    const int nr = src0->ne[1];
+    const int nz = n/nr;
+
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    for (int k = 0; k < nz; k++) {
+        for (int j = ith; j < nr; j += nth) {
+            for (int i = n_past; i < nc; i++) {
+                if (i > n_past + j) {
+                    *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value;
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_diag_mask_inf(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_diag_mask_f32(params, dst, -INFINITY);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+static void ggml_compute_forward_diag_mask_zero(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_diag_mask_f32(params, dst, 0);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_soft_max
+
+static void ggml_compute_forward_soft_max_f32(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    assert(ggml_is_contiguous(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    float scale    = 1.0f;
+    float max_bias = 0.0f;
+
+    memcpy(&scale,    (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+
+    // TODO: handle transposed/permuted matrices
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    //const int64_t ne11 = src1 ? src1->ne[1] : 1;
+
+    // TODO: is this supposed to be ceil instead of floor?
+    //       https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
+    const uint32_t n_head      = ne02;
+    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
+
+    const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        // ALiBi
+        const uint32_t h = (i1/ne01)%ne02; // head
+        const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
+
+        float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
+        float * dp = (float *)((char *)  dst->data +  i1*dst->nb[1]);
+
+        // broadcast the mask across rows
+        ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
+        float       * mp_f32 = src1 ? (float       *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
+
+        ggml_vec_cpy_f32  (nc, wp, sp);
+        ggml_vec_scale_f32(nc, wp, scale);
+        if (mp_f32) {
+            if (use_f16) {
+                for (int i = 0; i < nc; ++i) {
+                    wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
+                }
+            } else {
+                for (int i = 0; i < nc; ++i) {
+                    wp[i] += slope*mp_f32[i];
+                }
+            }
+        }
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            //printf("p[%d] = %f\n", i, p[i]);
+            assert(!isnan(wp[i]));
+        }
+#endif
+
+        float max = -INFINITY;
+        ggml_vec_max_f32(nc, &max, wp);
+
+        ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
+        assert(sum > 0.0);
+
+        sum = 1.0/sum;
+        ggml_vec_scale_f32(nc, dp, sum);
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            assert(!isnan(dp[i]));
+            assert(!isinf(dp[i]));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_soft_max(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_soft_max_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+
+// ggml_compute_forward_soft_max_back
+
+static void ggml_compute_forward_soft_max_back_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_are_same_shape(src1, dst));
+
+    // TODO: handle transposed/permuted matrices
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
+        float *y  = (float *)((char *) src1->data + i1*src1->nb[1]);
+        float *dx = (float *)((char *) dst->data  + i1*dst->nb[1]);
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            //printf("p[%d] = %f\n", i, p[i]);
+            assert(!isnan(dy[i]));
+            assert(!isnan(y[i]));
+        }
+#endif
+        // Jii = yi - yi*yi
+        // Jij = -yi*yj
+        // J = diag(y)-y.T*y
+        // dx = J * dy
+        // dxk = sum_i(Jki * dyi)
+        // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
+        // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk
+        // dxk = sum_i(-yk*yi * dyi) + yk*dyk
+        // dxk = -yk * sum_i(yi * dyi) + yk*dyk
+        // dxk = -yk * dot(y, dy) + yk*dyk
+        // dxk = yk * (- dot(y, dy) + dyk)
+        // dxk = yk * (dyk - dot(y, dy))
+        //
+        // post-order:
+        // dot_y_dy := dot(y, dy)
+        // dx := dy
+        // dx := dx - dot_y_dy
+        // dx := dx * y
+
+        // linear runtime, no additional memory
+        float dot_y_dy = 0;
+        ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
+        ggml_vec_cpy_f32 (nc, dx, dy);
+        ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
+        ggml_vec_mul_f32 (nc, dx, dx, y);
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            assert(!isnan(dx[i]));
+            assert(!isinf(dx[i]));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_soft_max_back(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_soft_max_back_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_clamp
+
+static void ggml_compute_forward_clamp_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    float min;
+    float max;
+    memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+
+    GGML_ASSERT( nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    for (int j = ith; j < n; j += nth) {
+        float * dst_ptr  = (float *) ((char *)  dst->data + j*nb1);
+        float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
+
+        for (int i = 0; i < nc; i++) {
+            dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
+        }
+    }
+}
+
+static void ggml_compute_forward_clamp(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_clamp_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+        case GGML_TYPE_BF16:
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_TQ1_0:
+        case GGML_TYPE_TQ2_0:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
+        case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ2_S:
+        case GGML_TYPE_Q8_K:
+        case GGML_TYPE_Q4_0_4_4:
+        case GGML_TYPE_Q4_0_4_8:
+        case GGML_TYPE_Q4_0_8_8:
+        case GGML_TYPE_IQ4_NL_4_4:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_I64:
+        case GGML_TYPE_F64:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_rope
+
+static float rope_yarn_ramp(const float low, const float high, const int i0) {
+    const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
+    return 1 - MIN(1, MAX(0, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static void rope_yarn(
+    float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
+    float * cos_theta, float * sin_theta) {
+    // Get n-d rotational scaling corrected for extrapolation
+    float theta_interp = freq_scale * theta_extrap;
+    float theta = theta_interp;
+    if (ext_factor != 0.0f) {
+        float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+        // Get n-d magnitude scaling corrected for interpolation
+        mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
+    }
+    *cos_theta = cosf(theta) * mscale;
+    *sin_theta = sinf(theta) * mscale;
+}
+
+static void ggml_rope_cache_init(
+     float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
+     float * cache, float sin_sign, float theta_scale) {
+    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
+    float theta = theta_base;
+    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+        const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
+        rope_yarn(
+            theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
+        );
+        cache[i0 + 1] *= sin_sign;
+
+        theta *= theta_scale;
+    }
+}
+
+static void ggml_compute_forward_rope_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst,
+        const bool forward) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+    const struct ggml_tensor * src2 = dst->src[2];
+
+    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+
+    //const int n_past     = ((int32_t *) dst->op_params)[0];
+    const int n_dims     = ((int32_t *) dst->op_params)[1];
+    const int mode       = ((int32_t *) dst->op_params)[2];
+    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
+    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+
+    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
+    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
+    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
+    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
+    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
+    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
+
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr = ggml_nrows(dst);
+
+    GGML_ASSERT(n_dims <= ne0);
+    GGML_ASSERT(n_dims % 2 == 0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    // row index used to determine which thread to use
+    int ir = 0;
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+    float corr_dims[2];
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
+
+    const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
+
+    const float * freq_factors = NULL;
+    if (src2 != NULL) {
+        GGML_ASSERT(src2->type == GGML_TYPE_F32);
+        GGML_ASSERT(src2->ne[0] >= n_dims / 2);
+        freq_factors = (const float *) src2->data;
+    }
+
+    // backward process uses inverse rotation by cos and sin.
+    // cos and sin build a rotation matrix, where the inverse is the transpose.
+    // this essentially just switches the sign of sin.
+    const float sin_sign = forward ? 1.0f : -1.0f;
+
+    const int32_t * pos = (const int32_t *) src1->data;
+
+    for (int64_t i3 = 0; i3 < ne3; i3++) {
+        for (int64_t i2 = 0; i2 < ne2; i2++) {
+            const int64_t p = pos[i2];
+
+            float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
+            ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+
+            for (int64_t i1 = 0; i1 < ne1; i1++) {
+                if (ir++ < ir0) continue;
+                if (ir   > ir1) break;
+
+                if (!is_neox) {
+                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                        const float cos_theta = cache[i0 + 0];
+                        const float sin_theta = cache[i0 + 1];
+
+                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                              float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                        const float x0 = src[0];
+                        const float x1 = src[1];
+
+                        dst_data[0] = x0*cos_theta - x1*sin_theta;
+                        dst_data[1] = x0*sin_theta + x1*cos_theta;
+                    }
+                } else {
+                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                        const int64_t ic = i0/2;
+
+                        const float cos_theta = cache[i0 + 0];
+                        const float sin_theta = cache[i0 + 1];
+
+                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+                        float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
+
+                        const float x0 = src[0];
+                        const float x1 = src[n_dims/2];
+
+                        dst_data[0]        = x0*cos_theta - x1*sin_theta;
+                        dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+                    }
+                }
+
+                for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
+                    const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                    float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                    dst_data[0] = src[0];
+                    dst_data[1] = src[1];
+                }
+            }
+        }
+    }
+}
+
+// TODO: deduplicate f16/f32 code
+static void ggml_compute_forward_rope_f16(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst,
+        const bool forward) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+    const struct ggml_tensor * src2 = dst->src[2];
+
+    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+
+    //const int n_past     = ((int32_t *) dst->op_params)[0];
+    const int n_dims     = ((int32_t *) dst->op_params)[1];
+    const int mode       = ((int32_t *) dst->op_params)[2];
+    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
+    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
+    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
+    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
+    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
+    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
+    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
+
+    GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr = ggml_nrows(dst);
+
+    GGML_ASSERT(n_dims <= ne0);
+    GGML_ASSERT(n_dims % 2 == 0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    // row index used to determine which thread to use
+    int ir = 0;
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+    float corr_dims[2];
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
+
+    const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
+
+    const float * freq_factors = NULL;
+    if (src2 != NULL) {
+        GGML_ASSERT(src2->type == GGML_TYPE_F32);
+        GGML_ASSERT(src2->ne[0] >= n_dims / 2);
+        freq_factors = (const float *) src2->data;
+    }
+
+    // backward process uses inverse rotation by cos and sin.
+    // cos and sin build a rotation matrix, where the inverse is the transpose.
+    // this essentially just switches the sign of sin.
+    const float sin_sign = forward ? 1.0f : -1.0f;
+
+    const int32_t * pos = (const int32_t *) src1->data;
+
+    for (int64_t i3 = 0; i3 < ne3; i3++) {
+        for (int64_t i2 = 0; i2 < ne2; i2++) {
+            const int64_t p = pos[i2];
+
+            float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
+            ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+
+            for (int64_t i1 = 0; i1 < ne1; i1++) {
+                if (ir++ < ir0) continue;
+                if (ir   > ir1) break;
+
+                if (!is_neox) {
+                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                        const float cos_theta = cache[i0 + 0];
+                        const float sin_theta = cache[i0 + 1];
+
+                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                              ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                        const float x0 = GGML_FP16_TO_FP32(src[0]);
+                        const float x1 = GGML_FP16_TO_FP32(src[1]);
+
+                        dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+                        dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                    }
+                } else {
+                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                        const int64_t ic = i0/2;
+
+                        const float cos_theta = cache[i0 + 0];
+                        const float sin_theta = cache[i0 + 1];
+
+                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+                        ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
+
+                        const float x0 = GGML_FP16_TO_FP32(src[0]);
+                        const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
+
+                        dst_data[0]        = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+                        dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                    }
+                }
+
+                for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
+                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                    ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                    dst_data[0] = src[0];
+                    dst_data[1] = src[1];
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_rope(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_rope_f16(params, dst, true);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rope_f32(params, dst, true);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_rope_back
+
+static void ggml_compute_forward_rope_back(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_rope_f16(params, dst, false);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rope_f32(params, dst, false);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_conv_transpose_1d
+
+static void ggml_compute_forward_conv_transpose_1d_f16_f32(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nk = ne00*ne01*ne02;
+
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    if (ith == 0) {
+        memset(params->wdata, 0, params->wsize);
+
+        // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
+        {
+            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
+                    ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        dst_data[i00*ne02 + i02] = src[i00];
+                    }
+                }
+            }
+        }
+
+        // permute source data (src1) from (L x Cin) to (Cin x L)
+        {
+            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
+            ggml_fp16_t * dst_data = wdata;
+
+            for (int64_t i11 = 0; i11 < ne11; i11++) {
+                const float * const src = (float *)((char *) src1->data + i11*nb11);
+                for (int64_t i10 = 0; i10 < ne10; i10++) {
+                    dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]);
+                }
+            }
+        }
+
+        // need to zero dst since we are accumulating into it
+        memset(dst->data, 0, ggml_nbytes(dst));
+    }
+    ggml_barrier(params->threadpool);
+
+    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+
+    // total rows in dst
+    const int nr = ne1;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    ggml_fp16_t * const wdata     = (ggml_fp16_t *) params->wdata + 0;
+    ggml_fp16_t * const wdata_src = wdata + nk;
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float * dst_data = (float *)((char *) dst->data + i1*nb1);
+        ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
+        for (int i10 = 0; i10 < ne10; i10++) {
+            const int i1n = i10*ne11;
+            for (int i00 = 0; i00 < ne00; i00++) {
+                float v = 0;
+                ggml_vec_dot_f16(ne02, &v, 0,
+                        (ggml_fp16_t *)    wdata_src + i1n, 0,
+                        (ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1);
+                dst_data[i10*s0 + i00] += v;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_conv_transpose_1d_f32(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nk = ne00*ne01*ne02;
+
+    GGML_ASSERT(nb00 == sizeof(float));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    if (ith == 0) {
+        memset(params->wdata, 0, params->wsize);
+
+        // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
+        {
+            float * const wdata = (float *) params->wdata + 0;
+
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
+                    float * dst_data = wdata + i01*ne00*ne02;
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        dst_data[i00*ne02 + i02] = src[i00];
+                    }
+                }
+            }
+        }
+
+        // prepare source data (src1)
+        {
+            float * const wdata = (float *) params->wdata + nk;
+            float * dst_data = wdata;
+
+            for (int64_t i11 = 0; i11 < ne11; i11++) {
+                const float * const src = (float *)((char *) src1->data + i11*nb11);
+                for (int64_t i10 = 0; i10 < ne10; i10++) {
+                    dst_data[i10*ne11 + i11] = src[i10];
+                }
+            }
+        }
+
+        // need to zero dst since we are accumulating into it
+        memset(dst->data, 0, ggml_nbytes(dst));
+    }
+    ggml_barrier(params->threadpool);
+
+    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+
+    // total rows in dst
+    const int nr = ne1;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    float * const wdata     = (float *) params->wdata + 0;
+    float * const wdata_src = wdata + nk;
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float * dst_data = (float *)((char *) dst->data + i1*nb1);
+        float * wdata_kernel = wdata + i1*ne02*ne00;
+        for (int i10 = 0; i10 < ne10; i10++) {
+            const int i1n = i10*ne11;
+            for (int i00 = 0; i00 < ne00; i00++) {
+                float v = 0;
+                ggml_vec_dot_f32(ne02, &v, 0,
+                        wdata_src + i1n, 0,
+                        wdata_kernel + i00*ne02, 0, 1);
+                dst_data[i10*s0 + i00] += v;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_conv_transpose_1d(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_conv_transpose_1d_f16_f32(params, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_conv_transpose_1d_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_im2col_f32
+// src0: kernel [OC, IC, KH, KW]
+// src1: image [N, IC, IH, IW]
+// dst:  result [N, OH, OW, IC*KH*KW]
+static void ggml_compute_forward_im2col_f32(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t N  = is_2D ? ne13 : ne12;
+    const int64_t IC = is_2D ? ne12 : ne11;
+    const int64_t IH = is_2D ? ne11 : 1;
+    const int64_t IW = ne10;
+
+    const int64_t KH = is_2D ? ne01 : 1;
+    const int64_t KW = ne00;
+
+    const int64_t OH = is_2D ? ne2 : 1;
+    const int64_t OW = ne1;
+
+    int ofs0 = is_2D ? nb13 : nb12;
+    int ofs1 = is_2D ? nb12 : nb11;
+
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
+    {
+        float * const wdata = (float *) dst->data;
+
+        for (int64_t in = 0; in < N; in++) {
+            for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
+                for (int64_t iow = 0; iow < OW; iow++) {
+                    for (int64_t iic = ith; iic < IC; iic += nth) {
+
+                        // micro kernel
+                        float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
+                        const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
+
+                        for (int64_t ikh = 0; ikh < KH; ikh++) {  // 1
+                            for (int64_t ikw = 0; ikw < KW; ikw++) {
+                                const int64_t iiw = iow*s0 + ikw*d0 - p0;
+                                const int64_t iih = ioh*s1 + ikh*d1 - p1;
+
+                                if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
+                                } else {
+                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]);
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+
+// ggml_compute_forward_im2col_f16
+// src0: kernel [OC, IC, KH, KW]
+// src1: image [N, IC, IH, IW]
+// dst:  result [N, OH, OW, IC*KH*KW]
+static void ggml_compute_forward_im2col_f16(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F16);
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t N  = is_2D ? ne13 : ne12;
+    const int64_t IC = is_2D ? ne12 : ne11;
+    const int64_t IH = is_2D ? ne11 : 1;
+    const int64_t IW = ne10;
+
+    const int64_t KH = is_2D ? ne01 : 1;
+    const int64_t KW = ne00;
+
+    const int64_t OH = is_2D ? ne2 : 1;
+    const int64_t OW = ne1;
+
+    int ofs0 = is_2D ? nb13 : nb12;
+    int ofs1 = is_2D ? nb12 : nb11;
+
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
+    {
+        ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
+
+        for (int64_t in = 0; in < N; in++) {
+            for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
+                for (int64_t iow = 0; iow < OW; iow++) {
+                    for (int64_t iic = ith; iic < IC; iic += nth) {
+
+                        // micro kernel
+                        ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
+                        const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
+
+                        for (int64_t ikh = 0; ikh < KH; ikh++) {  // 1
+                            for (int64_t ikw = 0; ikw < KW; ikw++) {
+                                const int64_t iiw = iow*s0 + ikw*d0 - p0;
+                                const int64_t iih = ioh*s1 + ikh*d1 - p1;
+
+                                if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
+                                } else {
+                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_im2col(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+    switch (dst->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_im2col_f16(params, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_im2col_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_im2col_back_f32
+
+static void ggml_compute_forward_im2col_back_f32(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t N  = is_2D ? ne3 : ne2;
+    const int64_t IC = is_2D ? ne2 : ne1;
+    const int64_t IH = is_2D ? ne1 : 1;
+    const int64_t IW = ne0;
+
+    const int64_t KH = is_2D ? ne01 : 1;
+    const int64_t KW = ne00;
+
+    const int64_t OH = is_2D ? ne12 : 1;
+    const int64_t OW = ne11;
+
+    int ofs0 = is_2D ? nb3 : nb2;
+    int ofs1 = is_2D ? nb2 : nb1;
+
+    GGML_ASSERT(nb0  == sizeof(float));
+
+    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
+    {
+        float * const wdata = (float *) dst->data;
+
+        for (int64_t in = 0; in < N; in++) {
+            for (int64_t iic = ith; iic < IC; iic += nth) {
+                for (int64_t iih = 0; iih < IH; iih++) {
+                    for (int64_t iiw = 0; iiw < IW; iiw++) {
+
+                        // micro kernel
+                        float grad = 0.0f;
+                        for (int64_t ikh = 0; ikh < KH; ikh++) {
+                            for (int64_t ikw = 0; ikw < KW; ikw++) {
+                                // For s0 > 1 some values were skipped over in the forward pass.
+                                // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
+                                const int64_t tmpw = (iiw + p0 - ikw*d0);
+                                if (tmpw % s0 != 0) {
+                                    continue;
+                                }
+                                const int64_t iow = tmpw / s0;
+
+                                // Equivalent logic as above except for s1.
+                                int64_t ioh;
+                                if (is_2D) {
+                                    const int64_t tmph = iih + p1 - ikh*d1;
+
+                                    if (tmph % s1 != 0) {
+                                        continue;
+                                    }
+
+                                    ioh = tmph / s1;
+                                } else {
+                                    ioh = 0;
+                                }
+
+                                if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
+                                    continue;
+                                }
+
+                                const float * const src_data = (const float *) src1->data
+                                    + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
+                                grad += src_data[iic*(KH*KW) + ikh*KW + ikw];
+                            }
+                        }
+                        float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
+                        dst_data[iih*IW + iiw] = grad;
+                    }
+                }
+            }
+        }
+    }
+}
+
+// ggml_compute_forward_conv_transpose_2d
+
+static void ggml_compute_forward_conv_transpose_2d(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nk = ne00*ne01*ne02*ne03;
+
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    if (ith == 0) {
+        memset(params->wdata, 0, params->wsize);
+
+        // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
+        {
+            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+
+            for (int64_t i03 = 0; i03 < ne03; i03++) {
+                for (int64_t i02 = 0; i02 < ne02; i02++) {
+                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02);
+                    ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03;
+                    for (int64_t i01 = 0; i01 < ne01; i01++) {
+                        for (int64_t i00 = 0; i00 < ne00; i00++) {
+                            dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00];
+                        }
+                    }
+                }
+            }
+        }
+
+        // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh)
+        {
+            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
+            for (int i12 = 0; i12 < ne12; i12++) {
+                for (int i11 = 0; i11 < ne11; i11++) {
+                    const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
+                    ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
+                    for (int i10 = 0; i10 < ne10; i10++) {
+                        dst_data[i10*ne12 + i12] = GGML_FP32_TO_FP16(src[i10]);
+                    }
+                }
+            }
+        }
+
+        memset(dst->data, 0, ggml_nbytes(dst));
+    }
+    ggml_barrier(params->threadpool);
+
+    const int32_t stride = ggml_get_op_params_i32(dst, 0);
+
+    // total patches in dst
+    const int np = ne2;
+
+    // patches per thread
+    const int dp = (np + nth - 1)/nth;
+
+    // patch range for this thread
+    const int ip0 = dp*ith;
+    const int ip1 = MIN(ip0 + dp, np);
+
+    ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+    ggml_fp16_t * const wdata_src = wdata + nk;
+
+    for (int i2 = ip0; i2 < ip1; i2++) { // Cout
+        float * dst_data = (float *)((char *) dst->data + i2*nb2);
+        ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
+        for (int i11 = 0; i11 < ne11; i11++) {
+            for (int i10 = 0; i10 < ne10; i10++) {
+                const int i1n = i11*ne10*ne12 + i10*ne12;
+                for (int i01 = 0; i01 < ne01; i01++) {
+                    for (int i00 = 0; i00 < ne00; i00++) {
+                        float v = 0;
+                        ggml_vec_dot_f16(ne03, &v, 0,
+                                wdata_src + i1n, 0,
+                                wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
+                        dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
+                    }
+                }
+            }
+        }
+    }
+}
+
+// ggml_compute_forward_pool_1d_sk_p0
+
+static void ggml_compute_forward_pool_1d_sk_p0(
+        const struct ggml_compute_params * params,
+        const enum ggml_op_pool op,
+        const int k,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src = dst->src[0];
+
+    assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    const char * cdata = (const char *)src->data;
+    const char * const data_end = cdata + ggml_nbytes(src);
+    float * drow = (float *)dst->data;
+
+    const int64_t rs = dst->ne[0];
+
+    while (cdata < data_end) {
+        const void * srow = (const void *)cdata;
+        int j = 0;
+        for (int64_t i = 0; i < rs; ++i) {
+            switch (op) {
+                case GGML_OP_POOL_AVG:   drow[i] = 0;        break;
+                case GGML_OP_POOL_MAX:   drow[i] = -FLT_MAX; break;
+                case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
+            }
+            for (int ki = 0; ki < k; ++ki) {
+                const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
+                switch (op) {
+                    case GGML_OP_POOL_AVG:                         drow[i] += srow_j; break;
+                    case GGML_OP_POOL_MAX:   if (srow_j > drow[i]) drow[i]  = srow_j; break;
+                    case GGML_OP_POOL_COUNT:                       GGML_ABORT("fatal error");
+                }
+                ++j;
+            }
+            switch (op) {
+                case GGML_OP_POOL_AVG:         drow[i] /= k; break;
+                case GGML_OP_POOL_MAX:                       break;
+                case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
+            }
+        }
+
+        cdata += src->nb[1];
+        drow  += rs;
+    }
+}
+
+// ggml_compute_forward_pool_1d
+
+static void ggml_compute_forward_pool_1d(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const int32_t * opts = (const int32_t *)dst->op_params;
+    enum ggml_op_pool op = opts[0];
+    const int k0 = opts[1];
+    const int s0 = opts[2];
+    const int p0 = opts[3];
+    GGML_ASSERT(p0 == 0); // padding not supported
+    GGML_ASSERT(k0 == s0); // only s = k supported
+
+    ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst);
+}
+
+// ggml_compute_forward_pool_2d
+
+static void ggml_compute_forward_pool_2d(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src = dst->src[0];
+
+    assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    const int32_t * opts = (const int32_t *)dst->op_params;
+    enum ggml_op_pool op = opts[0];
+    const int k0 = opts[1];
+    const int k1 = opts[2];
+    const int s0 = opts[3];
+    const int s1 = opts[4];
+    const int p0 = opts[5];
+    const int p1 = opts[6];
+    const char * cdata = (const char*)src->data;
+    const char * const data_end = cdata + ggml_nbytes(src);
+
+    const int64_t px = dst->ne[0];
+    const int64_t py = dst->ne[1];
+    const int64_t pa = px * py;
+
+    float * dplane = (float *)dst->data;
+
+    const int ka = k0 * k1;
+    const int offset0 = -p0;
+    const int offset1 = -p1;
+
+    while (cdata < data_end) {
+        for (int oy = 0; oy < py; ++oy) {
+            float * const drow = dplane + oy * px;
+            for (int ox = 0; ox < px; ++ox) {
+                float * const out =  drow + ox;
+                switch (op) {
+                    case GGML_OP_POOL_AVG:     *out = 0;        break;
+                    case GGML_OP_POOL_MAX:     *out = -FLT_MAX; break;
+                    case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
+                }
+
+                const int ix = offset0 + ox * s0;
+                const int iy = offset1 + oy * s1;
+
+                for (int ky = 0; ky < k1; ++ky) {
+                    if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
+                    const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
+                    for (int kx = 0; kx < k0; ++kx) {
+                        int j = ix + kx;
+                        if (j < 0 || j >= src->ne[0]) continue;
+                        const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
+                        switch (op) {
+                            case GGML_OP_POOL_AVG:                     *out += srow_j; break;
+                            case GGML_OP_POOL_MAX: if (srow_j > *out)  *out  = srow_j; break;
+                            case GGML_OP_POOL_COUNT:               GGML_ABORT("fatal error");
+                        }
+                    }
+                }
+                switch (op) {
+                    case GGML_OP_POOL_AVG:           *out /= ka; break;
+                    case GGML_OP_POOL_MAX:                       break;
+                    case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
+                }
+            }
+        }
+
+        cdata  += src->nb[2];
+        dplane += pa;
+    }
+}
+
+// ggml_compute_forward_pool_2d_back
+
+static void ggml_compute_forward_pool_2d_back(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src  = dst->src[0];
+    const struct ggml_tensor * dstf = dst->src[1]; // forward tensor of dst
+
+    assert(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    const int32_t * opts = (const int32_t *)dst->op_params;
+    enum ggml_op_pool op = opts[0];
+    const int k0 = opts[1];
+    const int k1 = opts[2];
+    const int s0 = opts[3];
+    const int s1 = opts[4];
+    const int p0 = opts[5];
+    const int p1 = opts[6];
+
+    char       * cdata  = (char       *) dst->data;
+    const char * cdataf = (const char *) dstf->data;
+    const char * const data_end = cdata + ggml_nbytes(dst);
+
+    GGML_ASSERT(params->ith == 0);
+    memset(cdata, 0, ggml_nbytes(dst));
+
+    const int64_t px = src->ne[0];
+    const int64_t py = src->ne[1];
+    const int64_t pa = px * py;
+
+    const float * splane = (const float *) src->data;
+
+    const int ka = k0 * k1;
+    const int offset0 = -p0;
+    const int offset1 = -p1;
+
+    while (cdata < data_end) {
+        for (int oy = 0; oy < py; ++oy) {
+            const float * const srow = splane + oy * px;
+            for (int ox = 0; ox < px; ++ox) {
+                const float grad0 = srow[ox];
+
+                const int ix = offset0 + ox * s0;
+                const int iy = offset1 + oy * s1;
+
+                if (op == GGML_OP_POOL_MAX) {
+                    float maxval = -FLT_MAX;
+                    int kxmax = -1;
+                    int kymax = -1;
+
+                    for (int ky = 0; ky < k1; ++ky) {
+                        if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
+                            continue;
+                        }
+                        const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky));
+                        for (int kx = 0; kx < k0; ++kx) {
+                            int j = ix + kx;
+                            if (j < 0 || j >= dst->ne[0]) {
+                                continue;
+                            }
+
+                            const float val = dst->type == GGML_TYPE_F32 ?
+                                ((const float *) drowf)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]);
+                            if (val <= maxval) {
+                                continue;
+                            }
+
+                            maxval = val;
+                            kxmax = kx;
+                            kymax = ky;
+                        }
+                    }
+
+                    if (kxmax == -1 || kymax == -1) {
+                        continue;
+                    }
+
+                    void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax));
+                    const int j = ix + kxmax;
+                    if (dst->type == GGML_TYPE_F32) {
+                        ((float *) drow)[j] += grad0;
+                    } else {
+                        ((ggml_fp16_t *) drow)[j] = GGML_FP32_TO_FP16(grad0 + GGML_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j]));
+                    }
+                } else if (op == GGML_OP_POOL_AVG) {
+                    const float grad = grad0 / ka;
+
+                    for (int ky = 0; ky < k1; ++ky) {
+                        if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
+                            continue;
+                        }
+                        void * drow = (void *)(cdata + dst->nb[1] * (iy + ky));
+                        for (int kx = 0; kx < k0; ++kx) {
+                            int j = ix + kx;
+                            if (j < 0 || j >= dst->ne[0]) {
+                                continue;
+                            }
+
+                            if (dst->type == GGML_TYPE_F32) {
+                                ((float *) drow)[j] += grad;
+                            } else {
+                                ((ggml_fp16_t *) drow)[j] += GGML_FP32_TO_FP16(grad);
+                            }
+                        }
+                    }
+                } else {
+                    GGML_ASSERT(false);
+                }
+            }
+        }
+
+        cdata  += dst->nb[2];
+        cdataf += dst->nb[2];
+        splane += pa;
+    }
+}
+
+// ggml_compute_forward_upscale
+
+static void ggml_compute_forward_upscale_f32(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    const float sf0 = (float)ne0/src0->ne[0];
+    const float sf1 = (float)ne1/src0->ne[1];
+    const float sf2 = (float)ne2/src0->ne[2];
+    const float sf3 = (float)ne3/src0->ne[3];
+
+    // TODO: optimize
+
+    for (int64_t i3 = 0; i3 < ne3; i3++) {
+        const int64_t i03 = i3 / sf3;
+        for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
+            const int64_t i02 = i2 / sf2;
+            for (int64_t i1 = 0; i1 < ne1; i1++) {
+                const int64_t i01 = i1 / sf1;
+                for (int64_t i0 = 0; i0 < ne0; i0++) {
+                    const int64_t i00 = i0 / sf0;
+
+                    const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                          float * y = (float *)((char *)  dst->data +  i0*nb0  +  i1*nb1  +  i2*nb2  +  i3*nb3);
+
+                    *y = *x;
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_upscale(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_upscale_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+
+// ggml_compute_forward_pad
+
+static void ggml_compute_forward_pad_f32(
+    const struct ggml_compute_params * params,
+          struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    float * dst_ptr = (float *) dst->data;
+
+    // TODO: optimize
+
+    for (int64_t i2 = 0; i2 < ne2; ++i2) {
+        for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                for (int64_t i3 = 0; i3 < ne3; ++i3) {
+                    const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
+
+                    const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+
+                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+                        dst_ptr[dst_idx] = *src_ptr;
+                    } else {
+                        dst_ptr[dst_idx] = 0;
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_pad(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_pad_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+static void ggml_compute_forward_unpad_f32(
+    const struct ggml_compute_params *params,
+    struct ggml_tensor *dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    float * dst_ptr = (float *) dst->data;
+
+    // TODO: optimize
+
+    for (int64_t i2 = 0; i2 < ne2; ++i2) {
+        for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                for (int64_t i3 = 0; i3 < ne3; ++i3) {
+                    const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
+
+                    const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+
+                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+                        dst_ptr[dst_idx] = *src_ptr;
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_unpad(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_unpad_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_arange
+
+static void ggml_compute_forward_arange_f32(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    GGML_ASSERT(dst->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const float start = ggml_get_op_params_f32(dst, 0);
+    const float stop  = ggml_get_op_params_f32(dst, 1);
+    const float step  = ggml_get_op_params_f32(dst, 2);
+
+    const int64_t steps = (int64_t) ceilf((stop - start) / step);
+
+    GGML_ASSERT(ggml_nelements(dst) == steps);
+
+    for (int64_t i = ith; i < steps; i+= nth) {
+        float value = start + step * i;
+        ((float *)dst->data)[i] = value;
+    }
+}
+
+static void ggml_compute_forward_arange(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+    switch (dst->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_arange_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+static void ggml_compute_forward_timestep_embedding_f32(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    const int dim = ggml_get_op_params_i32(dst, 0);
+    const int max_period = ggml_get_op_params_i32(dst, 1);
+
+    int half = dim / 2;
+
+    for (int64_t i = 0; i < ne00; i++) {
+        float * embed_data = (float *)((char *)  dst->data +  i*nb1);
+        for (int64_t j = ith; j < half; j += nth) {
+            float timestep = ((float *)src0->data)[i];
+            float freq = (float)expf(-logf(max_period) * j / half);
+            float arg = timestep * freq;
+            embed_data[j] = cosf(arg);
+            embed_data[j + half] = sinf(arg);
+        }
+        if (dim % 2 != 0 && ith == 0) {
+            embed_data[dim] = 0.f;
+        }
+    }
+}
+
+static void ggml_compute_forward_timestep_embedding(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_timestep_embedding_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_argsort
+
+static void ggml_compute_forward_argsort_f32(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(nb0 == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t nr = ggml_nrows(src0);
+
+    enum ggml_sort_order order = (enum ggml_sort_order) ggml_get_op_params_i32(dst, 0);
+
+    for (int64_t i = ith; i < nr; i += nth) {
+        int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
+        const float * src_data = (float *)((char *) src0->data + i*nb01);
+
+        for (int64_t j = 0; j < ne0; j++) {
+            dst_data[j] = j;
+        }
+
+        // C doesn't have a functional sort, so we do a bubble sort instead
+        for (int64_t j = 0; j < ne0; j++) {
+            for (int64_t k = j + 1; k < ne0; k++) {
+                if ((order == GGML_SORT_ORDER_ASC  && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
+                    (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
+                    int32_t tmp = dst_data[j];
+                    dst_data[j] = dst_data[k];
+                    dst_data[k] = tmp;
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_argsort(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_argsort_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_flash_attn_ext
+
+static void ggml_compute_forward_flash_attn_ext_f16(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * q,
+        const struct ggml_tensor * k,
+        const struct ggml_tensor * v,
+        const struct ggml_tensor * mask,
+        struct ggml_tensor * dst) {
+
+    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t D = neq0;
+    const int64_t N = neq1;
+
+    GGML_ASSERT(ne0 == D);
+    GGML_ASSERT(ne2 == N);
+
+    // input tensor rows must be contiguous
+    GGML_ASSERT(nbq0 == ggml_type_size(q->type));
+    GGML_ASSERT(nbk0 == ggml_type_size(k->type));
+    GGML_ASSERT(nbv0 == ggml_type_size(v->type));
+
+    GGML_ASSERT(neq0 == D);
+    GGML_ASSERT(nek0 == D);
+    GGML_ASSERT(nev0 == D);
+
+    GGML_ASSERT(neq1 == N);
+    GGML_ASSERT(nev0 == D);
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    // broadcast factors
+    const int64_t rk2 = neq2/nek2;
+    const int64_t rk3 = neq3/nek3;
+
+    const int64_t rv2 = neq2/nev2;
+    const int64_t rv3 = neq3/nev3;
+
+    // parallelize by q rows using ggml_vec_dot_f32
+
+    // total rows in q
+    const int nr = neq1*neq2*neq3;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    float scale         = 1.0f;
+    float max_bias      = 0.0f;
+    float logit_softcap = 0.0f;
+
+    memcpy(&scale,         (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias,      (float *) dst->op_params + 1, sizeof(float));
+    memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
+
+    if (logit_softcap != 0) {
+        scale /= logit_softcap;
+    }
+
+    const uint32_t n_head      = neq2;
+    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    enum ggml_type    const k_vec_dot_type = type_traits_cpu[k->type].vec_dot_type;
+    ggml_from_float_t const q_to_vec_dot   = type_traits_cpu[k_vec_dot_type].from_float;
+    ggml_vec_dot_t    const kq_vec_dot     = type_traits_cpu[k->type].vec_dot;
+    ggml_to_float_t   const v_to_float     = ggml_get_type_traits(v->type)->to_float;
+
+    GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
+    GGML_ASSERT(v_to_float   && "fattn: unsupported V-type");
+
+    // loop over n_batch and n_head
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // q indices
+        const int iq3 = ir/(neq2*neq1);
+        const int iq2 = (ir - iq3*neq2*neq1)/neq1;
+        const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
+
+        const uint32_t h = iq2; // head index
+        const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
+
+        float S = 0.0f;      // sum
+        float M = -INFINITY; // maximum KQ value
+
+        float       * VKQ32 = (float       *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
+        float       * V32   =                 (VKQ32 + 1*D); // (temporary) FP32 V buffer
+        ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
+        ggml_fp16_t * Q_q   = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
+
+        if (v->type == GGML_TYPE_F16) {
+            memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
+        } else {
+            memset(VKQ32, 0, D*sizeof(float));
+        }
+
+        const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
+
+        // k indices
+        const int ik3 = iq3 / rk3;
+        const int ik2 = iq2 / rk2;
+
+        // v indices
+        const int iv3 = iq3 / rv3;
+        const int iv2 = iq2 / rv2;
+
+        const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
+        q_to_vec_dot(pq, Q_q, D);
+
+        // online softmax / attention
+        // loop over n_kv and n_head_kv
+        // ref: https://arxiv.org/pdf/2112.05682.pdf
+        for (int64_t ic = 0; ic < nek1; ++ic) {
+            const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
+            if (mv == -INFINITY) {
+                continue;
+            }
+
+            float s; // KQ value
+
+            const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
+            kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
+
+            s = s*scale; // scale KQ value
+
+            if (logit_softcap != 0.0f) {
+                s = logit_softcap*tanhf(s);
+            }
+
+            s += mv; // apply mask
+
+            const float Mold = M;
+
+            float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
+            float vs = 1.0f; // post-softmax KQ value, expf(s - M)
+
+            const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
+
+            if (v->type == GGML_TYPE_F16) {
+                if (s > M) {
+                    // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
+                    M = s;
+                    ms = expf(Mold - M);
+
+                    // V = V*expf(Mold - M)
+                    ggml_vec_scale_f16(D, VKQ16, ms);
+                } else {
+                    // no new maximum, ms == 1.0f, vs != 1.0f
+                    vs = expf(s - M);
+                }
+
+                // V += v*expf(s - M)
+                ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
+            } else {
+                if (s > M) {
+                    // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
+                    M = s;
+                    ms = expf(Mold - M);
+
+                    // V = V*expf(Mold - M)
+                    ggml_vec_scale_f32(D, VKQ32, ms);
+                } else {
+                    // no new maximum, ms == 1.0f, vs != 1.0f
+                    vs = expf(s - M);
+                }
+
+                v_to_float(v_data, V32, D);
+
+                // V += v*expf(s - M)
+                ggml_vec_mad_f32(D, VKQ32, V32, vs);
+            }
+
+            S = S*ms + vs; // scale and increment sum with partial sum
+        }
+
+        if (v->type == GGML_TYPE_F16) {
+            for (int64_t d = 0; d < D; ++d) {
+                VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
+            }
+        }
+
+        // V /= S
+        const float S_inv = 1.0f/S;
+        ggml_vec_scale_f32(D, VKQ32, S_inv);
+
+        // dst indices
+        const int i1 = iq1;
+        const int i2 = iq2;
+        const int i3 = iq3;
+
+        // original
+        //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
+
+        // permute(0, 2, 1, 3)
+        memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
+    }
+}
+
+static void ggml_compute_forward_flash_attn_ext(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * q,
+        const struct ggml_tensor * k,
+        const struct ggml_tensor * v,
+        const struct ggml_tensor * mask,
+        struct ggml_tensor * dst) {
+    switch (dst->op_params[3]) {
+        case GGML_PREC_DEFAULT:
+        case GGML_PREC_F32:
+            {
+                // uses F32 accumulators
+                ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_flash_attn_back
+
+static void ggml_compute_forward_flash_attn_back_f32(
+        const struct ggml_compute_params * params,
+        const bool masked,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * q = dst->src[0];
+    const struct ggml_tensor * k = dst->src[1];
+    const struct ggml_tensor * v = dst->src[2];
+    const struct ggml_tensor * d = dst->src[3];
+
+    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
+    GGML_TENSOR_LOCALS(int64_t, ned, d,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbd, d,   nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t D = neq0;
+    const int64_t N = neq1;
+    const int64_t P = nek1 - N;
+    const int64_t M = P + N;
+
+    const int Mup  = ggml_up(M, GGML_SOFT_MAX_UNROLL);
+    const int mxDM = MAX(D, Mup);
+
+    // GGML_ASSERT(ne0 == D);
+    // GGML_ASSERT(ne1 == N);
+    GGML_ASSERT(P >= 0);
+
+    GGML_ASSERT(nbq0 == sizeof(float));
+    GGML_ASSERT(nbk0 == sizeof(float));
+    GGML_ASSERT(nbv0 == sizeof(float));
+
+    GGML_ASSERT(neq0 == D);
+    GGML_ASSERT(nek0 == D);
+    GGML_ASSERT(nev1 == D);
+    GGML_ASSERT(ned0 == D);
+
+    GGML_ASSERT(neq1 == N);
+    GGML_ASSERT(nek1 == N + P);
+    GGML_ASSERT(nev1 == D);
+    GGML_ASSERT(ned1 == N);
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    if (ith == 0) {
+        memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
+    }
+    ggml_barrier(params->threadpool);
+
+    const int64_t elem_q = ggml_nelements(q);
+    const int64_t elem_k = ggml_nelements(k);
+
+    enum ggml_type result_type = dst->type;
+    GGML_ASSERT(ggml_blck_size(result_type) == 1);
+    const size_t tsize = ggml_type_size(result_type);
+
+    const size_t offs_q = 0;
+    const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
+    const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
+
+    void * grad_q = (char *) dst->data;
+    void * grad_k = (char *) dst->data + offs_k;
+    void * grad_v = (char *) dst->data + offs_v;
+
+    const size_t nbgq1 = nb0*neq0;
+    const size_t nbgq2 = nb0*neq0*neq1;
+    const size_t nbgq3 = nb0*neq0*neq1*neq2;
+
+    const size_t nbgk1 = nb0*nek0;
+    const size_t nbgk2 = nb0*nek0*nek1;
+    const size_t nbgk3 = nb0*nek0*nek1*neq2;
+
+    const size_t nbgv1 = nb0*nev0;
+    const size_t nbgv2 = nb0*nev0*nev1;
+    const size_t nbgv3 = nb0*nev0*nev1*neq2;
+
+    // parallelize by k rows using ggml_vec_dot_f32
+
+    // total rows in k
+    const int nr = nek2*nek3;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    const float scale = 1.0f/sqrtf(D);
+
+    //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
+
+    // how often k2 (and v2) is repeated in q2
+    int nrep = neq2/nek2;
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // q indices
+        const int ik3 = ir/(nek2);
+        const int ik2 = ir - ik3*nek2;
+
+        const int iq3 = ik3;
+        const int id3 = ik3;
+        const int iv3 = ik3;
+        const int iv2 = ik2;
+
+        for (int irep = 0; irep < nrep; ++irep) {
+            const int iq2 = ik2 + irep*nek2;
+            const int id2 = iq2;
+
+            // (ik2 + irep*nek2) % nek2 == ik2
+            for (int iq1 = 0; iq1 < neq1; ++iq1) {
+                const int id1 = iq1;
+
+                // not sure about CACHE_LINE_SIZE_F32..
+                // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
+                float * S  = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
+                float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
+
+                for (int i = M; i < Mup; ++i) {
+                    S[i] = -INFINITY;
+                }
+
+                const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
+                for (int64_t ic = 0; ic < masked_begin; ++ic) {
+                    // k indices
+                    const int ik1 = ic;
+
+                    // S indices
+                    const int i1 = ik1;
+
+                    ggml_vec_dot_f32(neq0,
+                            S + i1, 0,
+                            (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
+                            (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
+                }
+
+                // scale
+                ggml_vec_scale_f32(masked_begin, S, scale);
+
+                for (int64_t i = masked_begin; i < M; i++) {
+                    S[i] = -INFINITY;
+                }
+
+                // softmax
+                // exclude known -INF S[..] values from max and loop
+                // dont forget to set their SM values to zero
+                {
+                    float max = -INFINITY;
+                    ggml_vec_max_f32(masked_begin, &max, S);
+
+                    ggml_float sum = 0.0;
+                    {
+#ifdef GGML_SOFT_MAX_ACCELERATE
+                        max = -max;
+                        vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
+                        vvexpf(SM, SM, &Mup);
+                        ggml_vec_sum_f32(Mup, &sum, SM);
+#else
+                        sum = ggml_vec_soft_max_f32(Mup, SM, S, max);
+#endif
+                    }
+
+                    assert(sum > 0.0);
+
+                    sum = 1.0/sum;
+                    ggml_vec_scale_f32(masked_begin, SM, sum);
+
+                }
+
+                // step-by-step explanation
+                {
+                    // forward-process                    shape      grads from backward process
+                    // parallel_for ik2,ik3:
+                    //  for irep:
+                    //   iq2 = ik2 + irep*nek2
+                    //   k[:D,:M,:,:]                     [D,M,:,:]  grad[k][:D,:M,ik2,ik3]  += grad[kcur]
+                    //   q[:D,:N,:,:]                     [D,N,:,:]  grad[q][:D,iq1,iq2,iq3] += grad[qcur]
+                    //   v[:M,:D,:,:]                     [M,D,:,:]  grad[v][:M,:D,iv2,iv3]  += grad[vcur]
+                    //   for iq1:
+                    //    kcur   = k[:D,:M,ik2,ik3]       [D,M,1,1]  grad[kcur] = grad[S1].T @ qcur
+                    //    qcur   = q[:D,iq1,iq2,iq3]      [D,1,1,1]  grad[qcur] = grad[S1]   @ kcur
+                    //    vcur   = v[:M,:D,iv2,iv3]       [M,D,1,1]  grad[vcur] = grad[S5].T @ S4
+                    //    S0     = -Inf                   [D,1,1,1]
+                    //   ~S1[i]  = dot(kcur[:D,i], qcur)
+                    //    S1     = qcur @ kcur.T          [M,1,1,1]  grad[S1]   = grad[S2] * scale
+                    //    S2     = S1 * scale             [M,1,1,1]  grad[S2]   = diag_mask_zero(grad[S3], P)
+                    //    S3     = diag_mask_inf(S2, P)   [M,1,1,1]  grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
+                    //    S4     = softmax(S3)            [M,1,1,1]  grad[S4]   = grad[S5] @ vcur
+                    //   ~S5[i]  = dot(vcur[:,i], S4)
+                    //    S5     = S4 @ vcur.T            [D,1,1,1]  grad[S5]   = d[:D,id1,id2,id3]
+                    //   ~dst[i,iq1,iq2,iq3]  = S5[i]              ^
+                    //    dst[:D,iq1,iq2,iq3] = S5                 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
+                    // dst                               backward-/ grad[dst]                 = d
+                    //
+                    // output gradients with their dependencies:
+                    //
+                    // grad[kcur] = grad[S1].T @ qcur
+                    // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
+                    // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
+                    // grad[S4]   = grad[S5] @ vcur
+                    // grad[S4]   = d[:D,id1,id2,id3] @ vcur
+                    // grad[qcur] = grad[S1]   @ kcur
+                    // grad[vcur] = grad[S5].T @ S4
+                    // grad[vcur] = d[:D,id1,id2,id3].T @ S4
+                    //
+                    // in post-order:
+                    //
+                    // S1         = qcur @ kcur.T
+                    // S2         = S1 * scale
+                    // S3         = diag_mask_inf(S2, P)
+                    // S4         = softmax(S3)
+                    // grad[S4]   = d[:D,id1,id2,id3] @ vcur
+                    // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
+                    // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
+                    // grad[qcur] = grad[S1]   @ kcur
+                    // grad[kcur] = grad[S1].T @ qcur
+                    // grad[vcur] = d[:D,id1,id2,id3].T @ S4
+                    //
+                    // using less variables (SM=S4):
+                    //
+                    // S             = diag_mask_inf(qcur @ kcur.T * scale, P)
+                    // SM            = softmax(S)
+                    // S             = d[:D,iq1,iq2,iq3] @ vcur
+                    // dot_SM_gradSM = dot(SM, S)
+                    // S             = SM * (S - dot(SM, S))
+                    // S             = diag_mask_zero(S, P) * scale
+                    //
+                    // grad[q][:D,iq1,iq2,iq3] += S   @ kcur
+                    // grad[k][:D,:M,ik2,ik3]  += S.T @ qcur
+                    // grad[v][:M,:D,iv2,iv3]  += d[:D,id1,id2,id3].T @ SM
+                }
+
+                // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
+                // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
+                // for ic:
+                //   S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
+                // exclude known future zero S[..] values from operation
+                ggml_vec_set_f32(masked_begin, S, 0);
+                for (int64_t ic = 0; ic < D; ++ic) {
+                    ggml_vec_mad_f32(masked_begin,
+                            S,
+                             (float *) ((char *) v->data + (          ic*nbv1  + iv2*nbv2 + iv3*nbv3)),
+                            *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
+                }
+
+                // S = SM * (S - dot(SM, S))
+                float dot_SM_gradSM = 0;
+                ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1);
+                ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
+                ggml_vec_mul_f32 (masked_begin, S, S, SM);
+
+                // S = diag_mask_zero(S, P) * scale
+                // already done by above ggml_vec_set_f32
+
+                // exclude known zero S[..] values from operation
+                ggml_vec_scale_f32(masked_begin, S, scale);
+
+                // S    shape [M,1]
+                // SM   shape [M,1]
+                // kcur shape [D,M]
+                // qcur shape [D,1]
+                // vcur shape [M,D]
+
+                // grad[q][:D,iq1,iq2,iq3] += S @ kcur
+                // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
+                // for ic:
+                //  grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
+                // exclude known zero S[..] values from loop
+                for (int64_t ic = 0; ic < masked_begin; ++ic) {
+                    ggml_vec_mad_f32(D,
+                            (float *) ((char *) grad_q  + (iq1*nbgq1 + iq2*nbgq2  + iq3*nbgq3)),
+                            (float *) ((char *) k->data + (ic*nbk1   + ik2*nbk2   + ik3*nbk3)),
+                            S[ic]);
+                }
+
+                // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
+                // for ic:
+                //  grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
+                //  grad[k][:D,ic,iq2,iq3] += S[ic]     * qcur[:D,0]
+                // exclude known zero S[..] values from loop
+                for (int64_t ic = 0; ic < masked_begin; ++ic) {
+                    ggml_vec_mad_f32(D,
+                            (float *) ((char *) grad_k  + (ic*nbgk1  + ik2*nbgk2  + ik3*nbgk3)),
+                            (float *) ((char *) q->data + (iq1*nbq1  + iq2*nbq2   + iq3*nbq3)),
+                            S[ic]);
+                }
+
+                // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T       @ SM
+                // for ic:
+                //  grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
+                //  grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3]         * SM[:M]
+                // exclude known zero SM[..] values from mad
+                for (int64_t ic = 0; ic < D; ++ic) {
+                    ggml_vec_mad_f32(masked_begin,
+                            (float *) ((char *) grad_v   + (          ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
+                            SM,
+                            *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2  + id3*nbd3)));
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_flash_attn_back(
+        const struct ggml_compute_params * params,
+        const bool masked,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * q = dst->src[0];
+
+    switch (q->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_flash_attn_back_f32(params, masked, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_ssm_conv
+
+static void ggml_compute_forward_ssm_conv_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    const struct ggml_tensor * src0 = dst->src[0]; // conv_x
+    const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc  = src1->ne[0]; // d_conv
+    const int ncs = src0->ne[0]; // d_conv - 1 + n_t
+    const int nr  = src0->ne[1]; // d_inner
+    const int n_t =  dst->ne[1]; // tokens per sequence
+    const int n_s =  dst->ne[2]; // number of sequences in the batch
+
+    GGML_ASSERT( dst->ne[0] == nr);
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT(src1->nb[0] == sizeof(float));
+    GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+    const int ir  = ir1 - ir0;
+
+    for (int i3 = 0; i3 < n_s; ++i3) {
+        for (int i2 = 0; i2 < n_t; ++i2) {
+            // {d_conv - 1 + n_t, d_inner, n_seqs}
+            // sliding window
+            const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
+            const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
+            float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
+
+            // TODO: transpose the output for smaller strides for big batches?
+            // d_inner
+            for (int i1 = 0; i1 < ir; ++i1) {
+                // rowwise dot product
+                // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
+                float sumf = 0.0f;
+
+                // d_conv
+                for (int i0 = 0; i0 < nc; ++i0) {
+                    sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
+                }
+                x[i1] = sumf;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_ssm_conv(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    switch (dst->src[0]->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_ssm_conv_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_ssm_scan
+
+static void ggml_compute_forward_ssm_scan_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    const struct ggml_tensor * src0 = dst->src[0]; // s
+    const struct ggml_tensor * src1 = dst->src[1]; // x
+    const struct ggml_tensor * src2 = dst->src[2]; // dt
+    const struct ggml_tensor * src3 = dst->src[3]; // A
+    const struct ggml_tensor * src4 = dst->src[4]; // B
+    const struct ggml_tensor * src5 = dst->src[5]; // C
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t nc  = src0->ne[0]; // d_state
+    const int64_t nr  = src0->ne[1]; // d_inner
+    const int64_t n_t = src1->ne[1]; // number of tokens per sequence
+    const int64_t n_s = src0->ne[2]; // number of sequences in the batch
+
+    GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT(src1->nb[0] == sizeof(float));
+    GGML_ASSERT(src2->nb[0] == sizeof(float));
+    GGML_ASSERT(src3->nb[0] == sizeof(float));
+    GGML_ASSERT(src4->nb[0] == sizeof(float));
+    GGML_ASSERT(src5->nb[0] == sizeof(float));
+    // required for the dot product between s and C
+    GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
+    // required for per-sequence offsets for states
+    GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
+    // required to get correct offset for state destination (i.e. src1->nb[3])
+    GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+    const int ir  = ir1 - ir0;
+
+    for (int i3 = 0; i3 < n_s; ++i3) {
+        for (int i2 = 0; i2 < n_t; ++i2) {
+            const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
+            const float * x  = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
+            const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
+            const float * A  = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
+            const float * B  = (const float *) ((const char *) src4->data +  i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
+            const float * C  = (const float *) ((const char *) src5->data +  i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
+                  float * y  = (      float *) ((      char *) dst->data  + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
+                  float * s  = (      float *) ((      char *) dst->data  + ir0*(src0->nb[1]) + i3*(src0->nb[2]) +     src1->nb[3]);  // {d_state, d_inner, n_s}
+
+            // use the output as the source for the next token-wise iterations
+            if (i2 > 0) { s0 = s; }
+
+            // d_inner
+            for (int i1 = 0; i1 < ir; ++i1) {
+                // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
+                float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
+                float x_dt = x[i1] * dt_soft_plus;
+                float sumf = 0.0f;
+                // d_state
+                for (int i0 = 0; i0 < nc; ++i0) {
+                    int i = i0 + i1*nc;
+                    // state = prev_state * dA + dB * x
+                    float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
+                    // y = rowwise_dotprod(state, C)
+                    sumf += state * C[i0];
+                    s[i] = state;
+                }
+                y[i1] = sumf;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_ssm_scan(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    switch (dst->src[0]->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_ssm_scan_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_win_part
+
+static void ggml_compute_forward_win_part_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    UNUSED(params);
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
+
+    const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
+    const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
+    const int32_t w    = ((const int32_t *)(dst->op_params))[2];
+
+    assert(ne00 == ne0);
+    assert(ne3  == nep0*nep1);
+
+    // TODO: optimize / multi-thread
+    for (int py = 0; py < nep1; ++py) {
+        for (int px = 0; px < nep0; ++px) {
+            const int64_t i3 = py*nep0 + px;
+            for (int64_t i2 = 0; i2 < ne2; ++i2) {
+                for (int64_t i1 = 0; i1 < ne1; ++i1) {
+                    for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                        const int64_t i02 = py*w + i2;
+                        const int64_t i01 = px*w + i1;
+                        const int64_t i00 = i0;
+
+                        const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0    + i1*ne0   + i0;
+                        const int64_t j =                  i02*ne01*ne00 + i01*ne00 + i00;
+
+                        if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
+                            ((float *) dst->data)[i] = 0.0f;
+                        } else {
+                            ((float *) dst->data)[i] = ((float *) src0->data)[j];
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_win_part(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_win_part_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_win_unpart
+
+static void ggml_compute_forward_win_unpart_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    UNUSED(params);
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
+
+    const int32_t w = ((const int32_t *)(dst->op_params))[0];
+
+    // padding
+    const int px = (w - ne1%w)%w;
+    //const int py = (w - ne2%w)%w;
+
+    const int npx = (px + ne1)/w;
+    //const int npy = (py + ne2)/w;
+
+    assert(ne0 == ne00);
+
+    // TODO: optimize / multi-thread
+    for (int64_t i2 = 0; i2 < ne2; ++i2) {
+        for (int64_t i1 = 0; i1 < ne1; ++i1) {
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                const int ip2 = i2/w;
+                const int ip1 = i1/w;
+
+                const int64_t i02 = i2%w;
+                const int64_t i01 = i1%w;
+                const int64_t i00 = i0;
+
+                const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;
+                const int64_t j =                                  i2*ne1*ne0    + i1*ne0   + i0;
+
+                ((float *) dst->data)[j] = ((float *) src0->data)[i];
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_win_unpart(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_win_unpart_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+//gmml_compute_forward_unary
+
+static void ggml_compute_forward_unary(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const enum ggml_unary_op op = ggml_get_unary_op(dst);
+
+    switch (op) {
+        case GGML_UNARY_OP_ABS:
+            {
+                ggml_compute_forward_abs(params, dst);
+            } break;
+        case GGML_UNARY_OP_SGN:
+            {
+                ggml_compute_forward_sgn(params, dst);
+            } break;
+        case GGML_UNARY_OP_NEG:
+            {
+                ggml_compute_forward_neg(params, dst);
+            } break;
+        case GGML_UNARY_OP_STEP:
+            {
+                ggml_compute_forward_step(params, dst);
+            } break;
+        case GGML_UNARY_OP_TANH:
+            {
+                ggml_compute_forward_tanh(params, dst);
+            } break;
+        case GGML_UNARY_OP_ELU:
+            {
+                ggml_compute_forward_elu(params, dst);
+            } break;
+        case GGML_UNARY_OP_RELU:
+            {
+                ggml_compute_forward_relu(params, dst);
+            } break;
+        case GGML_UNARY_OP_SIGMOID:
+            {
+                ggml_compute_forward_sigmoid(params, dst);
+            } break;
+        case GGML_UNARY_OP_GELU:
+            {
+                ggml_compute_forward_gelu(params, dst);
+            } break;
+        case GGML_UNARY_OP_GELU_QUICK:
+            {
+                ggml_compute_forward_gelu_quick(params, dst);
+            } break;
+        case GGML_UNARY_OP_SILU:
+            {
+                ggml_compute_forward_silu(params, dst);
+            } break;
+        case GGML_UNARY_OP_HARDSWISH:
+            {
+                ggml_compute_forward_hardswish(params, dst);
+            } break;
+        case GGML_UNARY_OP_HARDSIGMOID:
+            {
+                ggml_compute_forward_hardsigmoid(params, dst);
+            } break;
+        case GGML_UNARY_OP_EXP:
+            {
+                ggml_compute_forward_exp(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_get_rel_pos
+
+static void ggml_compute_forward_get_rel_pos_f16(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    UNUSED(params);
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    const int64_t w = ne1;
+
+    ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
+    ggml_fp16_t * dst_data  = (ggml_fp16_t *) dst->data;
+
+    for (int64_t i2 = 0; i2 < ne2; ++i2) {
+        for (int64_t i1 = 0; i1 < ne1; ++i1) {
+            const int64_t pos = (w - i1 - 1) + i2;
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_get_rel_pos(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+        case GGML_TYPE_BF16:
+            {
+                ggml_compute_forward_get_rel_pos_f16(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_add_rel_pos
+
+static void ggml_compute_forward_add_rel_pos_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+    const struct ggml_tensor * src2 = dst->src[2];
+
+    const bool inplace = (bool) ((int32_t *) dst->op_params)[0];
+    if (!inplace) {
+        if (params->ith == 0) {
+            memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst));
+        }
+        ggml_barrier(params->threadpool);
+    }
+    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359
+
+    float * src1_data = (float *) src1->data;
+    float * src2_data = (float *) src2->data;
+    float * dst_data  = (float *) dst->data;
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    // total patches in dst
+    const int np = ne13;
+
+    // patches per thread
+    const int dp = (np + nth - 1)/nth;
+
+    // patch range for this thread
+    const int ip0 = dp*ith;
+    const int ip1 = MIN(ip0 + dp, np);
+
+    for (int64_t i13 = ip0; i13 < ip1; ++i13) {
+        for (int64_t i12 = 0; i12 < ne12; ++i12) {
+            for (int64_t i11 = 0; i11 < ne11; ++i11) {
+                const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10;
+                for (int64_t i10 = 0; i10 < ne10; ++i10) {
+                    const int64_t jp0  = jp1 + i10;
+                    const float src1_e = src1_data[jp0];
+                    const float src2_e = src2_data[jp0];
+
+                    const int64_t jdh = jp0 * ne10;
+                    const int64_t jdw = jdh - (ne10 - 1) * i10;
+
+                    for (int64_t j = 0; j < ne10; ++j) {
+                        dst_data[jdh + j     ] += src2_e;
+                        dst_data[jdw + j*ne10] += src1_e;
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_add_rel_pos(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_add_rel_pos_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_rwkv_wkv6
+
+static void ggml_compute_forward_rwkv_wkv6_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    const int64_t T = dst->src[1]->ne[3];
+    const int64_t C = dst->ne[0];
+    const int64_t HEADS = dst->src[1]->ne[2];
+    const int64_t n_seqs = dst->src[5]->ne[1];
+    const int64_t head_size = C / HEADS;
+
+    float * dst_data = (float *) dst->data;
+    float * state = ((float *) dst->data) + C * T;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    if (ith >= HEADS) {
+        return;
+    }
+
+    const int h_start = (HEADS * ith) / nth;
+    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
+                (HEADS * (ith + 1)) / nth : HEADS;
+
+    float * k =          (float *) dst->src[0]->data;
+    float * v =          (float *) dst->src[1]->data;
+    float * r =          (float *) dst->src[2]->data;
+    float * time_faaaa = (float *) dst->src[3]->data;
+    float * time_decay = (float *) dst->src[4]->data;
+
+    size_t t_stride = HEADS * head_size; // Same to C
+
+    size_t h_stride = C / HEADS;
+    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
+    size_t h_stride_2d = head_size * head_size;
+
+    if (ith == 0) {
+        memset(dst_data, 0, T * C * sizeof(float));
+    }
+    ggml_barrier(params->threadpool);
+
+
+    #if defined(__AVX__) && !defined(__AVX512F__)
+        #define GGML_F32X GGML_F32x8
+        #define GGML_F32X_SET1 GGML_F32x8_SET1
+        #define GGML_F32X_LOAD GGML_F32x8_LOAD
+        #define GGML_F32X_STORE GGML_F32x8_STORE
+        #define GGML_F32X_MUL GGML_F32x8_MUL
+        #define GGML_F32X_FMA GGML_F32x8_FMA
+        #define WKV_VECTOR_SIZE 8
+    #elif defined(__AVX512F__)
+        #define GGML_F32X GGML_F32x16
+        #define GGML_F32X_SET1 GGML_F32x16_SET1
+        #define GGML_F32X_LOAD GGML_F32x16_LOAD
+        #define GGML_F32X_STORE GGML_F32x16_STORE
+        #define GGML_F32X_MUL GGML_F32x16_MUL
+        #define GGML_F32X_FMA GGML_F32x16_FMA
+        #define WKV_VECTOR_SIZE 16
+    #elif defined(__ARM_NEON) && defined(__aarch64__)
+        #define GGML_F32X GGML_F32x4
+        #define GGML_F32X_SET1 GGML_F32x4_SET1
+        #define GGML_F32X_LOAD GGML_F32x4_LOAD
+        #define GGML_F32X_STORE GGML_F32x4_STORE
+        #define GGML_F32X_MUL GGML_F32x4_MUL
+        #define GGML_F32X_FMA GGML_F32x4_FMA
+        #define WKV_VECTOR_SIZE 4
+    #endif
+
+    #ifdef WKV_VECTOR_SIZE
+        const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
+
+        for (int64_t t = 0; t < T; t++) {
+            size_t t_offset = t * t_stride;
+            size_t state_offset = head_size * C * (t / (T / n_seqs));
+            float * state_cur = state + state_offset;
+            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
+
+            for (int64_t h = h_start; h < h_end; h++) {
+                size_t h_offset = h * h_stride;
+                size_t t_h_offset = t_offset + h_offset;
+                size_t h_2d_offset = h * h_stride_2d;
+
+                for (int64_t i = 0; i < head_size; i++) {
+                    size_t t_h_i_offset = t_h_offset + i;
+                    size_t h_i_offset = h_offset + i;
+                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+                    float k_val = k[t_h_i_offset];
+                    float r_val = r[t_h_i_offset];
+                    float time_faaaa_val = time_faaaa[h_i_offset];
+                    float time_decay_val = time_decay[t_h_i_offset];
+
+                    // Broadcast scalar values to vectors
+                    GGML_F32X k_vec = GGML_F32X_SET1(k_val);
+                    GGML_F32X r_vec = GGML_F32X_SET1(r_val);
+                    GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
+                    GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
+
+                    for (int64_t j = 0; j < vec_count; j++) {
+                        size_t base_j = j * WKV_VECTOR_SIZE;
+                        size_t t_h_j_offset = t_h_offset + base_j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
+
+                        // Load x elements at once
+                        GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
+                        GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
+                        GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
+
+                        // Compute kv = v * k
+                        GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
+
+                        // Compute temp = kv * time_faaaa + prev_state
+                        GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);
+
+                        // Update dst: dst += temp * r
+                        dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);
+                        GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
+
+                        // Update state: state = prev_state * time_decay + kv
+                        GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);
+                        GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec);
+                    }
+
+                    // Handle remaining elements, this will not be used.
+                    for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) {
+                        size_t t_h_j_offset = t_h_offset + j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
+                        float v_val = v[t_h_j_offset];
+                        float kv_val = v_val * k_val;
+                        float prev_state_val = state_prev[h_2d_i_j_offset];
+                        float temp_val = kv_val * time_faaaa_val + prev_state_val;
+                        dst_data[t_h_j_offset] += temp_val * r_val;
+                        state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
+                    }
+                }
+            }
+        }
+
+    #else
+        // basically fused operations:
+        // dst = r @ (time_faaaa * (k @ v) + state),
+        // state = time_decay * state + (k @ v),
+        // recursive through each token
+        for (int64_t t = 0; t < T; t++) {
+            size_t t_offset = t * t_stride;
+            size_t state_offset = head_size * C * (t / (T / n_seqs));
+            float * state_cur = state + state_offset;
+            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
+
+            for (int64_t h = h_start; h < h_end; h++) {
+                size_t h_offset = h * h_stride;
+                size_t t_h_offset = t_offset + h_offset;
+                size_t h_2d_offset = h * h_stride_2d;
+
+                for (int64_t i = 0; i < head_size; i++) {
+                    size_t t_h_i_offset = t_h_offset + i;
+                    size_t h_i_offset = h_offset + i;
+                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+                    float k_val = k[t_h_i_offset];
+                    float r_val = r[t_h_i_offset];
+                    float time_faaaa_val = time_faaaa[h_i_offset];
+                    // RWKV v6: different time_decay for each token.
+                    float time_decay_val = time_decay[t_h_i_offset];
+
+                    for (int64_t j = 0; j < head_size; j++) {
+                        size_t t_h_j_offset = t_h_offset + j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
+
+                        float v_val = v[t_h_j_offset];
+                        float kv_val = v_val * k_val;
+                        float prev_state_val = state_prev[h_2d_i_j_offset];
+                        float temp_val = kv_val * time_faaaa_val + prev_state_val;
+                        dst_data[t_h_j_offset] += temp_val * r_val;
+                        state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
+                    }
+                }
+            }
+        }
+    #endif
+}
+
+
+static void ggml_compute_forward_rwkv_wkv6(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rwkv_wkv6_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_map_unary
+
+static void ggml_compute_forward_map_unary_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst,
+        const ggml_unary_op_f32_t fun) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        fun(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_map_unary(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst,
+        const ggml_unary_op_f32_t fun) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_map_unary_f32(params, dst, fun);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_map_binary
+
+static void ggml_compute_forward_map_binary_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst,
+        const ggml_binary_op_f32_t fun) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(src1));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        fun(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])),
+                (float *) ((char *) src1->data + i*(src1->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_map_binary(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst,
+        const ggml_binary_op_f32_t fun) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_map_binary_f32(params, dst, fun);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_map_custom1
+
+static void ggml_compute_forward_map_custom1_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst,
+        const ggml_custom1_op_f32_t fun) {
+
+    const struct ggml_tensor * a = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    fun(dst, a);
+}
+
+// ggml_compute_forward_map_custom2
+
+static void ggml_compute_forward_map_custom2_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst,
+        const ggml_custom2_op_f32_t fun) {
+
+    const struct ggml_tensor * a = dst->src[0];
+    const struct ggml_tensor * b = dst->src[1];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    fun(dst, a, b);
+}
+
+// ggml_compute_forward_map_custom3
+
+static void ggml_compute_forward_map_custom3_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst,
+        const ggml_custom3_op_f32_t fun) {
+
+    const struct ggml_tensor * a = dst->src[0];
+    const struct ggml_tensor * b = dst->src[1];
+    const struct ggml_tensor * c = dst->src[1];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    fun(dst, a, b, c);
+}
+
+// ggml_compute_forward_map_custom1
+
+static void ggml_compute_forward_map_custom1(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * a = dst->src[0];
+
+    struct ggml_map_custom1_op_params p;
+    memcpy(&p, dst->op_params, sizeof(p));
+
+    p.fun(dst, a, params->ith, params->nth, p.userdata);
+}
+
+// ggml_compute_forward_map_custom2
+
+static void ggml_compute_forward_map_custom2(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * a = dst->src[0];
+    const struct ggml_tensor * b = dst->src[1];
+
+    struct ggml_map_custom2_op_params p;
+    memcpy(&p, dst->op_params, sizeof(p));
+
+    p.fun(dst, a, b, params->ith, params->nth, p.userdata);
+}
+
+// ggml_compute_forward_map_custom3
+
+static void ggml_compute_forward_map_custom3(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * a = dst->src[0];
+    const struct ggml_tensor * b = dst->src[1];
+    const struct ggml_tensor * c = dst->src[2];
+
+    struct ggml_map_custom3_op_params p;
+    memcpy(&p, dst->op_params, sizeof(p));
+
+    p.fun(dst, a, b, c, params->ith, params->nth, p.userdata);
+}
+
+// ggml_compute_forward_cross_entropy_loss
+
+static void ggml_compute_forward_cross_entropy_loss_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
+    GGML_ASSERT(ggml_are_same_shape(src0, src1));
+    GGML_ASSERT(ggml_is_scalar(dst));
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    // TODO: handle transposed/permuted matrices
+    const int64_t nc = src0->ne[0];
+    const int64_t nr = ggml_nrows(src0);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    float * sums =  (float *) params->wdata;
+    float * st   = ((float *) params->wdata) + nth + ith*nc;
+    float sum_thread = 0.0f;
+
+    GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
+
+    // rows per thread
+    const int64_t dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t i1 = ir0; i1 < ir1; ++i1) {
+        const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]);
+        const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]);
+
+#ifndef NDEBUG
+        for (int64_t i = 0; i < nc; ++i) {
+            //printf("p[%d] = %f\n", i, p[i]);
+            assert(!isnan(s0[i]));
+            assert(!isnan(s1[i]));
+        }
+#endif
+
+        float max = -INFINITY;
+        ggml_vec_max_f32(nc, &max, s0);
+        const ggml_float sum_softmax = ggml_vec_log_soft_max_f32(nc, st, s0, max);
+        assert(sum_softmax >= 0.0);
+
+        ggml_vec_add1_f32(nc, st, st, -sum_softmax);
+        ggml_vec_mul_f32(nc, st, st, s1);
+
+        float sum_st = 0.0f;
+        ggml_vec_sum_f32(nc, &sum_st, st);
+        sum_thread += sum_st;
+
+#ifndef NDEBUG
+        for (int64_t i = 0; i < nc; ++i) {
+            assert(!isnan(st[i]));
+            assert(!isinf(st[i]));
+        }
+#endif
+    }
+    sums[ith] = sum_thread;
+    ggml_barrier(params->threadpool);
+
+    if (ith == 0) {
+        float * dp = (float *) dst->data;
+        ggml_vec_sum_f32(nth, dp, sums);
+        dp[0] *= -1.0f / (float) nr;
+    }
+}
+
+static void ggml_compute_forward_cross_entropy_loss(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_cross_entropy_loss_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_cross_entropy_loss_back
+
+static void ggml_compute_forward_cross_entropy_loss_back_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+    const struct ggml_tensor * opt0 = dst->src[2];
+
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_contiguous(opt0));
+    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    const int64_t ith = params->ith;
+    const int64_t nth = params->nth;
+
+    // TODO: handle transposed/permuted matrices
+    const int64_t nc = src0->ne[0];
+    const int64_t nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int64_t dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = MIN(ir0 + dr, nr);
+
+    const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr;
+
+    for (int64_t i1 = ir0; i1 < ir1; i1++) {
+        float * ds0 = (float *)((char *) dst->data  + i1*dst->nb[1]);
+        float * s0  = (float *)((char *) src0->data + i1*src0->nb[1]);
+        float * s1  = (float *)((char *) src1->data + i1*src1->nb[1]);
+
+#ifndef NDEBUG
+        for (int64_t i = 0; i < nc; ++i) {
+            //printf("p[%d] = %f\n", i, p[i]);
+            assert(!isnan(s0[i]));
+            assert(!isnan(s1[i]));
+        }
+#endif
+
+        // soft_max
+        float max = -INFINITY;
+        ggml_vec_max_f32(nc, &max, s0);
+        ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
+        assert(sum > 0.0);
+        ggml_vec_scale_f32(nc, ds0, 1.0/sum);
+
+        // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
+        ggml_vec_sub_f32(nc, ds0, ds0, s1);
+        ggml_vec_scale_f32(nc, ds0, d_by_nr);
+
+#ifndef NDEBUG
+        for (int64_t i = 0; i < nc; ++i) {
+            assert(!isnan(ds0[i]));
+            assert(!isinf(ds0[i]));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_cross_entropy_loss_back(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_cross_entropy_loss_back_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+static void ggml_compute_forward_opt_step_adamw_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0         = dst->src[0];
+    const struct ggml_tensor * src0_grad    = dst->src[1];
+    const struct ggml_tensor * src0_grad_m  = dst->src[2];
+    const struct ggml_tensor * src0_grad_v  = dst->src[3];
+    const struct ggml_tensor * adamw_params = dst->src[4];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
+    GGML_ASSERT(ggml_nelements(adamw_params) == 7);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
+    const float alpha  = adamw_params_ptr[0];
+    const float beta1  = adamw_params_ptr[1];
+    const float beta2  = adamw_params_ptr[2];
+    const float eps    = adamw_params_ptr[3];
+    const float wd     = adamw_params_ptr[4];
+    const float beta1h = adamw_params_ptr[5];
+    const float beta2h = adamw_params_ptr[6];
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        const int64_t i03 = ir/(ne02*ne01);
+        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+        const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
+
+        float       * w = (float       *) ((char       *) src0->data        + offset); // weight
+        const float * g = (const float *) ((const char *) src0_grad->data   + offset); // grad
+        float       * m = (float       *) ((char       *) src0_grad_m->data + offset);
+        float       * v = (float       *) ((char       *) src0_grad_v->data + offset);
+
+        for (int i00 = 0; i00 < ne00; ++i00) {
+            m[i00] = m[i00]*beta1 +        g[i00]*(1.0f - beta1);
+            v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2);
+
+            const float mh =       m[i00]*beta1h;
+            const float vh = sqrtf(v[i00]*beta2h) + eps;
+
+            // The weight decay is applied independently of the Adam momenta m and v.
+            // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
+            // See: https://arxiv.org/pdf/1711.05101v3.pdf
+            w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
+        }
+    }
+}
+
+static void ggml_compute_forward_opt_step_adamw(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_opt_step_adamw_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+/////////////////////////////////
+
+static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
+    GGML_ASSERT(params);
+
+    if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
+        return;
+    }
+
+    switch (tensor->op) {
+        case GGML_OP_DUP:
+            {
+                ggml_compute_forward_dup(params, tensor);
+            } break;
+        case GGML_OP_ADD:
+            {
+                ggml_compute_forward_add(params, tensor);
+            } break;
+        case GGML_OP_ADD1:
+            {
+                ggml_compute_forward_add1(params, tensor);
+            } break;
+        case GGML_OP_ACC:
+            {
+                ggml_compute_forward_acc(params, tensor);
+            } break;
+        case GGML_OP_SUB:
+            {
+                ggml_compute_forward_sub(params, tensor);
+            } break;
+        case GGML_OP_MUL:
+            {
+                ggml_compute_forward_mul(params, tensor);
+            } break;
+        case GGML_OP_DIV:
+            {
+                ggml_compute_forward_div(params, tensor);
+            } break;
+        case GGML_OP_SQR:
+            {
+                ggml_compute_forward_sqr(params, tensor);
+            } break;
+        case GGML_OP_SQRT:
+            {
+                ggml_compute_forward_sqrt(params, tensor);
+            } break;
+        case GGML_OP_LOG:
+            {
+                ggml_compute_forward_log(params, tensor);
+            } break;
+        case GGML_OP_SIN:
+            {
+                ggml_compute_forward_sin(params, tensor);
+            } break;
+        case GGML_OP_COS:
+            {
+                ggml_compute_forward_cos(params, tensor);
+            } break;
+        case GGML_OP_SUM:
+            {
+                ggml_compute_forward_sum(params, tensor);
+            } break;
+        case GGML_OP_SUM_ROWS:
+            {
+                ggml_compute_forward_sum_rows(params, tensor);
+            } break;
+        case GGML_OP_MEAN:
+            {
+                ggml_compute_forward_mean(params, tensor);
+            } break;
+        case GGML_OP_ARGMAX:
+            {
+                ggml_compute_forward_argmax(params, tensor);
+            } break;
+        case GGML_OP_COUNT_EQUAL:
+            {
+                ggml_compute_forward_count_equal(params, tensor);
+            } break;
+        case GGML_OP_REPEAT:
+            {
+                ggml_compute_forward_repeat(params, tensor);
+            } break;
+        case GGML_OP_REPEAT_BACK:
+            {
+                ggml_compute_forward_repeat_back(params, tensor);
+            } break;
+        case GGML_OP_CONCAT:
+            {
+                ggml_compute_forward_concat(params, tensor);
+            } break;
+        case GGML_OP_SILU_BACK:
+            {
+                ggml_compute_forward_silu_back(params, tensor);
+            } break;
+        case GGML_OP_NORM:
+            {
+                ggml_compute_forward_norm(params, tensor);
+            } break;
+        case GGML_OP_RMS_NORM:
+            {
+                ggml_compute_forward_rms_norm(params, tensor);
+            } break;
+        case GGML_OP_RMS_NORM_BACK:
+            {
+                ggml_compute_forward_rms_norm_back(params, tensor);
+            } break;
+        case GGML_OP_GROUP_NORM:
+            {
+                ggml_compute_forward_group_norm(params, tensor);
+            } break;
+        case GGML_OP_MUL_MAT:
+            {
+                ggml_compute_forward_mul_mat(params, tensor);
+            } break;
+        case GGML_OP_MUL_MAT_ID:
+            {
+                ggml_compute_forward_mul_mat_id(params, tensor);
+            } break;
+        case GGML_OP_OUT_PROD:
+            {
+                ggml_compute_forward_out_prod(params, tensor);
+            } break;
+        case GGML_OP_SCALE:
+            {
+                ggml_compute_forward_scale(params, tensor);
+            } break;
+        case GGML_OP_SET:
+            {
+                ggml_compute_forward_set(params, tensor);
+            } break;
+        case GGML_OP_CPY:
+            {
+                ggml_compute_forward_cpy(params, tensor);
+            } break;
+        case GGML_OP_CONT:
+            {
+                ggml_compute_forward_cont(params, tensor);
+            } break;
+        case GGML_OP_RESHAPE:
+            {
+                ggml_compute_forward_reshape(params, tensor);
+            } break;
+        case GGML_OP_VIEW:
+            {
+                ggml_compute_forward_view(params, tensor);
+            } break;
+        case GGML_OP_PERMUTE:
+            {
+                ggml_compute_forward_permute(params, tensor);
+            } break;
+        case GGML_OP_TRANSPOSE:
+            {
+                ggml_compute_forward_transpose(params, tensor);
+            } break;
+        case GGML_OP_GET_ROWS:
+            {
+                ggml_compute_forward_get_rows(params, tensor);
+            } break;
+        case GGML_OP_GET_ROWS_BACK:
+            {
+                ggml_compute_forward_get_rows_back(params, tensor);
+            } break;
+        case GGML_OP_DIAG:
+            {
+                ggml_compute_forward_diag(params, tensor);
+            } break;
+        case GGML_OP_DIAG_MASK_INF:
+            {
+                ggml_compute_forward_diag_mask_inf(params, tensor);
+            } break;
+        case GGML_OP_DIAG_MASK_ZERO:
+            {
+                ggml_compute_forward_diag_mask_zero(params, tensor);
+            } break;
+        case GGML_OP_SOFT_MAX:
+            {
+                ggml_compute_forward_soft_max(params, tensor);
+            } break;
+        case GGML_OP_SOFT_MAX_BACK:
+            {
+                ggml_compute_forward_soft_max_back(params, tensor);
+            } break;
+        case GGML_OP_ROPE:
+            {
+                ggml_compute_forward_rope(params, tensor);
+            } break;
+        case GGML_OP_ROPE_BACK:
+            {
+                ggml_compute_forward_rope_back(params, tensor);
+            } break;
+        case GGML_OP_CLAMP:
+            {
+                ggml_compute_forward_clamp(params, tensor);
+            } break;
+        case GGML_OP_CONV_TRANSPOSE_1D:
+            {
+                ggml_compute_forward_conv_transpose_1d(params, tensor);
+            } break;
+        case GGML_OP_IM2COL:
+            {
+                ggml_compute_forward_im2col(params, tensor);
+            } break;
+        case GGML_OP_IM2COL_BACK:
+            {
+                ggml_compute_forward_im2col_back_f32(params, tensor);
+            } break;
+        case GGML_OP_CONV_TRANSPOSE_2D:
+            {
+                ggml_compute_forward_conv_transpose_2d(params, tensor);
+            } break;
+        case GGML_OP_POOL_1D:
+            {
+                ggml_compute_forward_pool_1d(params, tensor);
+            } break;
+        case GGML_OP_POOL_2D:
+            {
+                ggml_compute_forward_pool_2d(params, tensor);
+            } break;
+        case GGML_OP_POOL_2D_BACK:
+            {
+                ggml_compute_forward_pool_2d_back(params, tensor);
+            } break;
+        case GGML_OP_UPSCALE:
+            {
+                ggml_compute_forward_upscale(params, tensor);
+            } break;
+        case GGML_OP_PAD:
+            {
+                ggml_compute_forward_pad(params, tensor);
+            } break;
+        case GGML_OP_UNPAD:
+            {
+                ggml_compute_forward_unpad(params, tensor);
+            } break;
+        case GGML_OP_ARANGE:
+            {
+                ggml_compute_forward_arange(params, tensor);
+            } break;
+        case GGML_OP_TIMESTEP_EMBEDDING:
+            {
+                ggml_compute_forward_timestep_embedding(params, tensor);
+            } break;
+        case GGML_OP_ARGSORT:
+            {
+                ggml_compute_forward_argsort(params, tensor);
+            } break;
+        case GGML_OP_LEAKY_RELU:
+            {
+                ggml_compute_forward_leaky_relu(params, tensor);
+            } break;
+        case GGML_OP_FLASH_ATTN_EXT:
+            {
+                ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
+            } break;
+        case GGML_OP_FLASH_ATTN_BACK:
+            {
+                int32_t t = ggml_get_op_params_i32(tensor, 0);
+                GGML_ASSERT(t == 0 || t == 1);
+                bool masked = t != 0;
+                ggml_compute_forward_flash_attn_back(params, masked, tensor);
+            } break;
+        case GGML_OP_SSM_CONV:
+            {
+                ggml_compute_forward_ssm_conv(params, tensor);
+            } break;
+        case GGML_OP_SSM_SCAN:
+            {
+                ggml_compute_forward_ssm_scan(params, tensor);
+            } break;
+        case GGML_OP_WIN_PART:
+            {
+                ggml_compute_forward_win_part(params, tensor);
+            } break;
+        case GGML_OP_WIN_UNPART:
+            {
+                ggml_compute_forward_win_unpart(params, tensor);
+            } break;
+        case GGML_OP_UNARY:
+            {
+                ggml_compute_forward_unary(params, tensor);
+            } break;
+        case GGML_OP_GET_REL_POS:
+            {
+                ggml_compute_forward_get_rel_pos(params, tensor);
+            } break;
+        case GGML_OP_ADD_REL_POS:
+            {
+                ggml_compute_forward_add_rel_pos(params, tensor);
+            } break;
+        case GGML_OP_RWKV_WKV6:
+            {
+                ggml_compute_forward_rwkv_wkv6(params, tensor);
+            } break;
+        case GGML_OP_MAP_UNARY:
+            {
+                ggml_unary_op_f32_t fun;
+                memcpy(&fun, tensor->op_params, sizeof(fun));
+                ggml_compute_forward_map_unary(params, tensor, fun);
+            }
+            break;
+        case GGML_OP_MAP_BINARY:
+            {
+                ggml_binary_op_f32_t fun;
+                memcpy(&fun, tensor->op_params, sizeof(fun));
+                ggml_compute_forward_map_binary(params, tensor, fun);
+            }
+            break;
+        case GGML_OP_MAP_CUSTOM1_F32:
+            {
+                ggml_custom1_op_f32_t fun;
+                memcpy(&fun, tensor->op_params, sizeof(fun));
+                ggml_compute_forward_map_custom1_f32(params, tensor, fun);
+            }
+            break;
+        case GGML_OP_MAP_CUSTOM2_F32:
+            {
+                ggml_custom2_op_f32_t fun;
+                memcpy(&fun, tensor->op_params, sizeof(fun));
+                ggml_compute_forward_map_custom2_f32(params, tensor, fun);
+            }
+            break;
+        case GGML_OP_MAP_CUSTOM3_F32:
+            {
+                ggml_custom3_op_f32_t fun;
+                memcpy(&fun, tensor->op_params, sizeof(fun));
+                ggml_compute_forward_map_custom3_f32(params, tensor, fun);
+            }
+            break;
+        case GGML_OP_MAP_CUSTOM1:
+            {
+                ggml_compute_forward_map_custom1(params, tensor);
+            }
+            break;
+        case GGML_OP_MAP_CUSTOM2:
+            {
+                ggml_compute_forward_map_custom2(params, tensor);
+            }
+            break;
+        case GGML_OP_MAP_CUSTOM3:
+            {
+                ggml_compute_forward_map_custom3(params, tensor);
+            }
+            break;
+        case GGML_OP_CROSS_ENTROPY_LOSS:
+            {
+                ggml_compute_forward_cross_entropy_loss(params, tensor);
+            }
+            break;
+        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+            {
+                ggml_compute_forward_cross_entropy_loss_back(params, tensor);
+            }
+            break;
+        case GGML_OP_OPT_STEP_ADAMW:
+            {
+                ggml_compute_forward_opt_step_adamw(params, tensor);
+            }
+            break;
+        case GGML_OP_NONE:
+            {
+                // nop
+            } break;
+        case GGML_OP_COUNT:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// Android's libc implementation "bionic" does not support setting affinity
+#if defined(__gnu_linux__)
+static void set_numa_thread_affinity(int thread_n) {
+    if (!ggml_is_numa()) {
+        return;
+    }
+
+    int node_num;
+    int rv;
+    size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus);
+
+    switch(g_state.numa.numa_strategy) {
+        case GGML_NUMA_STRATEGY_DISTRIBUTE:
+            // run thread on node_num thread_n / (threads per node)
+            node_num = thread_n % g_state.numa.n_nodes;
+            break;
+        case GGML_NUMA_STRATEGY_ISOLATE:
+            // run thread on current_node
+            node_num = g_state.numa.current_node;
+            break;
+        case GGML_NUMA_STRATEGY_NUMACTL:
+            // use the cpuset that numactl gave us
+            rv = pthread_setaffinity_np(pthread_self(), setsize, &g_state.numa.cpuset);
+            if (rv) {
+                fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n",strerror(rv));
+            }
+            return;
+        default:
+            return;
+    }
+
+    struct ggml_numa_node * node = &g_state.numa.nodes[node_num];
+
+    cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus);
+    CPU_ZERO_S(setsize, cpus);
+    for (size_t i = 0; i < node->n_cpus; ++i) {
+        CPU_SET_S(node->cpus[i], setsize, cpus);
+    }
+
+    rv = pthread_setaffinity_np(pthread_self(), setsize, cpus);
+    if (rv) {
+            fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", strerror(rv));
+    }
+
+    CPU_FREE(cpus);
+}
+
+static void clear_numa_thread_affinity(void) {
+    if (!ggml_is_numa()) {
+        return;
+    }
+
+    size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus);
+
+    cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus);
+    CPU_ZERO_S(setsize, cpus);
+    for (unsigned i = 0; i < g_state.numa.total_cpus; ++i) {
+        CPU_SET_S(i, setsize, cpus);
+    }
+
+    int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus);
+    if (rv) {
+        fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", strerror(rv));
+    }
+
+    CPU_FREE(cpus);
+}
+#else
+// TODO: Windows etc.
+// (the linux implementation may also work on BSD, someone should test)
+static void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n);  }
+static void clear_numa_thread_affinity(void) {}
+#endif
+
+static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
+    int n_tasks = 0;
+
+    if (ggml_is_empty(node)) {
+        // no need to multi-thread a no-op
+        n_tasks = 1;
+        return n_tasks;
+    }
+
+    switch (node->op) {
+        case GGML_OP_CPY:
+        case GGML_OP_DUP:
+        case GGML_OP_CONT:
+        case GGML_OP_ADD:
+        case GGML_OP_ADD1:
+        case GGML_OP_ACC:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_SUB:
+        case GGML_OP_SQR:
+        case GGML_OP_SQRT:
+        case GGML_OP_LOG:
+        case GGML_OP_SIN:
+        case GGML_OP_COS:
+        case GGML_OP_SUM:
+        case GGML_OP_SUM_ROWS:
+        case GGML_OP_MEAN:
+        case GGML_OP_ARGMAX:
+            {
+                n_tasks = 1;
+            } break;
+        case GGML_OP_COUNT_EQUAL:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_REPEAT:
+        case GGML_OP_REPEAT_BACK:
+        case GGML_OP_LEAKY_RELU:
+            {
+                n_tasks = 1;
+            } break;
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(node)) {
+                case GGML_UNARY_OP_ABS:
+                case GGML_UNARY_OP_SGN:
+                case GGML_UNARY_OP_NEG:
+                case GGML_UNARY_OP_STEP:
+                case GGML_UNARY_OP_TANH:
+                case GGML_UNARY_OP_ELU:
+                case GGML_UNARY_OP_RELU:
+                case GGML_UNARY_OP_SIGMOID:
+                case GGML_UNARY_OP_HARDSWISH:
+                case GGML_UNARY_OP_HARDSIGMOID:
+                case GGML_UNARY_OP_EXP:
+                    {
+                        n_tasks = 1;
+                    } break;
+
+                case GGML_UNARY_OP_GELU:
+                case GGML_UNARY_OP_GELU_QUICK:
+                case GGML_UNARY_OP_SILU:
+                    {
+                        n_tasks = n_threads;
+                    } break;
+                default:
+                    GGML_ABORT("fatal error");
+            }
+            break;
+        case GGML_OP_SILU_BACK:
+        case GGML_OP_MUL:
+        case GGML_OP_DIV:
+        case GGML_OP_NORM:
+        case GGML_OP_RMS_NORM:
+        case GGML_OP_RMS_NORM_BACK:
+        case GGML_OP_GROUP_NORM:
+        case GGML_OP_CONCAT:
+        case GGML_OP_MUL_MAT:
+        case GGML_OP_MUL_MAT_ID:
+        case GGML_OP_OUT_PROD:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_GET_ROWS:
+            {
+                // FIXME: get_rows can use additional threads, but the cost of launching additional threads
+                // decreases performance with GPU offloading
+                //n_tasks = n_threads;
+                n_tasks = 1;
+            } break;
+        case GGML_OP_SCALE:
+        case GGML_OP_SET:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+        case GGML_OP_GET_ROWS_BACK:
+        case GGML_OP_DIAG:
+            {
+                n_tasks = 1;
+            } break;
+        case GGML_OP_DIAG_MASK_ZERO:
+        case GGML_OP_DIAG_MASK_INF:
+        case GGML_OP_SOFT_MAX_BACK:
+        case GGML_OP_ROPE:
+        case GGML_OP_ROPE_BACK:
+        case GGML_OP_ADD_REL_POS:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_CLAMP:
+            {
+                n_tasks = 1; //TODO
+            } break;
+        case GGML_OP_SOFT_MAX:
+            {
+                n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
+            } break;
+        case GGML_OP_IM2COL:
+        case GGML_OP_IM2COL_BACK:
+        case GGML_OP_CONV_TRANSPOSE_1D:
+        case GGML_OP_CONV_TRANSPOSE_2D:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_POOL_1D:
+        case GGML_OP_POOL_2D:
+        case GGML_OP_POOL_2D_BACK:
+            {
+                n_tasks = 1;
+            } break;
+        case GGML_OP_UPSCALE:
+        case GGML_OP_PAD:
+        case GGML_OP_UNPAD:
+        case GGML_OP_ARANGE:
+        case GGML_OP_TIMESTEP_EMBEDDING:
+        case GGML_OP_ARGSORT:
+        case GGML_OP_FLASH_ATTN_EXT:
+        case GGML_OP_FLASH_ATTN_BACK:
+        case GGML_OP_SSM_CONV:
+        case GGML_OP_SSM_SCAN:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_WIN_PART:
+        case GGML_OP_WIN_UNPART:
+        case GGML_OP_GET_REL_POS:
+        case GGML_OP_RWKV_WKV6:
+        case GGML_OP_MAP_UNARY:
+        case GGML_OP_MAP_BINARY:
+        case GGML_OP_MAP_CUSTOM1_F32:
+        case GGML_OP_MAP_CUSTOM2_F32:
+        case GGML_OP_MAP_CUSTOM3_F32:
+            {
+                n_tasks = 1;
+            } break;
+        case GGML_OP_MAP_CUSTOM1:
+            {
+                struct ggml_map_custom1_op_params p;
+                memcpy(&p, node->op_params, sizeof(p));
+                if (p.n_tasks == GGML_N_TASKS_MAX) {
+                    n_tasks = n_threads;
+                } else {
+                    n_tasks = MIN(p.n_tasks, n_threads);
+                }
+            } break;
+        case GGML_OP_MAP_CUSTOM2:
+            {
+                struct ggml_map_custom2_op_params p;
+                memcpy(&p, node->op_params, sizeof(p));
+                if (p.n_tasks == GGML_N_TASKS_MAX) {
+                    n_tasks = n_threads;
+                } else {
+                    n_tasks = MIN(p.n_tasks, n_threads);
+                }
+            } break;
+        case GGML_OP_MAP_CUSTOM3:
+            {
+                struct ggml_map_custom3_op_params p;
+                memcpy(&p, node->op_params, sizeof(p));
+                if (p.n_tasks == GGML_N_TASKS_MAX) {
+                    n_tasks = n_threads;
+                } else {
+                    n_tasks = MIN(p.n_tasks, n_threads);
+                }
+            } break;
+        case GGML_OP_CROSS_ENTROPY_LOSS:
+        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+        case GGML_OP_OPT_STEP_ADAMW:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_NONE:
+            {
+                n_tasks = 1;
+            } break;
+        case GGML_OP_COUNT:
+            {
+                GGML_ABORT("fatal error");
+            }
+        default:
+            {
+                fprintf(stderr, "%s: op not implemented: ", __func__);
+                if (node->op < GGML_OP_COUNT) {
+                    fprintf(stderr, "%s\n", ggml_op_name(node->op));
+                } else {
+                    fprintf(stderr, "%d\n", node->op);
+                }
+                GGML_ABORT("fatal error");
+            }
+    }
+
+    assert(n_tasks > 0);
+
+    return n_tasks;
+}
+
+static thread_ret_t ggml_graph_compute_secondary_thread(void* data);
+
+#if defined(_WIN32)
+#include "windows.h"
+
+// TODO: support > 64 CPUs
+bool ggml_thread_apply_affinity(bool * mask) {
+    HANDLE    h = GetCurrentThread();
+    uint64_t  bitmask = 0ULL;
+
+    assert(GGML_MAX_N_THREADS >= 64);
+
+    for (int32_t i = 0; i < 8; i++) {
+        int32_t idx = i * 8;
+        uint8_t val = 0;
+        val |= mask[idx + 0] << 0;
+        val |= mask[idx + 1] << 1;
+        val |= mask[idx + 2] << 2;
+        val |= mask[idx + 3] << 3;
+        val |= mask[idx + 4] << 4;
+        val |= mask[idx + 5] << 5;
+        val |= mask[idx + 6] << 6;
+        val |= mask[idx + 7] << 7;
+        bitmask |= (uint64_t)val << idx;
+    }
+
+    for (int32_t i = 64; i < GGML_MAX_N_THREADS; i++) {
+        if (mask[i]) {
+            fprintf(stderr, "warn: setting thread-affinity for > 64 CPUs isn't supported on windows!\n");
+            break;
+        }
+    }
+
+    DWORD_PTR m = (DWORD_PTR)bitmask;
+
+    m = SetThreadAffinityMask(h, m);
+
+    return m != 0;
+}
+
+static bool ggml_thread_apply_priority(int32_t prio) {
+    // Note that on Windows the Process Priority Class must be updated in order to set Thread priority.
+    // This is up to the applications.
+    DWORD p = THREAD_PRIORITY_NORMAL;
+    switch (prio) {
+        case GGML_SCHED_PRIO_NORMAL:   p = THREAD_PRIORITY_NORMAL;        break;
+        case GGML_SCHED_PRIO_MEDIUM:   p = THREAD_PRIORITY_ABOVE_NORMAL;  break;
+        case GGML_SCHED_PRIO_HIGH:     p = THREAD_PRIORITY_HIGHEST;       break;
+        case GGML_SCHED_PRIO_REALTIME: p = THREAD_PRIORITY_TIME_CRITICAL; break;
+    }
+
+    if (prio == GGML_SCHED_PRIO_NORMAL) {
+        // Keep inherited policy/priority
+        return true;
+    }
+
+    if (!SetThreadPriority(GetCurrentThread(), p)) {
+        fprintf(stderr, "warn: failed to set thread priority %d : (%d)\n", prio, (int) GetLastError());
+        return false;
+    }
+
+    return true;
+}
+
+#elif defined(__APPLE__)
+#include <sys/types.h>
+#include <sys/resource.h>
+
+static bool ggml_thread_apply_affinity(const bool * mask) {
+    // Not supported on Apple platforms
+    UNUSED(mask);
+    return true;
+}
+
+static bool ggml_thread_apply_priority(int32_t prio) {
+    struct sched_param p;
+    int32_t policy = SCHED_OTHER;
+    switch (prio) {
+        case GGML_SCHED_PRIO_NORMAL:   policy = SCHED_OTHER; p.sched_priority = 0;  break;
+        case GGML_SCHED_PRIO_MEDIUM:   policy = SCHED_FIFO;  p.sched_priority = 40; break;
+        case GGML_SCHED_PRIO_HIGH:     policy = SCHED_FIFO;  p.sched_priority = 80; break;
+        case GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO;  p.sched_priority = 90; break;
+    }
+
+    if (prio == GGML_SCHED_PRIO_NORMAL) {
+        // Keep inherited policy/priority
+        return true;
+    }
+
+    int32_t err = pthread_setschedparam(pthread_self(), policy, &p);
+    if (err != 0) {
+        fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err);
+        return false;
+    }
+
+    return true;
+}
+
+#elif defined(__gnu_linux__)
+// TODO: this may not work on BSD, to be verified
+
+static bool ggml_thread_apply_affinity(const bool * mask) {
+    cpu_set_t cpuset;
+    int err;
+
+    CPU_ZERO(&cpuset);
+
+    for (uint32_t i = 0; i < GGML_MAX_N_THREADS; i++) {
+        if (mask[i]) {
+            GGML_PRINT_DEBUG("Thread %lx: adding %d to cpuset\n", pthread_self(), i);
+            CPU_SET(i, &cpuset);
+        }
+    }
+
+#ifdef __ANDROID__
+    err = sched_setaffinity(0, sizeof(cpuset), &cpuset);
+    if (err < 0) {
+        err = errno;
+    }
+#else
+    err = pthread_setaffinity_np(pthread_self(), sizeof(cpuset), &cpuset);
+#endif
+    if (err != 0) {
+        fprintf(stderr, "warn: failed to set affinity mask 0x%llx : %s (%d)\n", (unsigned long long)mask, strerror(err), err);
+        return false;
+    }
+
+    return true;
+}
+
+static bool ggml_thread_apply_priority(int32_t prio) {
+    struct sched_param p;
+    int32_t policy = SCHED_OTHER;
+    switch (prio) {
+        case GGML_SCHED_PRIO_NORMAL:   policy = SCHED_OTHER; p.sched_priority = 0;  break;
+        case GGML_SCHED_PRIO_MEDIUM:   policy = SCHED_FIFO;  p.sched_priority = 40; break;
+        case GGML_SCHED_PRIO_HIGH:     policy = SCHED_FIFO;  p.sched_priority = 80; break;
+        case GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO;  p.sched_priority = 90; break;
+    }
+
+    if (prio == GGML_SCHED_PRIO_NORMAL) {
+        // Keep inherited policy/priority
+        return true;
+    }
+
+    int32_t err = pthread_setschedparam(pthread_self(), policy, &p);
+    if (err != 0) {
+        fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err);
+        return false;
+    }
+
+    return true;
+}
+
+#else // unsupported platforms
+
+static bool ggml_thread_apply_affinity(const bool * mask) {
+    UNUSED(mask);
+    return true;
+}
+
+static bool ggml_thread_apply_priority(int32_t prio) {
+    UNUSED(prio);
+    return true;
+}
+
+#endif
+
+static bool ggml_thread_cpumask_is_valid(const bool * mask) {
+    for (int i = 0; i < GGML_MAX_N_THREADS; i++) {
+        if (mask[i]) { return true; }
+    }
+    return false;
+}
+
+static void ggml_thread_cpumask_next(const bool * global_mask, bool * local_mask, bool strict, int32_t* iter) {
+    if (!strict) {
+        memcpy(local_mask, global_mask, GGML_MAX_N_THREADS);
+        return;
+    } else {
+        memset(local_mask, 0, GGML_MAX_N_THREADS);
+        int32_t base_idx = *iter;
+        for (int32_t i = 0; i < GGML_MAX_N_THREADS; i++) {
+            int32_t idx = base_idx + i;
+            if (idx >= GGML_MAX_N_THREADS) {
+                // Just a cheaper modulo
+                idx -= GGML_MAX_N_THREADS;
+            }
+            if (global_mask[idx]) {
+                local_mask[idx] = 1;
+                *iter = idx + 1;
+                return;
+            }
+        }
+    }
+}
+
+void ggml_threadpool_free(struct ggml_threadpool* threadpool) {
+    if (!threadpool) return;
+
+    const int n_threads = threadpool->n_threads_max;
+
+#ifndef GGML_USE_OPENMP
+    struct ggml_compute_state* workers = threadpool->workers;
+
+    ggml_mutex_lock(&threadpool->mutex);
+
+    threadpool->stop = true;
+    threadpool->pause = false;
+
+    ggml_cond_broadcast(&threadpool->cond);
+    ggml_mutex_unlock(&threadpool->mutex);
+
+    for (int j = 1; j < n_threads; j++) {
+        int32_t rc = ggml_thread_join(workers[j].thrd, NULL);
+        GGML_ASSERT(rc == GGML_EXIT_SUCCESS || rc == GGML_EXIT_ABORTED);
+        UNUSED(rc);
+    }
+
+    ggml_mutex_destroy(&threadpool->mutex);
+    ggml_cond_destroy(&threadpool->cond);
+#endif // GGML_USE_OPENMP
+
+    const size_t workers_size = sizeof(struct ggml_compute_state) * n_threads;
+    ggml_aligned_free(threadpool->workers, workers_size);
+    ggml_aligned_free(threadpool, sizeof(struct ggml_threadpool));
+}
+
+#ifndef GGML_USE_OPENMP
+// pause/resume must be called under mutex
+static void ggml_threadpool_pause_locked(struct ggml_threadpool * threadpool) {
+    GGML_PRINT_DEBUG("Pausing threadpool\n");
+    threadpool->pause = true;
+    ggml_cond_broadcast(&threadpool->cond);
+}
+
+static void ggml_threadpool_resume_locked(struct ggml_threadpool * threadpool) {
+    GGML_PRINT_DEBUG("Resuming threadpool\n");
+    threadpool->pause = false;
+    ggml_cond_broadcast(&threadpool->cond);
+}
+#endif
+
+void ggml_threadpool_pause(struct ggml_threadpool * threadpool) {
+#ifndef GGML_USE_OPENMP
+    ggml_mutex_lock(&threadpool->mutex);
+    if (!threadpool->pause) {
+       ggml_threadpool_pause_locked(threadpool);
+    }
+    ggml_mutex_unlock(&threadpool->mutex);
+#else
+    UNUSED(threadpool);
+#endif
+}
+
+void ggml_threadpool_resume(struct ggml_threadpool * threadpool) {
+#ifndef GGML_USE_OPENMP
+    ggml_mutex_lock(&threadpool->mutex);
+    if (threadpool->pause) {
+       ggml_threadpool_resume_locked(threadpool);
+    }
+    ggml_mutex_unlock(&threadpool->mutex);
+#else
+    UNUSED(threadpool);
+#endif
+}
+
+struct ggml_cplan ggml_graph_plan(
+          const struct ggml_cgraph * cgraph,
+                               int   n_threads,
+            struct ggml_threadpool * threadpool) {
+
+    if (threadpool == NULL) {
+        //GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads);
+    }
+    if (n_threads <= 0) {
+        n_threads = threadpool ? threadpool->n_threads_max : GGML_DEFAULT_N_THREADS;
+    }
+
+    size_t work_size = 0;
+
+    struct ggml_cplan cplan;
+    memset(&cplan, 0, sizeof(struct ggml_cplan));
+
+    int max_tasks = 1;
+
+    // thread scheduling for the different operations + work buffer size estimation
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        struct ggml_tensor * node = cgraph->nodes[i];
+
+        const int n_tasks = ggml_get_n_tasks(node, n_threads);
+
+        max_tasks = MAX(max_tasks, n_tasks);
+
+        size_t cur = 0;
+
+        switch (node->op) {
+            case GGML_OP_CPY:
+            case GGML_OP_DUP:
+                {
+                    if (ggml_is_quantized(node->type) ||
+                        // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
+                        (node->src[0]->type == GGML_TYPE_F16  && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||
+                        (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) {
+                        cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
+                    }
+                } break;
+            case GGML_OP_ADD:
+            case GGML_OP_ADD1:
+                {
+                    if (ggml_is_quantized(node->src[0]->type)) {
+                        cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
+                    }
+                } break;
+            case GGML_OP_ACC:
+                {
+                    if (ggml_is_quantized(node->src[0]->type)) {
+                        cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
+                    }
+                } break;
+            case GGML_OP_COUNT_EQUAL:
+                {
+                    cur = ggml_type_size(node->type)*n_tasks;
+                } break;
+            case GGML_OP_MUL_MAT:
+                {
+#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
+                    if (node->src[0]->buffer && ggml_backend_amx_buft_is_amx(node->src[0]->buffer->buft)) {
+                        cur = ggml_backend_amx_desired_wsize(node);
+                    }
+#endif
+                    const enum ggml_type vec_dot_type = type_traits_cpu[node->src[0]->type].vec_dot_type;
+
+                    if (node->src[1]->type != vec_dot_type) {
+                        size_t cur2 = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1]));
+                        cur = MAX(cur, cur2);
+                    }
+                } break;
+            case GGML_OP_MUL_MAT_ID:
+                {
+                    cur = 0;
+                    const struct ggml_tensor * src0 = node->src[0];
+                    const struct ggml_tensor * src1 = node->src[1];
+                    const enum ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
+                    if (src1->type != vec_dot_type) {
+                        cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
+                    }
+                    const int n_as = src0->ne[2];
+                    cur += GGML_PAD(cur, sizeof(int64_t));       // align
+                    cur += n_as * sizeof(int64_t);               // matrix_row_counts
+                    cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
+                } break;
+            case GGML_OP_OUT_PROD:
+                {
+                    if (ggml_is_quantized(node->src[0]->type)) {
+                        cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
+                    }
+                } break;
+            case GGML_OP_SOFT_MAX:
+            case GGML_OP_ROPE:
+                {
+                    cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
+                } break;
+            case GGML_OP_CONV_TRANSPOSE_1D:
+                {
+                    GGML_ASSERT(node->src[0]->ne[3] == 1);
+                    GGML_ASSERT(node->src[1]->ne[2] == 1);
+                    GGML_ASSERT(node->src[1]->ne[3] == 1);
+
+                    const int64_t ne00 = node->src[0]->ne[0];  // K
+                    const int64_t ne01 = node->src[0]->ne[1];  // Cout
+                    const int64_t ne02 = node->src[0]->ne[2];  // Cin
+
+                    const int64_t ne10 = node->src[1]->ne[0];  // L
+                    const int64_t ne11 = node->src[1]->ne[1];  // Cin
+
+                    if ((node->src[0]->type == GGML_TYPE_F16 ||
+                         node->src[0]->type == GGML_TYPE_BF16) &&
+                        node->src[1]->type == GGML_TYPE_F32) {
+                        cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
+                        cur += sizeof(ggml_fp16_t)*ne10*ne11;
+                    } else if (node->src[0]->type == GGML_TYPE_F32 &&
+                               node->src[1]->type == GGML_TYPE_F32) {
+                        cur += sizeof(float)*ne00*ne01*ne02;
+                        cur += sizeof(float)*ne10*ne11;
+                    } else {
+                        GGML_ABORT("fatal error");
+                    }
+                } break;
+            case GGML_OP_CONV_TRANSPOSE_2D:
+                {
+                    const int64_t ne00 = node->src[0]->ne[0]; // W
+                    const int64_t ne01 = node->src[0]->ne[1]; // H
+                    const int64_t ne02 = node->src[0]->ne[2]; // Channels Out
+                    const int64_t ne03 = node->src[0]->ne[3]; // Channels In
+
+                    const int64_t ne10 = node->src[1]->ne[0]; // W
+                    const int64_t ne11 = node->src[1]->ne[1]; // H
+                    const int64_t ne12 = node->src[1]->ne[2]; // Channels In
+
+                    cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
+                    cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
+                } break;
+            case GGML_OP_FLASH_ATTN_EXT:
+                {
+                    const int64_t ne00 = node->src[0]->ne[0]; // D
+
+                    cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
+                } break;
+            case GGML_OP_FLASH_ATTN_BACK:
+                {
+                    const int64_t    D = node->src[0]->ne[0];
+                    const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
+                    const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
+                    if (node->src[1]->type == GGML_TYPE_F32) {
+                        cur  = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
+                        cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
+                    } else if (node->src[1]->type == GGML_TYPE_F16) {
+                        cur  = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
+                        cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
+                    } else if (node->src[1]->type == GGML_TYPE_BF16) {
+                        cur  = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
+                        cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
+                    }
+                } break;
+
+            case GGML_OP_CROSS_ENTROPY_LOSS:
+                {
+                    cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
+                } break;
+            case GGML_OP_COUNT:
+                {
+                    GGML_ABORT("fatal error");
+                }
+            default:
+                break;
+        }
+
+        work_size = MAX(work_size, cur);
+    }
+
+    if (work_size > 0) {
+        work_size += CACHE_LINE_SIZE*(n_threads);
+    }
+
+    cplan.threadpool = threadpool;
+    cplan.n_threads  = MIN(max_tasks, n_threads);
+    cplan.work_size  = work_size;
+    cplan.work_data  = NULL;
+
+    return cplan;
+}
+
+static thread_ret_t ggml_graph_compute_thread(void * data) {
+    struct ggml_compute_state * state = (struct ggml_compute_state *) data;
+    struct ggml_threadpool    * tp    = state->threadpool;
+
+    const struct ggml_cgraph * cgraph = tp->cgraph;
+    const struct ggml_cplan  * cplan  = tp->cplan;
+
+    set_numa_thread_affinity(state->ith);
+
+    struct ggml_compute_params params = {
+        /*.ith       =*/ state->ith,
+        /*.nth       =*/ atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed),
+        /*.wsize     =*/ cplan->work_size,
+        /*.wdata     =*/ cplan->work_data,
+        /*.threadpool=*/ tp,
+    };
+
+    for (int node_n = 0; node_n < cgraph->n_nodes && !tp->abort; node_n++) {
+        struct ggml_tensor * node = cgraph->nodes[node_n];
+
+        ggml_compute_forward(&params, node);
+
+        if (state->ith == 0 && cplan->abort_callback &&
+                cplan->abort_callback(cplan->abort_callback_data)) {
+            tp->abort = true;
+            tp->ec    = GGML_STATUS_ABORTED;
+        }
+
+        ggml_barrier(state->threadpool);
+    }
+
+    return 0;
+}
+
+#ifndef GGML_USE_OPENMP
+
+// check if thread is active
+static inline bool ggml_graph_compute_thread_active(struct ggml_compute_state * state) {
+    struct ggml_threadpool * threadpool = state->threadpool;
+    int n_threads = atomic_load_explicit(&threadpool->n_threads_cur, memory_order_relaxed);
+    return (state->ith < n_threads);
+}
+
+// check if thread is ready to proceed (exit from polling or sleeping)
+static inline bool ggml_graph_compute_thread_ready(struct ggml_compute_state * state) {
+    struct ggml_threadpool * threadpool = state->threadpool;
+
+    if (state->pending || threadpool->stop || threadpool->pause) { return true; }
+
+    // check for new graph/work
+    int new_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed);
+    if (new_graph != state->last_graph) {
+        state->pending    = ggml_graph_compute_thread_active(state);
+        state->last_graph = new_graph;
+    }
+
+    return state->pending;
+}
+
+// sync thread state after polling
+static inline void ggml_graph_compute_thread_sync(struct ggml_compute_state * state) {
+    // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead
+    #ifdef GGML_TSAN_ENABLED
+    atomic_fetch_add_explicit(&state->threadpool->n_graph, 0, memory_order_seq_cst);
+    #else
+    atomic_thread_fence(memory_order_seq_cst);
+    #endif
+    UNUSED(state);
+}
+
+static inline bool ggml_graph_compute_poll_for_work(struct ggml_compute_state * state) {
+    struct ggml_threadpool * threadpool = state->threadpool;
+
+    // Skip polling for unused threads
+    if (!ggml_graph_compute_thread_active(state)) {
+        return state->pending;
+    }
+
+    // This seems to make 0 ... 100 a decent range for polling level across modern processors.
+    // Perhaps, we can adjust it dynamically based on load and things.
+    const uint64_t n_rounds = 1024UL * 128 * threadpool->poll;
+
+    for (uint64_t i=0; !ggml_graph_compute_thread_ready(state) && i < n_rounds; i++) {
+        // No new work. Keep polling.
+        ggml_thread_cpu_relax();
+    }
+
+    return state->pending;
+}
+
+static inline bool ggml_graph_compute_check_for_work(struct ggml_compute_state * state) {
+    struct ggml_threadpool * threadpool = state->threadpool;
+
+    if (ggml_graph_compute_poll_for_work(state)) {
+        ggml_graph_compute_thread_sync(state);
+        return state->pending;
+    }
+
+    ggml_mutex_lock_shared(&threadpool->mutex);
+    while (!ggml_graph_compute_thread_ready(state)) {
+        // No new work. Wait for the signal.
+        GGML_PRINT_DEBUG("thread #%d waiting for work (sleeping)\n", state->ith);
+        ggml_cond_wait(&threadpool->cond, &threadpool->mutex);
+    }
+    ggml_mutex_unlock_shared(&threadpool->mutex);
+
+    return state->pending;
+}
+
+static thread_ret_t ggml_graph_compute_secondary_thread(void* data) {
+    struct ggml_compute_state * state = (struct ggml_compute_state *) data;
+    struct ggml_threadpool * threadpool = state->threadpool;
+
+    ggml_thread_apply_priority(threadpool->prio);
+    if (ggml_thread_cpumask_is_valid(state->cpumask)) {
+        ggml_thread_apply_affinity(state->cpumask);
+    }
+
+    while (true) {
+        // Check if we need to sleep
+        while (threadpool->pause) {
+            GGML_PRINT_DEBUG("thread #%d inside pause loop\n", state->ith);
+            ggml_mutex_lock_shared(&threadpool->mutex);
+            if (threadpool->pause) {
+                ggml_cond_wait(&threadpool->cond, &threadpool->mutex);
+            }
+            GGML_PRINT_DEBUG("thread #%d resuming after wait\n", state->ith);
+            ggml_mutex_unlock_shared(&threadpool->mutex);
+        }
+
+        // This needs to be checked for after the cond_wait
+        if (threadpool->stop) break;
+
+        // Check if there is new work
+        // The main thread is the only one that can dispatch new work
+
+        ggml_graph_compute_check_for_work(state);
+        if (state->pending) {
+            state->pending = false;
+
+            ggml_graph_compute_thread(state);
+        }
+    }
+
+    return (thread_ret_t) 0;
+}
+
+// Start processing new graph
+static void ggml_graph_compute_kickoff(struct ggml_threadpool * threadpool, int n_threads)
+{
+    // Always take the mutex here because the worker threads are doing hybrid poll/wait
+
+    ggml_mutex_lock(&threadpool->mutex);
+
+    GGML_PRINT_DEBUG("threadpool: n_threads_cur %d n_threads %d\n", threadpool->n_threads_cur, n_threads);
+
+    // Update the number of active threads
+    atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed);
+
+    // Indicate the graph is ready to be processed
+    // We need the full seq-cst fence here because of the polling threads (used in thread_sync)
+    atomic_fetch_add_explicit(&threadpool->n_graph, 1, memory_order_seq_cst);
+
+    if (threadpool->pause) {
+       // Update main thread prio and affinity to match the threadpool settings
+       ggml_thread_apply_priority(threadpool->prio);
+       if (ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) {
+           ggml_thread_apply_affinity(threadpool->workers[0].cpumask);
+       }
+
+       // resume does cond broadcast
+       ggml_threadpool_resume_locked(threadpool);
+    } else {
+       ggml_cond_broadcast(&threadpool->cond);
+    }
+
+    ggml_mutex_unlock(&threadpool->mutex);
+}
+
+#endif // GGML_USE_OPENMP
+
+static struct ggml_threadpool * ggml_threadpool_new_impl(
+    struct ggml_threadpool_params * tpp,
+               struct ggml_cgraph * cgraph,
+                struct ggml_cplan * cplan) {
+
+    struct ggml_threadpool * threadpool =
+        ggml_aligned_malloc(sizeof(struct ggml_threadpool));
+    {
+        threadpool->cgraph           = cgraph;
+        threadpool->cplan            = cplan;
+        threadpool->n_graph          = 0;
+        threadpool->n_barrier        = 0;
+        threadpool->n_barrier_passed = 0;
+        threadpool->current_chunk    = 0;
+        threadpool->stop             = false;
+        threadpool->pause            = tpp->paused;
+        threadpool->abort            = false;
+        threadpool->workers          = NULL;
+        threadpool->n_threads_max    = tpp->n_threads;
+        threadpool->n_threads_cur    = tpp->n_threads;
+        threadpool->poll             = tpp->poll;
+        threadpool->prio             = tpp->prio;
+        threadpool->ec               = GGML_STATUS_SUCCESS;
+    }
+
+    // Allocate and init workers state
+    const size_t workers_size = sizeof(struct ggml_compute_state) * tpp->n_threads;
+    struct ggml_compute_state * workers = ggml_aligned_malloc(workers_size);
+
+    memset(workers, 0, workers_size);
+    for (int j = 0; j < tpp->n_threads; j++) {
+        workers[j].threadpool = threadpool;
+        workers[j].ith        = j;
+    }
+
+    threadpool->workers = workers;
+
+#ifndef GGML_USE_OPENMP
+    ggml_mutex_init(&threadpool->mutex);
+    ggml_cond_init(&threadpool->cond);
+
+    // Spin the threads for all workers, and update CPU placements.
+    // Place the main thread last (towards the higher numbered CPU cores).
+
+    int32_t cpumask_iter = 0;
+
+    for (int j = 1; j < tpp->n_threads; j++) {
+        ggml_thread_cpumask_next(tpp->cpumask, workers[j].cpumask, tpp->strict_cpu, &cpumask_iter);
+
+        int32_t rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_secondary_thread, &workers[j]);
+        GGML_ASSERT(rc == 0);
+    }
+
+    ggml_thread_cpumask_next(tpp->cpumask, workers[0].cpumask, tpp->strict_cpu, &cpumask_iter);
+
+    if (!threadpool->pause) {
+        // Update main thread prio and affinity at the start, otherwise we'll do it in resume
+        ggml_thread_apply_priority(threadpool->prio);
+        if (ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) {
+            ggml_thread_apply_affinity(threadpool->workers[0].cpumask);
+        }
+    }
+#endif // GGML_USE_OPENMP
+
+    return threadpool;
+}
+
+struct ggml_threadpool * ggml_threadpool_new(struct ggml_threadpool_params * tpp) {
+    return ggml_threadpool_new_impl(tpp, NULL, NULL);
+}
+
+enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
+    ggml_cpu_init();
+
+    GGML_ASSERT(cplan);
+    GGML_ASSERT(cplan->n_threads > 0);
+    GGML_ASSERT(cplan->work_size == 0 || cplan->work_data != NULL);
+
+    int n_threads                               = cplan->n_threads;
+    struct ggml_threadpool * threadpool = cplan->threadpool;
+
+    bool disposable_threadpool = false;
+
+    if (threadpool == NULL) {
+        //GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads);
+        disposable_threadpool = true;
+
+        struct ggml_threadpool_params ttp = ggml_threadpool_params_default(n_threads);
+        threadpool = ggml_threadpool_new_impl(&ttp, cgraph, cplan);
+    } else {
+        // Reset some of the parameters that need resetting
+        // No worker threads should be accessing the parameters below at this stage
+        threadpool->cgraph           = cgraph;
+        threadpool->cplan            = cplan;
+        threadpool->current_chunk    = 0;
+        threadpool->abort            = false;
+        threadpool->ec               = GGML_STATUS_SUCCESS;
+    }
+
+#ifdef GGML_USE_OPENMP
+    if (n_threads > 1) {
+        #pragma omp parallel num_threads(n_threads)
+        {
+            #pragma omp single
+            {
+                // update the number of threads from the actual number of threads that we got from OpenMP
+                n_threads = omp_get_num_threads();
+                atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed);
+            }
+
+            ggml_graph_compute_thread(&threadpool->workers[omp_get_thread_num()]);
+        }
+    } else {
+        atomic_store_explicit(&threadpool->n_threads_cur, 1, memory_order_relaxed);
+        ggml_graph_compute_thread(&threadpool->workers[0]);
+    }
+#else
+    if (n_threads > threadpool->n_threads_max) {
+        GGML_LOG_WARN("cplan requested more threads (%d) than available (%d)\n", n_threads, threadpool->n_threads_max);
+        n_threads = threadpool->n_threads_max;
+    }
+
+    // Kick all threads to start the new graph
+    ggml_graph_compute_kickoff(threadpool, n_threads);
+
+    // This is a work thread too
+    ggml_graph_compute_thread(&threadpool->workers[0]);
+#endif
+
+    // don't leave affinity set on the main thread
+    clear_numa_thread_affinity();
+
+    enum ggml_status ret = threadpool->ec;
+
+    if (disposable_threadpool) {
+        ggml_threadpool_free(threadpool);
+    }
+
+    return ret;
+}
+
+enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
+    struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads, NULL);
+
+    cplan.work_data = (uint8_t *)ggml_new_buffer(ctx, cplan.work_size);
+
+    return ggml_graph_compute(cgraph, &cplan);
+}
+
+
+int ggml_cpu_has_avx(void) {
+#if defined(__AVX__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_avx_vnni(void) {
+#if defined(__AVXVNNI__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_avx2(void) {
+#if defined(__AVX2__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_avx512(void) {
+#if defined(__AVX512F__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_avx512_vbmi(void) {
+#if defined(__AVX512VBMI__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_avx512_vnni(void) {
+#if defined(__AVX512VNNI__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_avx512_bf16(void) {
+#if defined(__AVX512BF16__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_amx_int8(void) {
+#if defined(__AMX_INT8__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_fma(void) {
+#if defined(__FMA__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_arm_fma(void) {
+#if defined(__ARM_FEATURE_FMA)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_riscv_v(void) {
+#if defined(__riscv_v_intrinsic)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_f16c(void) {
+#if defined(__F16C__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_fp16_va(void) {
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_wasm_simd(void) {
+#if defined(__wasm_simd128__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_llamafile(void) {
+#if defined(GGML_USE_LLAMAFILE)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_sse3(void) {
+#if defined(__SSE3__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_ssse3(void) {
+#if defined(__SSSE3__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_vsx(void) {
+#if defined(__POWER9_VECTOR__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_neon(void) {
+#if defined(__ARM_ARCH) && defined(__ARM_NEON)
+    return ggml_arm_arch_features.has_neon;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_dotprod(void) {
+#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_DOTPROD)
+    return ggml_arm_arch_features.has_dotprod;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_sve(void) {
+#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SVE)
+    return ggml_arm_arch_features.has_sve;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_matmul_int8(void) {
+#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_MATMUL_INT8)
+    return ggml_arm_arch_features.has_i8mm;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_get_sve_cnt(void) {
+#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SVE)
+    return ggml_arm_arch_features.sve_cnt;
+#else
+    return 0;
+#endif
+}
+
+void ggml_cpu_init(void) {
+    // needed to initialize f16 tables
+    {
+        struct ggml_init_params params = { 0, NULL, false };
+        struct ggml_context * ctx = ggml_init(params);
+        ggml_free(ctx);
+    }
+
+    ggml_critical_section_start();
+
+    static bool is_first_call = true;
+
+    if (is_first_call) {
+        // initialize GELU, Quick GELU, SILU and EXP F32 tables
+        {
+            const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
+
+            for (int i = 0; i < (1 << 16); ++i) {
+                union {
+                    uint16_t u16;
+                    ggml_fp16_t fp16;
+                } u = {i};
+                float f = GGML_FP16_TO_FP32(u.fp16);
+                ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
+                ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
+            }
+
+            const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
+
+            GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0);
+        }
+
+#if defined(__ARM_ARCH)
+        ggml_init_arm_arch_features();
+#endif
+
+        is_first_call = false;
+    }
+
+    ggml_critical_section_end();
+}

+ 741 - 0
llama/ggml-cpu.cpp

@@ -0,0 +1,741 @@
+/**
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "ggml-backend.h"
+#include "ggml-backend-impl.h"
+#include "ggml-cpu.h"
+#include "ggml-cpu-aarch64.h"
+#include "ggml-impl.h"
+#include "amx.h"
+#include <cctype>
+#include <string>
+#include <vector>
+
+#if defined(__APPLE__)
+#include <sys/types.h>
+#include <sys/sysctl.h>
+#endif
+
+#if defined(_WIN32)
+#define WIN32_LEAN_AND_MEAN
+#ifndef NOMINMAX
+    #define NOMINMAX
+#endif
+#include <windows.h>
+#endif
+
+// ggml-backend interface
+
+#ifdef GGML_USE_CPU_HBM
+
+// buffer type HBM
+
+#include <hbwmalloc.h>
+
+static const char * ggml_backend_cpu_hbm_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+    return "CPU_HBM";
+
+    GGML_UNUSED(buft);
+}
+
+static void ggml_backend_cpu_hbm_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    hbw_free(buffer->context);
+}
+
+static ggml_backend_buffer_t ggml_backend_cpu_hbm_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    void * ptr;
+    int result = hbw_posix_memalign(&ptr, ggml_backend_cpu_buffer_type_get_alignment(buft), size);
+    if (result != 0) {
+        GGML_LOG_ERROR("failed to allocate HBM buffer of size %zu\n", size);
+        return NULL;
+    }
+
+    ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
+    buffer->buft = buft;
+    buffer->iface.free_buffer = ggml_backend_cpu_hbm_buffer_free_buffer;
+
+    return buffer;
+}
+
+ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) {
+    static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_hbm = {
+        /* .iface    = */ {
+            /* .get_name         = */ ggml_backend_cpu_hbm_buffer_type_get_name,
+            /* .alloc_buffer     = */ ggml_backend_cpu_hbm_buffer_type_alloc_buffer,
+            /* .get_alignment    = */ ggml_backend_cpu_buffer_type_get_alignment,
+            /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
+            /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
+            /* .is_host          = */ ggml_backend_cpu_buffer_type_is_host,
+        },
+        /* .context  = */ NULL,
+    };
+
+    return &ggml_backend_cpu_buffer_type_hbm;
+}
+#endif
+
+// buffer type AARCH64
+
+static void ggml_backend_cpu_aarch64_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+    tensor->extra = (void *)ggml_aarch64_get_optimal_repack_type(tensor); // NOLINT
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cpu_aarch64_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    GGML_ASSERT(offset == 0);
+    GGML_ASSERT(size == ggml_nbytes(tensor));
+
+    enum ggml_type repack_type = (enum ggml_type)(intptr_t)tensor->extra;
+
+    ggml_aarch64_repack_tensor(tensor, repack_type, data, size);
+
+    GGML_UNUSED(buffer);
+}
+
+static const char * ggml_backend_cpu_aarch64_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+    return "CPU_AARCH64";
+
+    GGML_UNUSED(buft);
+}
+
+static ggml_backend_buffer_t ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    auto * buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
+
+    if (buffer == NULL) {
+        return NULL;
+    }
+
+    buffer->buft = buft;
+    buffer->iface.init_tensor = ggml_backend_cpu_aarch64_buffer_init_tensor;
+    buffer->iface.set_tensor = ggml_backend_cpu_aarch64_buffer_set_tensor;
+
+    return buffer;
+}
+
+ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void) {
+    static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_aarch64 = {
+        /* .iface    = */ {
+            /* .get_name         = */ ggml_backend_cpu_aarch64_buffer_type_get_name,
+            /* .alloc_buffer     = */ ggml_backend_cpu_aarch64_buffer_type_alloc_buffer,
+            /* .get_alignment    = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
+            /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
+            /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
+            /* .is_host          = */ NULL,
+        },
+        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
+        /* .context = */ NULL,
+    };
+
+    return &ggml_backend_cpu_buffer_type_aarch64;
+}
+
+bool ggml_backend_cpu_buft_is_aarch64(ggml_backend_buffer_type_t buft) {
+    return buft == ggml_backend_cpu_aarch64_buffer_type();
+}
+
+static ggml_backend_buffer_type_t * ggml_backend_cpu_get_extra_bufts(ggml_backend_dev_t device) {
+    static std::vector<ggml_backend_buffer_type_t> bufts = []() {
+        std::vector<ggml_backend_buffer_type_t> bufts;
+
+#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
+        if (ggml_backend_amx_buffer_type()) {
+            bufts.push_back(ggml_backend_amx_buffer_type());
+        }
+#endif
+
+#ifdef GGML_USE_CPU_AARCH64
+        if (ggml_backend_cpu_aarch64_buffer_type()) {
+            bufts.push_back(ggml_backend_cpu_aarch64_buffer_type());
+        }
+#endif
+
+        bufts.push_back(NULL);
+
+        return bufts;
+    }();
+
+    return bufts.data();
+
+    GGML_UNUSED(device);
+}
+
+// CPU backend - backend (stream)
+
+struct ggml_backend_cpu_context {
+    int                 n_threads;
+    ggml_threadpool_t   threadpool;
+
+    uint8_t *           work_data;
+    size_t              work_size;
+
+    ggml_abort_callback abort_callback;
+    void *              abort_callback_data;
+};
+
+static const char * ggml_backend_cpu_get_name(ggml_backend_t backend) {
+    return "CPU";
+
+    GGML_UNUSED(backend);
+}
+
+static void ggml_backend_cpu_free(ggml_backend_t backend) {
+    struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
+    delete[] cpu_ctx->work_data;
+    delete cpu_ctx;
+    delete backend;
+}
+
+struct ggml_backend_plan_cpu {
+    struct ggml_cplan cplan;
+    struct ggml_cgraph cgraph;
+};
+
+static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph * cgraph) {
+    struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
+
+    struct ggml_backend_plan_cpu * cpu_plan = new ggml_backend_plan_cpu;
+
+    cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool);
+    cpu_plan->cgraph = *cgraph; // FIXME: deep copy
+
+    if (cpu_plan->cplan.work_size > 0) {
+        cpu_plan->cplan.work_data = new uint8_t[cpu_plan->cplan.work_size];
+        if (cpu_plan->cplan.work_data == NULL) {
+            delete cpu_plan;
+            return NULL;
+        }
+    }
+
+    cpu_plan->cplan.abort_callback      = cpu_ctx->abort_callback;
+    cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data;
+
+    return cpu_plan;
+}
+
+static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+    struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
+
+    delete[] cpu_plan->cplan.work_data;
+    delete cpu_plan;
+
+    GGML_UNUSED(backend);
+}
+
+static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+    struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
+
+    return ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
+
+    GGML_UNUSED(backend);
+}
+
+static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
+
+    struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool);
+
+    if (cpu_ctx->work_size < cplan.work_size) {
+        delete[] cpu_ctx->work_data;
+        cpu_ctx->work_data = new uint8_t[cplan.work_size];
+        if (cpu_ctx->work_data == NULL) {
+            cpu_ctx->work_size = 0;
+            return GGML_STATUS_ALLOC_FAILED;
+        }
+        cpu_ctx->work_size = cplan.work_size;
+    }
+    cplan.work_data = (uint8_t *)cpu_ctx->work_data;
+
+    cplan.abort_callback      = cpu_ctx->abort_callback;
+    cplan.abort_callback_data = cpu_ctx->abort_callback_data;
+
+    return ggml_graph_compute(cgraph, &cplan);
+}
+
+static const struct ggml_backend_i ggml_backend_cpu_i = {
+    /* .get_name                = */ ggml_backend_cpu_get_name,
+    /* .free                    = */ ggml_backend_cpu_free,
+    /* .set_tensor_async        = */ NULL,
+    /* .get_tensor_async        = */ NULL,
+    /* .cpy_tensor_async        = */ NULL,
+    /* .synchronize             = */ NULL,
+    /* .graph_plan_create       = */ ggml_backend_cpu_graph_plan_create,
+    /* .graph_plan_free         = */ ggml_backend_cpu_graph_plan_free,
+    /* .graph_plan_update       = */ NULL,
+    /* .graph_plan_compute      = */ ggml_backend_cpu_graph_plan_compute,
+    /* .graph_compute           = */ ggml_backend_cpu_graph_compute,
+    /* .event_record            = */ NULL,
+    /* .event_wait              = */ NULL,
+};
+
+static ggml_guid_t ggml_backend_cpu_guid(void) {
+    static ggml_guid guid = { 0xaa, 0x67, 0xc7, 0x43, 0x96, 0xe6, 0xa3, 0x8a, 0xe3, 0xaf, 0xea, 0x92, 0x36, 0xbc, 0xfc, 0x89 };
+    return &guid;
+}
+
+ggml_backend_t ggml_backend_cpu_init(void) {
+    // initialize CPU backend now to avoid slowing the first graph computation
+    ggml_cpu_init();
+
+    struct ggml_backend_cpu_context * ctx = new ggml_backend_cpu_context;
+    if (ctx == NULL) {
+        return NULL;
+    }
+
+    ctx->n_threads           = GGML_DEFAULT_N_THREADS;
+    ctx->threadpool          = NULL;
+    ctx->work_data           = NULL;
+    ctx->work_size           = 0;
+    ctx->abort_callback      = NULL;
+    ctx->abort_callback_data = NULL;
+
+    ggml_backend_t cpu_backend = new ggml_backend {
+        /* .guid      = */ ggml_backend_cpu_guid(),
+        /* .interface = */ ggml_backend_cpu_i,
+        /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
+        /* .context   = */ ctx,
+    };
+
+    if (cpu_backend == NULL) {
+        delete ctx;
+        return NULL;
+    }
+
+    return cpu_backend;
+}
+
+bool ggml_backend_is_cpu(ggml_backend_t backend) {
+    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cpu_guid());
+}
+
+void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
+    GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
+
+    struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
+    ctx->n_threads = n_threads;
+}
+
+void ggml_backend_cpu_set_threadpool(ggml_backend_t backend_cpu, ggml_threadpool_t threadpool) {
+    GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
+
+    struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
+
+    if (ctx->threadpool && ctx->threadpool != threadpool) {
+        // already had a different threadpool, pause/suspend it before switching
+        ggml_threadpool_pause(ctx->threadpool);
+    }
+    ctx->threadpool = threadpool;
+}
+
+void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data) {
+    GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
+
+    struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
+    ctx->abort_callback = abort_callback;
+    ctx->abort_callback_data = abort_callback_data;
+}
+
+// CPU backend - device
+
+struct ggml_backend_cpu_device_context {
+    std::string description = "CPU";
+
+    ggml_backend_cpu_device_context() {
+#ifdef __APPLE__
+        size_t len = 0;
+        if (!sysctlbyname("machdep.cpu.brand_string", NULL, &len, NULL, 0)) {
+            description.resize(len);
+            sysctlbyname("machdep.cpu.brand_string", &description[0], &len, NULL, 0); // NOLINT
+        }
+#elif defined(__linux__)
+        FILE * f = fopen("/proc/cpuinfo", "r");
+        if (f) {
+            char buf[1024];
+            while (fgets(buf, sizeof(buf), f)) {
+                if (strncmp(buf, "model name", 10) == 0) {
+                    char * p = strchr(buf, ':');
+                    if (p) {
+                        p++;
+                        while (std::isspace(*p)) {
+                            p++;
+                        }
+                        while (std::isspace(p[strlen(p) - 1])) {
+                            p[strlen(p) - 1] = '\0';
+                        }
+                        description = p;
+                        break;
+                    }
+                }
+            }
+            fclose(f);
+        }
+#elif defined(_WIN32)
+        HKEY hKey;
+        if (RegOpenKeyEx(HKEY_LOCAL_MACHINE,
+                        TEXT("HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0"),
+                        0,
+                        KEY_READ,
+                        &hKey) == ERROR_SUCCESS) {
+            DWORD cpu_brand_size = 0;
+            if (RegQueryValueExA(hKey,
+                                TEXT("ProcessorNameString"),
+                                NULL,
+                                NULL,
+                                NULL,
+                                &cpu_brand_size) == ERROR_SUCCESS) {
+                description.resize(cpu_brand_size);
+                if (RegQueryValueExA(hKey,
+                                    TEXT("ProcessorNameString"),
+                                    NULL,
+                                    NULL,
+                                    (LPBYTE)&description[0], // NOLINT
+                                    &cpu_brand_size) == ERROR_SUCCESS) {
+                    if (description.find('\0') != std::string::npos) {
+                        description.resize(description.find('\0'));
+                    }
+                }
+            }
+            RegCloseKey(hKey);
+        }
+#endif
+    }
+};
+
+static const char * ggml_backend_cpu_device_get_name(ggml_backend_dev_t dev) {
+    return "CPU";
+
+    GGML_UNUSED(dev);
+}
+
+static const char * ggml_backend_cpu_device_get_description(ggml_backend_dev_t dev) {
+    struct ggml_backend_cpu_device_context * ctx = (struct ggml_backend_cpu_device_context *)dev->context;
+
+    return ctx->description.c_str();
+}
+
+static void ggml_backend_cpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
+    // TODO
+    *free = 0;
+    *total = 0;
+
+    GGML_UNUSED(dev);
+}
+
+static enum ggml_backend_dev_type ggml_backend_cpu_device_get_type(ggml_backend_dev_t dev) {
+    return GGML_BACKEND_DEVICE_TYPE_CPU;
+
+    GGML_UNUSED(dev);
+}
+
+static void ggml_backend_cpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
+    props->name        = ggml_backend_cpu_device_get_name(dev);
+    props->description = ggml_backend_cpu_device_get_description(dev);
+    props->type        = ggml_backend_cpu_device_get_type(dev);
+    ggml_backend_cpu_device_get_memory(dev, &props->memory_free, &props->memory_total);
+    props->caps = {
+        /* .async                 = */ false,
+        /* .host_buffer           = */ false,
+        /* .buffer_from_host_ptr  = */ true,
+        /* .events                = */ false,
+    };
+}
+
+static ggml_backend_t ggml_backend_cpu_device_init_backend(ggml_backend_dev_t dev, const char * params) {
+    return ggml_backend_cpu_init();
+
+    GGML_UNUSED(dev);
+    GGML_UNUSED(params);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_cpu_device_get_buffer_type(ggml_backend_dev_t dev) {
+    return ggml_backend_cpu_buffer_type();
+
+    GGML_UNUSED(dev);
+}
+
+static ggml_backend_buffer_t ggml_backend_cpu_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
+    return ggml_backend_cpu_buffer_from_ptr(ptr, size);
+
+    GGML_UNUSED(dev);
+    GGML_UNUSED(max_tensor_size);
+}
+
+static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
+    const struct ggml_tensor * src0 = op->src[0];
+    const struct ggml_tensor * src1 = op->src[1];
+
+    if (op->op == GGML_OP_NONE || op->op == GGML_OP_RESHAPE || op->op == GGML_OP_VIEW || op->op == GGML_OP_PERMUTE || op->op == GGML_OP_TRANSPOSE) {
+        return true;
+    }
+
+    if (src0 && src0->buffer && ggml_backend_cpu_buft_is_aarch64(src0->buffer->buft)) {
+        if (op->op != GGML_OP_MUL_MAT || src0->type == ggml_aarch64_get_optimal_repack_type(src0)) {
+            return false;
+        }
+    }
+
+#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
+    if (src0 && src0->buffer && ggml_backend_amx_buft_is_amx(src0->buffer->buft)) {
+        return ggml_backend_amx_device_supports_op(op);
+    }
+    for (int i = 1; i < GGML_MAX_SRC; i++) {
+        if (op->src[i] && op->src[i]->buffer && ggml_backend_amx_buft_is_amx(op->src[i]->buffer->buft)) {
+            return false;
+        }
+    }
+#endif
+
+    for (int i = 1; i < GGML_MAX_SRC; i++) {
+        if (op->src[i] && op->src[i]->buffer && ggml_backend_cpu_buft_is_aarch64(op->src[i]->buffer->buft)) {
+            return false;
+        }
+    }
+
+    switch (op->op) {
+        case GGML_OP_CPY:
+            return
+                op->type != GGML_TYPE_IQ2_XXS &&
+                op->type != GGML_TYPE_IQ2_XS  &&
+                op->type != GGML_TYPE_IQ1_S   &&
+                op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
+        case GGML_OP_MUL_MAT:
+            return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
+        case GGML_OP_ROPE_BACK:
+            return op->src[2] == NULL && (op->op_params[2] & 4) == 0;
+        case GGML_OP_IM2COL_BACK:
+            return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
+        case GGML_OP_OUT_PROD:
+            return (src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32;
+        default:
+            return true;
+    }
+
+    GGML_UNUSED(dev);
+}
+
+static bool ggml_backend_cpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
+    bool supported = ggml_backend_buft_is_host(buft) || ggml_backend_cpu_buft_is_aarch64(buft);
+
+#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
+    supported = supported || ggml_backend_amx_buft_is_amx(buft);
+#endif
+
+    return supported;
+
+    GGML_UNUSED(dev);
+}
+
+static const struct ggml_backend_device_i ggml_backend_cpu_device_i = {
+    /* .get_name             = */ ggml_backend_cpu_device_get_name,
+    /* .get_description      = */ ggml_backend_cpu_device_get_description,
+    /* .get_memory           = */ ggml_backend_cpu_device_get_memory,
+    /* .get_type             = */ ggml_backend_cpu_device_get_type,
+    /* .get_props            = */ ggml_backend_cpu_device_get_props,
+    /* .init_backend         = */ ggml_backend_cpu_device_init_backend,
+    /* .get_buffer_type      = */ ggml_backend_cpu_device_get_buffer_type,
+    /* .get_host_buffer_type = */ NULL,
+    /* .buffer_from_host_ptr = */ ggml_backend_cpu_device_buffer_from_host_ptr,
+    /* .supports_op          = */ ggml_backend_cpu_device_supports_op,
+    /* .supports_buft        = */ ggml_backend_cpu_device_supports_buft,
+    /* .offload_op           = */ NULL,
+    /* .event_new            = */ NULL,
+    /* .event_free           = */ NULL,
+    /* .event_synchronize    = */ NULL,
+};
+
+// CPU backend - backend (reg)
+
+static const char * ggml_backend_cpu_reg_get_name(ggml_backend_reg_t reg) {
+    return "CPU";
+
+    GGML_UNUSED(reg);
+}
+
+static size_t ggml_backend_cpu_reg_get_device_count(ggml_backend_reg_t reg) {
+    return 1;
+
+    GGML_UNUSED(reg);
+}
+
+static ggml_backend_dev_t ggml_backend_cpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
+    GGML_ASSERT(index == 0);
+
+    static ggml_backend_cpu_device_context ctx;
+    static ggml_backend_device ggml_backend_cpu_device = {
+        /* .iface   = */ ggml_backend_cpu_device_i,
+        /* .reg     = */ reg,
+        /* .context = */ &ctx,
+    };
+
+    return &ggml_backend_cpu_device;
+}
+
+// This is intended to replace the the ggml_cpu_has_* functions when loading the CPU backend dynamically,
+// and additionally to allow other backends to expose their own list of features that applications can query using the same API
+static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t reg) {
+    static std::vector<ggml_backend_feature> features = []() {
+        ggml_cpu_init();
+
+        std::vector<ggml_backend_feature> features;
+        if (ggml_cpu_has_sse3()) {
+            features.push_back({ "SSE3", "1" });
+        }
+        if (ggml_cpu_has_ssse3()) {
+            features.push_back({ "SSSE3", "1" });
+        }
+        if (ggml_cpu_has_avx()) {
+            features.push_back({ "AVX", "1" });
+        }
+        if (ggml_cpu_has_avx_vnni()) {
+            features.push_back({ "AVX_VNNI", "1" });
+        }
+        if (ggml_cpu_has_avx2()) {
+            features.push_back({ "AVX2", "1" });
+        }
+        if (ggml_cpu_has_f16c()) {
+            features.push_back({ "F16C", "1" });
+        }
+        if (ggml_cpu_has_fma()) {
+            features.push_back({ "FMA", "1" });
+        }
+        if (ggml_cpu_has_avx512()) {
+            features.push_back({ "AVX512", "1" });
+        }
+        if (ggml_cpu_has_avx512_vbmi()) {
+            features.push_back({ "AVX512_VBMI", "1" });
+        }
+        if (ggml_cpu_has_avx512_vnni()) {
+            features.push_back({ "AVX512_VNNI", "1" });
+        }
+        if (ggml_cpu_has_avx512_bf16()) {
+            features.push_back({ "AVX512_BF16", "1" });
+        }
+        if (ggml_cpu_has_amx_int8()) {
+            features.push_back({ "AMX_INT8", "1" });
+        }
+        if (ggml_cpu_has_neon()) {
+            features.push_back({ "NEON", "1" });
+        }
+        if (ggml_cpu_has_arm_fma()) {
+            features.push_back({ "ARM_FMA", "1" });
+        }
+        if (ggml_cpu_has_fp16_va()) {
+            features.push_back({ "FP16_VA", "1" });
+        }
+        if (ggml_cpu_has_matmul_int8()) {
+            features.push_back({ "MATMUL_INT8", "1" });
+        }
+        if (ggml_cpu_has_sve()) {
+            features.push_back({ "SVE", "1" });
+        }
+        if (ggml_cpu_get_sve_cnt() > 0) {
+            static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt());
+            features.push_back({ "SVE_CNT", sve_cnt.c_str() });
+        }
+        if (ggml_cpu_has_riscv_v()) {
+            features.push_back({ "RISCV_V", "1" });
+        }
+        if (ggml_cpu_has_vsx()) {
+            features.push_back({ "VSX", "1" });
+        }
+        if (ggml_cpu_has_wasm_simd()) {
+            features.push_back({ "WASM_SIMD", "1" });
+        }
+        if (ggml_cpu_has_llamafile()) {
+            features.push_back({ "LLAMAFILE", "1" });
+        }
+        // TODO: rename this
+    #ifdef GGML_USE_CPU_AARCH64
+        features.push_back({ "AARCH64_REPACK", "1" });
+    #endif
+
+        features.push_back({ nullptr, nullptr });
+
+        return features;
+    }();
+
+    return features.data();
+
+    GGML_UNUSED(reg);
+}
+
+static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const char * name) {
+    if (strcmp(name, "ggml_backend_set_n_threads") == 0) {
+        return (void *)ggml_backend_cpu_set_n_threads;
+    }
+    if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0) {
+        return (void *)ggml_backend_cpu_get_extra_bufts;
+    }
+    if (strcmp(name, "ggml_backend_get_features") == 0) {
+        return (void *)ggml_backend_cpu_get_features;
+    }
+    if (strcmp(name, "ggml_backend_set_abort_callback") == 0) {
+        return (void *)ggml_backend_cpu_set_abort_callback;
+    }
+    if (strcmp(name, "ggml_backend_cpu_numa_init") == 0) {
+        return (void *)ggml_numa_init;
+    }
+    if (strcmp(name, "ggml_backend_cpu_is_numa") == 0) {
+        return (void *)ggml_is_numa;
+    }
+
+    // threadpool - TODO:  move to ggml-base
+    if (strcmp(name, "ggml_threadpool_new") == 0) {
+        return (void *)ggml_threadpool_new;
+    }
+    if (strcmp(name, "ggml_threadpool_free") == 0) {
+        return (void *)ggml_threadpool_free;
+    }
+    if (strcmp(name, "ggml_backend_cpu_set_threadpool") == 0) {
+        return (void *)ggml_backend_cpu_set_threadpool;
+    }
+
+    return NULL;
+
+    GGML_UNUSED(reg);
+}
+
+static const struct ggml_backend_reg_i ggml_backend_cpu_reg_i = {
+    /* .get_name         = */ ggml_backend_cpu_reg_get_name,
+    /* .get_device_count = */ ggml_backend_cpu_reg_get_device_count,
+    /* .get_device       = */ ggml_backend_cpu_reg_get_device,
+    /* .get_proc_address = */ ggml_backend_cpu_get_proc_address,
+};
+
+ggml_backend_reg_t ggml_backend_cpu_reg(void) {
+    // init CPU feature detection
+    ggml_cpu_init();
+
+    static struct ggml_backend_reg ggml_backend_cpu_reg = {
+        /* .api_version = */ GGML_BACKEND_API_VERSION,
+        /* .iface       = */ ggml_backend_cpu_reg_i,
+        /* .context     = */ NULL,
+    };
+
+    return &ggml_backend_cpu_reg;
+}
+
+GGML_BACKEND_DL_IMPL(ggml_backend_cpu_reg)

+ 178 - 0
llama/ggml-cpu.h

@@ -0,0 +1,178 @@
+/**
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+    // the compute plan that needs to be prepared for ggml_graph_compute()
+    // since https://github.com/ggerganov/ggml/issues/287
+    struct ggml_cplan {
+        size_t    work_size; // size of work buffer, calculated by `ggml_graph_plan()`
+        uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()`
+
+        int n_threads;
+        struct ggml_threadpool * threadpool;
+
+        // abort ggml_graph_compute when true
+        ggml_abort_callback abort_callback;
+        void *              abort_callback_data;
+    };
+
+    // numa strategies
+    enum ggml_numa_strategy {
+        GGML_NUMA_STRATEGY_DISABLED   = 0,
+        GGML_NUMA_STRATEGY_DISTRIBUTE = 1,
+        GGML_NUMA_STRATEGY_ISOLATE    = 2,
+        GGML_NUMA_STRATEGY_NUMACTL    = 3,
+        GGML_NUMA_STRATEGY_MIRROR     = 4,
+        GGML_NUMA_STRATEGY_COUNT
+    };
+
+    GGML_BACKEND_API void    ggml_numa_init(enum ggml_numa_strategy numa); // call once for better performance on NUMA systems
+    GGML_BACKEND_API bool    ggml_is_numa(void); // true if init detected that system has >1 NUMA node
+
+    GGML_BACKEND_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
+    GGML_BACKEND_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
+
+    GGML_BACKEND_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
+    GGML_BACKEND_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
+
+    GGML_BACKEND_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
+    GGML_BACKEND_API void    ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
+
+    GGML_BACKEND_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
+    GGML_BACKEND_API void    ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value);
+
+    GGML_BACKEND_API float   ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
+    GGML_BACKEND_API void    ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
+
+    GGML_BACKEND_API float   ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
+    GGML_BACKEND_API void    ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value);
+
+    GGML_BACKEND_API struct ggml_threadpool *      ggml_threadpool_new           (struct ggml_threadpool_params  * params);
+    GGML_BACKEND_API void                          ggml_threadpool_free          (struct ggml_threadpool * threadpool);
+    GGML_BACKEND_API int                           ggml_threadpool_get_n_threads (struct ggml_threadpool * threadpool);
+    GGML_BACKEND_API void                          ggml_threadpool_pause         (struct ggml_threadpool * threadpool);
+    GGML_BACKEND_API void                          ggml_threadpool_resume        (struct ggml_threadpool * threadpool);
+
+    // ggml_graph_plan() has to be called before ggml_graph_compute()
+    // when plan.work_size > 0, caller must allocate memory for plan.work_data
+    GGML_BACKEND_API struct ggml_cplan ggml_graph_plan(
+                  const struct ggml_cgraph * cgraph,
+                                       int   n_threads, /* = GGML_DEFAULT_N_THREADS */
+                    struct ggml_threadpool * threadpool /* = NULL */ );
+    GGML_BACKEND_API enum ggml_status  ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
+
+    // same as ggml_graph_compute() but the work data is allocated as a part of the context
+    // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
+    GGML_BACKEND_API enum ggml_status  ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
+
+    //
+    // system info
+    //
+
+    // x86
+    GGML_BACKEND_API int ggml_cpu_has_sse3       (void);
+    GGML_BACKEND_API int ggml_cpu_has_ssse3      (void);
+    GGML_BACKEND_API int ggml_cpu_has_avx        (void);
+    GGML_BACKEND_API int ggml_cpu_has_avx_vnni   (void);
+    GGML_BACKEND_API int ggml_cpu_has_avx2       (void);
+    GGML_BACKEND_API int ggml_cpu_has_f16c       (void);
+    GGML_BACKEND_API int ggml_cpu_has_fma        (void);
+    GGML_BACKEND_API int ggml_cpu_has_avx512     (void);
+    GGML_BACKEND_API int ggml_cpu_has_avx512_vbmi(void);
+    GGML_BACKEND_API int ggml_cpu_has_avx512_vnni(void);
+    GGML_BACKEND_API int ggml_cpu_has_avx512_bf16(void);
+    GGML_BACKEND_API int ggml_cpu_has_amx_int8   (void);
+    // ARM
+    GGML_BACKEND_API int ggml_cpu_has_neon       (void);
+    GGML_BACKEND_API int ggml_cpu_has_arm_fma    (void);
+    GGML_BACKEND_API int ggml_cpu_has_fp16_va    (void);
+    GGML_BACKEND_API int ggml_cpu_has_dotprod    (void);
+    GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void);
+    GGML_BACKEND_API int ggml_cpu_has_sve        (void);
+    GGML_BACKEND_API int ggml_cpu_get_sve_cnt    (void);  // sve vector length in bytes
+    // other
+    GGML_BACKEND_API int ggml_cpu_has_riscv_v    (void);
+    GGML_BACKEND_API int ggml_cpu_has_vsx        (void);
+    GGML_BACKEND_API int ggml_cpu_has_wasm_simd  (void);
+    GGML_BACKEND_API int ggml_cpu_has_llamafile  (void);
+
+    // Internal types and functions exposed for tests and benchmarks
+
+    typedef void (*ggml_from_float_to_mat_t)
+                                     (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr, int64_t k, int64_t bs);
+    typedef void (*ggml_vec_dot_t)  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx,
+                                       const void * GGML_RESTRICT y, size_t by, int nrc);
+    typedef void (*ggml_gemv_t)     (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
+                                       const void * GGML_RESTRICT y, int nr, int nc);
+    typedef void (*ggml_gemm_t)     (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
+                                       const void * GGML_RESTRICT y, int nr, int nc);
+
+    struct ggml_type_traits_cpu {
+        ggml_from_float_t        from_float;
+        ggml_from_float_to_mat_t from_float_to_mat;
+        ggml_vec_dot_t           vec_dot;
+        enum ggml_type           vec_dot_type;
+        int64_t                  nrows; // number of rows to process simultaneously
+        int64_t                  ncols; // number of columns to process simultaneously
+        ggml_gemv_t              gemv;
+        ggml_gemm_t              gemm;
+    };
+
+    GGML_BACKEND_API const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type);
+
+    GGML_BACKEND_API void ggml_cpu_init(void);
+
+    //
+    // CPU backend
+    //
+
+    GGML_BACKEND_API ggml_backend_t ggml_backend_cpu_init(void);
+
+    GGML_BACKEND_API bool ggml_backend_is_cpu                (ggml_backend_t backend);
+    GGML_BACKEND_API void ggml_backend_cpu_set_n_threads     (ggml_backend_t backend_cpu, int n_threads);
+    GGML_BACKEND_API void ggml_backend_cpu_set_threadpool    (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool);
+    GGML_BACKEND_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
+
+    GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
+
+#ifdef GGML_USE_CPU_HBM
+    GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);
+#endif
+
+    GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void);
+    GGML_BACKEND_API bool ggml_backend_cpu_buft_is_aarch64(ggml_backend_buffer_type_t buft);
+
+#ifdef __cplusplus
+}
+#endif

+ 17 - 19
llama/ggml-cuda.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -29,7 +29,11 @@
 #include "ggml.h"
 #include "ggml-backend.h"
 
-#ifdef GGML_USE_HIPBLAS
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+#ifdef GGML_USE_HIP
 #define GGML_CUDA_NAME "ROCm"
 #define GGML_CUBLAS_NAME "hipBLAS"
 #elif defined(GGML_USE_MUSA)
@@ -39,37 +43,31 @@
 #define GGML_CUDA_NAME "CUDA"
 #define GGML_CUBLAS_NAME "cuBLAS"
 #endif
-
-#ifdef  __cplusplus
-extern "C" {
-#endif
-
 #define GGML_CUDA_MAX_DEVICES       16
 
 // backend API
-GGML_API GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device);
+GGML_BACKEND_API ggml_backend_t ggml_backend_cuda_init(int device);
 
-GGML_API GGML_CALL bool ggml_backend_is_cuda(ggml_backend_t backend);
+GGML_BACKEND_API bool ggml_backend_is_cuda(ggml_backend_t backend);
 
 // device buffer
-GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
+GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
 
 // split tensor buffer that splits matrices by rows across multiple devices
-GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split);
+GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split);
 
 // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
-GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
+GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
 
-GGML_API GGML_CALL int ggml_backend_cuda_reg_devices();
+GGML_BACKEND_API int  ggml_backend_cuda_get_device_count(void);
+GGML_BACKEND_API void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
+GGML_BACKEND_API void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
 
-GGML_API GGML_CALL int  ggml_backend_cuda_get_device_count(void);
-GGML_API GGML_CALL void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
-GGML_API GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
+GGML_BACKEND_API bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
+GGML_BACKEND_API void ggml_backend_cuda_unregister_host_buffer(void * buffer);
 
-GGML_API GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
-GGML_API GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer);
+GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cuda_reg(void);
 
-GGML_API void ggml_backend_cuda_log_set_callback(ggml_log_callback log_callback, void * user_data);
 #ifdef  __cplusplus
 }
 #endif

+ 1 - 1
llama/ggml-cuda/acc.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/acc.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/arange.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/arange.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 117 - 0
llama/ggml-cuda/argmax.cu

@@ -0,0 +1,117 @@
+/**
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include <algorithm>
+#include <cstdint>
+
+#include "argmax.cuh"
+#include "common.cuh"
+#include "sum.cuh"
+
+static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __restrict__ dst, const int64_t ncols) {
+    const int64_t row = blockIdx.x;
+
+    float maxval = -FLT_MAX;
+    int   argmax = -1;
+    const float * rowx = x + row * ncols;
+
+    for (int32_t col = threadIdx.x; col < ncols; col += blockDim.x) {
+        const float val = rowx[col];
+        if (val > maxval) {
+            maxval = val;
+            argmax = col;
+        }
+    }
+
+#pragma unroll
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
+        const int   col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
+        if (val > maxval) {
+            maxval = val;
+            argmax = col;
+        }
+    }
+
+    const int n_warps = blockDim.x / WARP_SIZE;
+    const int lane_id = threadIdx.x % WARP_SIZE;
+    const int warp_id = threadIdx.x / WARP_SIZE;
+    if (n_warps > 1) {
+        constexpr int    max_warps = 1024 / WARP_SIZE;
+        __shared__ float shared_maxval[max_warps];
+        __shared__ int   shared_argmax[max_warps];
+        if (lane_id == 0) {
+            shared_maxval[warp_id] = maxval;
+            shared_argmax[warp_id] = argmax;
+        }
+
+        __syncthreads();
+
+        if (warp_id == 0) {
+            if (lane_id < n_warps) {
+                maxval = shared_maxval[lane_id];
+                argmax = shared_argmax[lane_id];
+            }
+#pragma unroll
+            for (int offset = 16; offset > 0; offset >>= 1) {
+                const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
+                const int   col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
+                if (val > maxval) {
+                    maxval = val;
+                    argmax = col;
+                }
+            }
+        }
+    }
+
+    if (warp_id == 0 && lane_id == 0) {
+        dst[row] = argmax;
+    }
+}
+
+void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_I32);
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    const int64_t ne00  = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    const float * src0_d = (const float *) src0->data;
+    int32_t     * dst_d  = (int32_t     *) dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    const int64_t num_blocks = nrows;
+    const int64_t num_threads = std::min<int64_t>(1024, (ne00 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE);
+    const dim3 blocks_dim(num_threads, 1, 1);
+    const dim3 blocks_num(num_blocks, 1, 1);
+
+    argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00);
+}

+ 29 - 0
llama/ggml-cuda/argmax.cuh

@@ -0,0 +1,29 @@
+/**
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 1 - 1
llama/ggml-cuda/argsort.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/argsort.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/binbcast.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/binbcast.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/clamp.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/clamp.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 63 - 40
llama/ggml-cuda/common.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -32,7 +32,7 @@
 #include <cstdint>
 #include <memory>
 
-#if defined(GGML_USE_HIPBLAS)
+#if defined(GGML_USE_HIP)
 #define GGML_COMMON_DECL_HIP
 #define GGML_COMMON_IMPL_HIP
 #else
@@ -52,13 +52,13 @@
 #include <string>
 #include <vector>
 
-#if defined(GGML_USE_HIPBLAS)
+#if defined(GGML_USE_HIP)
 #include "vendors/hip.h"
 #elif defined(GGML_USE_MUSA)
 #include "vendors/musa.h"
 #else
 #include "vendors/cuda.h"
-#endif // defined(GGML_USE_HIPBLAS)
+#endif // defined(GGML_USE_HIP)
 
 #define STRINGIZE_IMPL(...) #__VA_ARGS__
 #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
@@ -73,9 +73,20 @@
 #define CC_TURING     750
 #define CC_AMPERE     800
 #define CC_OFFSET_AMD 1000000
-#define CC_RDNA1      (CC_OFFSET_AMD + 1010)
-#define CC_RDNA2      (CC_OFFSET_AMD + 1030)
-#define CC_RDNA3      (CC_OFFSET_AMD + 1100)
+
+// GCN/CNDA, wave size is 64
+#define CC_GCN4       (CC_OFFSET_AMD + 803)  // Tonga, Fiji, Polaris, minimum for fast fp16
+#define CC_VEGA       (CC_OFFSET_AMD + 900)  // Vega56/64, minimum for fp16 dual issue
+#define CC_VEGA20     (CC_OFFSET_AMD + 906)  // MI50/Radeon VII, minimum for dp4a
+#define CC_CDNA       (CC_OFFSET_AMD + 908)  // MI100, minimum for MFMA, acc registers
+#define CC_CDNA2      (CC_OFFSET_AMD + 910)  // MI210, minimum acc register renameing
+#define CC_CDNA3      (CC_OFFSET_AMD + 942)  // MI300
+
+// RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32
+#define CC_RDNA1      (CC_OFFSET_AMD + 1010) // RX 5000
+#define CC_RDNA2      (CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a
+#define CC_RDNA3      (CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA
+
 #define CC_QY1        210
 #define CC_QY2        220
 
@@ -123,7 +134,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
 
 #define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
 
-#if !defined(GGML_USE_HIPBLAS)
+#if !defined(GGML_USE_HIP)
 static const char * cu_get_error_str(CUresult err) {
     const char * err_str;
     cuGetErrorString(err, &err_str);
@@ -146,21 +157,21 @@ typedef float dfloat; // dequantize float
 typedef float2 dfloat2;
 #endif // GGML_CUDA_F16
 
-#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
+#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
 #define FP16_AVAILABLE
-#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
+#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
 
 #if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
 #define FAST_FP16_AVAILABLE
 #endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
 
-#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
 #define FP16_MMA_AVAILABLE
-#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
 
-#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
 #define INT8_MMA_AVAILABLE
-#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
 
 #if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
 #define FLASH_ATTN_AVAILABLE
@@ -182,14 +193,14 @@ static constexpr bool int8_mma_available(const int cc) {
 static __device__ void no_device_code(
     const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
 
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
     printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
            file_name, line, function_name, arch);
     GGML_UNUSED(arch_list);
 #else
     printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
            file_name, line, function_name, arch, arch_list);
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
     __trap();
 
     GGML_UNUSED(no_device_code); // suppress unused function warning
@@ -201,19 +212,31 @@ static __device__ void no_device_code(
 #define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
 #endif // __CUDA_ARCH__
 
+static __device__ __forceinline__ int warp_reduce_sum(int x) {
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
+    return __reduce_add_sync(0xffffffff, x);
+#else
+#pragma unroll
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        x += __shfl_xor_sync(0xffffffff, x, offset, 32);
+    }
+    return x;
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
+}
+
 static __device__ __forceinline__ float warp_reduce_sum(float x) {
 #pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        x += __shfl_xor_sync(0xffffffff, x, mask, 32);
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        x += __shfl_xor_sync(0xffffffff, x, offset, 32);
     }
     return x;
 }
 
 static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
 #pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
-        a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        a.x += __shfl_xor_sync(0xffffffff, a.x, offset, 32);
+        a.y += __shfl_xor_sync(0xffffffff, a.y, offset, 32);
     }
     return a;
 }
@@ -221,21 +244,21 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
 static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
 #ifdef FP16_AVAILABLE
 
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 #pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32);
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        const half2 a_other = __shfl_xor_sync(0xffffffff, a, offset, 32);
         reinterpret_cast<half&>(a.x) +=  __low2half(a_other);
         reinterpret_cast<half&>(a.y) += __high2half(a_other);
     }
     return a;
 #else
 #pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, 32));
     }
     return a;
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 
 #else
     NO_DEVICE_CODE;
@@ -245,8 +268,8 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
 
 static __device__ __forceinline__ float warp_reduce_max(float x) {
 #pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
     }
     return x;
 }
@@ -254,11 +277,11 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
 static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
 #ifdef FP16_AVAILABLE
 
-#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
     return __float2half(fmaxf(__half2float(a), __half2float(b)));
 #else
     return __hmax(a, b);
-#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
 
 #else
    NO_DEVICE_CODE;
@@ -268,7 +291,7 @@ static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b
 }
 
 static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
-#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 
 #if CUDART_VERSION >= CUDART_HMAX
     return __hmax2(a, b);
@@ -283,20 +306,20 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
     GGML_UNUSED(a);
     GGML_UNUSED(b);
     NO_DEVICE_CODE;
-#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 }
 
 static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
-#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
 #pragma unroll
-   for (int mask = 16; mask > 0; mask >>= 1) {
-       x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
+   for (int offset = 16; offset > 0; offset >>= 1) {
+       x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
    }
    return x;
 #else
    GGML_UNUSED(x);
    NO_DEVICE_CODE;
-#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
 }
 
 #if CUDART_VERSION < CUDART_HMASK
@@ -308,7 +331,7 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
 #endif // CUDART_VERSION < CUDART_HMASK
 
 static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
     c = __builtin_amdgcn_sdot4(a, b, c, false);
 #elif defined(RDNA3)
@@ -334,7 +357,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
 #endif
     return c;
 
-#else // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 
 #if __CUDA_ARCH__ >= MIN_CC_DP4A
     return __dp4a(a, b, c);
@@ -344,7 +367,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
     return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 }
 
 // TODO: move to ggml-common.h

+ 1 - 1
llama/ggml-cuda/concat.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/concat.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/conv-transpose-1d.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/conv-transpose-1d.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/convert.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/convert.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 90 - 0
llama/ggml-cuda/count-equal.cu

@@ -0,0 +1,90 @@
+/**
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+#include "count-equal.cuh"
+
+#include <cstdint>
+
+template <typename T>
+static __global__ void count_equal(const T * __restrict__ x, const T * __restrict__ y, int64_t * __restrict__ dst, const int64_t dk, const int64_t k) {
+    const int64_t i0 = (int64_t) blockIdx.x*dk;
+    const int64_t i1 = min(i0 + dk, k);
+
+    int nequal = 0;
+
+    for (int64_t i = i0 + threadIdx.x; i < i1; i += WARP_SIZE) {
+        const T xi = x[i];
+        const T yi = y[i];
+        nequal += xi == yi;
+    }
+
+    nequal = warp_reduce_sum(nequal);
+
+    if (threadIdx.x != 0) {
+        return;
+    }
+
+    atomicAdd((int *) dst, nequal);
+}
+
+void ggml_cuda_count_equal(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == src1->type);
+    GGML_ASSERT( dst->type == GGML_TYPE_I64);
+
+    GGML_ASSERT(ggml_are_same_shape(src0, src1));
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    int64_t * dst_d  = (int64_t *) dst->data;
+
+    cudaStream_t stream = ctx.stream();
+    const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
+
+    const int64_t ne = ggml_nelements(src0);
+    GGML_ASSERT(ne < (1 << 30) && "atomicAdd implementation only supports int");
+    const int64_t dne = GGML_PAD((ne + 4*nsm - 1) / (4*nsm), CUDA_COUNT_EQUAL_CHUNK_SIZE);
+
+    CUDA_CHECK(cudaMemsetAsync(dst_d, 0, ggml_nbytes(dst), stream));
+
+    const dim3 blocks_dim(WARP_SIZE, 1, 1);
+    const dim3 blocks_num(std::min((int64_t)4*nsm, (ne + CUDA_COUNT_EQUAL_CHUNK_SIZE - 1)/CUDA_COUNT_EQUAL_CHUNK_SIZE), 1, 1);
+
+    switch (src0->type) {
+        case GGML_TYPE_I32: {
+            const int * src0_d = (const int *) src0->data;
+            const int * src1_d = (const int *) src1->data;
+            count_equal<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_d, dne, ne);
+        } break;
+        default:
+            GGML_ASSERT(false);
+            break;
+    }
+}

+ 31 - 0
llama/ggml-cuda/count-equal.cuh

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_COUNT_EQUAL_CHUNK_SIZE 128
+
+void ggml_cuda_count_equal(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 1 - 1
llama/ggml-cuda/cpy.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 2 - 2
llama/ggml-cuda/cpy.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -26,7 +26,7 @@
 
 #include "common.cuh"
 
-#define CUDA_CPY_BLOCK_SIZE 32
+#define CUDA_CPY_BLOCK_SIZE 64
 
 void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
 

+ 1 - 1
llama/ggml-cuda/cross-entropy-loss.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/cross-entropy-loss.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/dequantize.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/diagmask.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/diagmask.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 0 - 709
llama/ggml-cuda/dmmv.cu

@@ -1,709 +0,0 @@
-/**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
- *
- * MIT License
- *
- * Copyright (c) 2023-2024 The ggml authors
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#include "dmmv.cuh"
-#include "dequantize.cuh"
-#include "convert.cuh"
-
-#ifndef K_QUANTS_PER_ITERATION
-#define K_QUANTS_PER_ITERATION 2
-#else
-static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
-#endif
-
-static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
-
-    static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
-
-    const int row = blockIdx.x*blockDim.y + threadIdx.y;
-    if (row > nrows) return;
-
-    const int num_blocks_per_row = ncols / QK_K;
-    const int ib0 = row*num_blocks_per_row;
-
-    const block_q2_K * x = (const block_q2_K *)vx + ib0;
-
-    float tmp = 0; // partial sum for thread in warp
-
-    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...15
-    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
-
-    const int step = 16/K_QUANTS_PER_ITERATION;
-
-    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
-    const int in = tid - step*im;                        // 0...15 or 0...7
-
-    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15 or 0...14 in steps of 2
-    const int q_offset = 32*im + l0;
-    const int s_offset = 8*im;
-    const int y_offset = 128*im + l0;
-
-    uint32_t aux[4];
-    const uint8_t * d = (const uint8_t *)aux;
-    const uint8_t * m = (const uint8_t *)(aux + 2);
-
-    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
-
-        const float   * y = yy + i * QK_K + y_offset;
-        const uint8_t * q = x[i].qs + q_offset;
-
-        const float dall = __low2half(x[i].dm);
-        const float dmin = __high2half(x[i].dm);
-
-        const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
-        aux[0] = a[0] & 0x0f0f0f0f;
-        aux[1] = a[1] & 0x0f0f0f0f;
-        aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
-        aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
-
-        float sum1 = 0, sum2 = 0;
-        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
-            sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
-                  + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
-                  + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
-                  + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
-                  + y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
-                  + y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
-                  + y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
-                  +y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
-            sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
-                  + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
-
-        }
-        tmp += dall * sum1 - dmin * sum2;
-
-    }
-
-    // sum up partial sums and write back result
-    tmp = warp_reduce_sum(tmp);
-
-    if (threadIdx.x == 0) {
-        dst[row] = tmp;
-    }
-}
-
-static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
-
-    const int row = blockIdx.x*blockDim.y + threadIdx.y;
-    if (row > nrows) return;
-
-    const int num_blocks_per_row = ncols / QK_K;
-    const int ib0 = row*num_blocks_per_row;
-
-    const block_q3_K * x = (const block_q3_K *)vx + ib0;
-
-    float tmp = 0; // partial sum for thread in warp
-
-    const uint16_t kmask1 = 0x0303;
-    const uint16_t kmask2 = 0x0f0f;
-
-    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
-    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
-
-    const int n  = K_QUANTS_PER_ITERATION;               // iterations in the inner loop
-    const int step = 16/K_QUANTS_PER_ITERATION;
-    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
-    const int in = tid - step*im;                        // 0....15 or 0...7
-
-    const uint8_t m = 1 << (4*im);
-
-    const int l0 = n*in;                                 // 0...15 or 0...14 in steps of 2
-    const int q_offset =  32*im + l0;
-    const int y_offset = 128*im + l0;
-
-    uint16_t utmp[4];
-    const int8_t * s = (const int8_t *)utmp;
-
-    const uint16_t s_shift = 4*im;
-
-    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
-
-        const float   * y  = yy + i * QK_K + y_offset;
-        const uint8_t * q = x[i].qs + q_offset;
-        const uint8_t * h = x[i].hmask + l0;
-
-        const uint16_t * a = (const uint16_t *)x[i].scales;
-        utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
-        utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
-        utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
-        utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
-
-        const float d = x[i].d;
-
-        float sum = 0;
-        for (int l = 0; l < n; ++l) {
-            sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
-                 + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
-                 + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
-                 + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
-            sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
-                 + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
-                 + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
-                + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
-        }
-        tmp += d * sum;
-
-    }
-
-    // sum up partial sums and write back result
-    tmp = warp_reduce_sum(tmp);
-
-    if (threadIdx.x == 0) {
-        dst[row] = tmp;
-    }
-}
-
-static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
-
-    const int row = blockIdx.x*blockDim.y + threadIdx.y;
-    if (row > nrows) return;
-    const int num_blocks_per_row = ncols / QK_K;
-    const int ib0 = row*num_blocks_per_row;
-
-    const block_q4_K * x = (const block_q4_K *)vx + ib0;
-
-    const uint16_t kmask1 = 0x3f3f;
-    const uint16_t kmask2 = 0x0f0f;
-    const uint16_t kmask3 = 0xc0c0;
-
-    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
-    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
-
-    const int step = 8/K_QUANTS_PER_ITERATION;           // 8 or 4
-
-    const int il  = tid/step;                            // 0...3
-    const int ir  = tid - step*il;                       // 0...7 or 0...3
-    const int n   = 2 * K_QUANTS_PER_ITERATION;          // 2 or 4
-
-    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
-    const int in = il%2;
-
-    const int l0 = n*(2*ir + in);
-    const int q_offset = 32*im + l0;
-    const int y_offset = 64*im + l0;
-
-    uint16_t aux[4];
-    const uint8_t * sc = (const uint8_t *)aux;
-
-#if K_QUANTS_PER_ITERATION == 2
-    uint32_t q32[4];
-    const uint8_t * q4 = (const uint8_t *)q32;
-#else
-    uint16_t q16[4];
-    const uint8_t * q4 = (const uint8_t *)q16;
-#endif
-
-    float tmp = 0; // partial sum for thread in warp
-
-    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
-
-        const float   * y1 = yy + i*QK_K + y_offset;
-        const float   * y2 = y1 + 128;
-
-        const float dall = __low2half(x[i].dm);
-        const float dmin = __high2half(x[i].dm);
-
-        const uint16_t * a = (const uint16_t *)x[i].scales;
-        aux[0] = a[im+0] & kmask1;
-        aux[1] = a[im+2] & kmask1;
-        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
-        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
-
-#if K_QUANTS_PER_ITERATION == 2
-        const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset);
-        const uint32_t * q2 = q1 + 16;
-
-        q32[0] = q1[0] & 0x0f0f0f0f;
-        q32[1] = q1[0] & 0xf0f0f0f0;
-        q32[2] = q2[0] & 0x0f0f0f0f;
-        q32[3] = q2[0] & 0xf0f0f0f0;
-
-        float4 s = {0.f, 0.f, 0.f, 0.f};
-        float smin = 0;
-        for (int l = 0; l < 4; ++l) {
-            s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4];
-            s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12];
-            smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
-        }
-        tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
-#else
-        const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset);
-        const uint16_t * q2 = q1 + 32;
-
-        q16[0] = q1[0] & 0x0f0f;
-        q16[1] = q1[0] & 0xf0f0;
-        q16[2] = q2[0] & 0x0f0f;
-        q16[3] = q2[0] & 0xf0f0;
-
-        float4 s = {0.f, 0.f, 0.f, 0.f};
-        float smin = 0;
-        for (int l = 0; l < 2; ++l) {
-            s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2];
-            s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6];
-            smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
-        }
-        tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
-#endif
-
-    }
-
-    // sum up partial sums and write back result
-    tmp = warp_reduce_sum(tmp);
-
-    if (tid == 0) {
-        dst[row] = tmp;
-    }
-}
-
-static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) {
-
-    const int row = blockIdx.x;
-    const int num_blocks_per_row = ncols / QK_K;
-    const int ib0 = row*num_blocks_per_row;
-
-    const block_q5_K * x = (const block_q5_K *)vx + ib0;
-
-    float tmp = 0; // partial sum for thread in warp
-
-    const uint16_t kmask1 = 0x3f3f;
-    const uint16_t kmask2 = 0x0f0f;
-    const uint16_t kmask3 = 0xc0c0;
-
-    const int tid = threadIdx.x/2;  // 0...15
-    const int ix  = threadIdx.x%2;
-
-    const int il  = tid/4;     // 0...3
-    const int ir  = tid - 4*il;// 0...3
-    const int n   = 2;
-
-    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
-    const int in = il%2;
-
-    const int l0 = n*(2*ir + in);
-    const int q_offset = 32*im + l0;
-    const int y_offset = 64*im + l0;
-
-    const uint8_t hm1  = 1 << (2*im);
-    const uint8_t hm2  = hm1 << 4;
-
-    uint16_t aux[4];
-    const uint8_t * sc = (const uint8_t *)aux;
-
-    uint16_t q16[8];
-    const uint8_t * q4 = (const uint8_t *)q16;
-
-    for (int i = ix; i < num_blocks_per_row; i += 2) {
-
-        const uint8_t * ql1 = x[i].qs + q_offset;
-        const uint8_t * qh  = x[i].qh + l0;
-        const float   * y1  = yy + i*QK_K + y_offset;
-        const float   * y2  = y1 + 128;
-
-        const float dall = __low2half(x[i].dm);
-        const float dmin = __high2half(x[i].dm);
-
-        const uint16_t * a = (const uint16_t *)x[i].scales;
-        aux[0] = a[im+0] & kmask1;
-        aux[1] = a[im+2] & kmask1;
-        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
-        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
-
-        float4 sum = {0.f, 0.f, 0.f, 0.f};
-        float smin = 0;
-        const uint16_t * q1 = (const uint16_t *)ql1;
-        const uint16_t * q2 = q1 + 32;
-        q16[0] = q1[0] & 0x0f0f;
-        q16[1] = q1[8] & 0x0f0f;
-        q16[2] = (q1[0] >> 4) & 0x0f0f;
-        q16[3] = (q1[8] >> 4) & 0x0f0f;
-        q16[4] = q2[0] & 0x0f0f;
-        q16[5] = q2[8] & 0x0f0f;
-        q16[6] = (q2[0] >> 4) & 0x0f0f;
-        q16[7] = (q2[8] >> 4) & 0x0f0f;
-        for (int l = 0; l < n; ++l) {
-            sum.x += y1[l+ 0] * (q4[l +0] + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
-                   + y1[l+16] * (q4[l +2] + (qh[l+16] & (hm1 << 0) ? 16 : 0));
-            sum.y += y1[l+32] * (q4[l +4] + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
-                   + y1[l+48] * (q4[l +6] + (qh[l+16] & (hm1 << 1) ? 16 : 0));
-            sum.z += y2[l+ 0] * (q4[l +8] + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
-                   + y2[l+16] * (q4[l+10] + (qh[l+16] & (hm2 << 0) ? 16 : 0));
-            sum.w += y2[l+32] * (q4[l+12] + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
-                   + y2[l+48] * (q4[l+14] + (qh[l+16] & (hm2 << 1) ? 16 : 0));
-            smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
-                  + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
-        }
-        tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
-    }
-
-    // sum up partial sums and write back result
-    tmp = warp_reduce_sum(tmp);
-
-    if (threadIdx.x == 0) {
-        dst[row] = tmp;
-    }
-}
-
-static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
-
-    static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
-
-    const int row = blockIdx.x*blockDim.y + threadIdx.y;
-    if (row > nrows) return;
-
-    const int num_blocks_per_row = ncols / QK_K;
-    const int ib0 = row*num_blocks_per_row;
-
-    const block_q6_K * x = (const block_q6_K *)vx + ib0;
-
-    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
-    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
-
-    const int step = 16/K_QUANTS_PER_ITERATION;          // 16 or 8
-
-    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
-    const int in = tid - step*im;                        // 0...15 or 0...7
-
-#if K_QUANTS_PER_ITERATION == 1
-    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15
-    const int is = 0;
-#else
-    const int l0 = 4 * in;                               // 0, 4, 8, ..., 28
-    const int is = in / 4;
-#endif
-    const int ql_offset = 64*im + l0;
-    const int qh_offset = 32*im + l0;
-    const int s_offset  =  8*im + is;
-    const int y_offset = 128*im + l0;
-
-    float tmp = 0; // partial sum for thread in warp
-
-    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
-
-        const float   * y  = yy + i * QK_K + y_offset;
-        const uint8_t * ql = x[i].ql + ql_offset;
-        const uint8_t * qh = x[i].qh + qh_offset;
-        const int8_t  * s  = x[i].scales + s_offset;
-
-        const float d = x[i].d;
-
-#if K_QUANTS_PER_ITERATION == 1
-        float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
-                  + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
-                  + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
-                  + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
-                  + y[64] * s[4] * d * ((int8_t)((ql[ 0]  >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
-                  + y[80] * s[5] * d * ((int8_t)((ql[16]  >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
-                  + y[96] * s[6] * d * ((int8_t)((ql[32]  >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
-                  +y[112] * s[7] * d * ((int8_t)((ql[48]  >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
-        tmp += sum;
-#else
-        float sum = 0;
-        for (int l = 0; l < 4; ++l) {
-            sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
-                 + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
-                 + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
-                 + y[l+96] * s[6] * d * ((int8_t)((ql[l+32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
-        }
-        tmp += sum;
-#endif
-
-    }
-
-    // sum up partial sums and write back result
-    tmp = warp_reduce_sum(tmp);
-
-    if (tid == 0) {
-        dst[row] = tmp;
-    }
-}
-
-static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
-    const half * x = (const half *) vx;
-
-    // automatic half -> float type cast if dfloat == float
-    v.x = x[ib + iqs + 0];
-    v.y = x[ib + iqs + 1];
-}
-
-static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) {
-    return type == GGML_TYPE_Q4_0 ? dequantize_q4_0 :
-        type == GGML_TYPE_Q4_1 ? dequantize_q4_1 :
-        type == GGML_TYPE_Q5_0 ? dequantize_q5_0 :
-        type == GGML_TYPE_Q5_1 ? dequantize_q5_1 :
-        type == GGML_TYPE_Q8_0 ? dequantize_q8_0 :
-        type == GGML_TYPE_F16 ? convert_f16 :
-        nullptr;
-}
-
-template <ggml_type type>
-static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
-    constexpr int qk = ggml_cuda_type_traits<type>::qk; // quantized weights per x block
-    constexpr int qr = ggml_cuda_type_traits<type>::qr; // number of quantized weights per data value in x block
-    constexpr dequantize_kernel_t dequantize_kernel = get_dequantize_kernel(type);
-
-    const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y;
-
-    if (row >= nrows) {
-        return;
-    }
-
-    const int tid = threadIdx.x;
-
-    const int iter_stride = 2*GGML_CUDA_DMMV_X;
-    const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
-    const int y_offset = qr == 1 ? 1 : qk/2;
-
-// partial sum for each thread
-#ifdef GGML_CUDA_F16
-    half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
-#else
-    float tmp = 0.0f;
-#endif // GGML_CUDA_F16
-
-    for (int i = 0; i < ncols; i += iter_stride) {
-        const int col = i + vals_per_iter*tid;
-        const int64_t ib = ((int64_t)row*ncols + col)/qk; // x block index
-        const int iqs = (col%qk)/qr; // x quant index
-        const int iybs = col - col%qk; // y block start index
-
-// processing >2 values per i iter is faster for fast GPUs
-#pragma unroll
-        for (int j = 0; j < vals_per_iter; j += 2) {
-            // process 2 vals per j iter
-
-            // dequantize
-            // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
-            dfloat2 v;
-            dequantize_kernel(vx, ib, iqs + j/qr, v);
-
-            // matrix multiplication
-            // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
-#ifdef GGML_CUDA_F16
-            tmp += __hmul2(v, {
-                y[iybs + iqs + j/qr + 0],
-                y[iybs + iqs + j/qr + y_offset]
-            });
-#else
-            tmp += v.x * y[iybs + iqs + j/qr + 0];
-            tmp += v.y * y[iybs + iqs + j/qr + y_offset];
-#endif // GGML_CUDA_F16
-        }
-    }
-
-    // sum up partial sums and write back result
-    tmp = warp_reduce_sum(tmp);
-
-    if (tid == 0) {
-#ifdef GGML_CUDA_F16
-        dst[row] = tmp.x + tmp.y;
-#else
-        dst[row] = tmp;
-#endif // GGML_CUDA_F16
-    }
-}
-
-static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
-    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
-    const dim3 block_nums(block_num_y, 1, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    dequantize_mul_mat_vec<GGML_TYPE_Q4_0>
-        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
-}
-
-static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
-    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(block_num_y, 1, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    dequantize_mul_mat_vec<GGML_TYPE_Q4_1>
-        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
-}
-
-static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
-    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(block_num_y, 1, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    dequantize_mul_mat_vec<GGML_TYPE_Q5_0>
-        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
-}
-
-static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
-    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(block_num_y, 1, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    dequantize_mul_mat_vec<GGML_TYPE_Q5_1>
-        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
-}
-
-static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
-    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(block_num_y, 1, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    dequantize_mul_mat_vec<GGML_TYPE_Q8_0>
-        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
-}
-
-static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
-    const int block_num_y = (nrows + ny - 1) / ny;
-    const dim3 block_nums(block_num_y, 1, 1);
-    const dim3 block_dims(32, ny, 1);
-    dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
-}
-
-static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int ny = 2 / K_QUANTS_PER_ITERATION;
-    const int block_num_y = (nrows + ny - 1) / ny;
-    const dim3 block_nums(block_num_y, 1, 1);
-    const dim3 block_dims(32, ny, 1);
-    dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
-}
-
-static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int ny = 2 / K_QUANTS_PER_ITERATION;
-    const int block_num_y = (nrows + ny - 1) / ny;
-    const dim3 block_nums(block_num_y, 1, 1);
-    const dim3 block_dims(32, ny, 1);
-    dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
-}
-
-static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const dim3 block_dims(32, 1, 1);
-    dequantize_mul_mat_vec_q5_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
-}
-
-static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int ny = 2 / K_QUANTS_PER_ITERATION;
-    const int block_num_y = (nrows + ny - 1) / ny;
-    const dim3 block_nums(block_num_y, 1, 1);
-    const dim3 block_dims(32, ny, 1);
-    dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
-}
-
-static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
-    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(block_num_y, 1, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    dequantize_mul_mat_vec<GGML_TYPE_F16>
-        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
-}
-
-void ggml_cuda_op_dequantize_mul_mat_vec(
-    ggml_backend_cuda_context & ctx,
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
-    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
-    const int64_t src1_padded_row_size, cudaStream_t stream) {
-    GGML_UNUSED(ctx);
-    const int64_t ne00 = src0->ne[0];
-    const int64_t row_diff = row_high - row_low;
-
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-    // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
-#ifdef GGML_CUDA_F16
-    ggml_cuda_pool_alloc<half> src1_dfloat_a(ctx.pool());
-    half * src1_dfloat = nullptr; // dfloat == half
-
-    bool src1_convert_f16 =
-        src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
-        src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
-        src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
-
-    if (src1_convert_f16) {
-        src1_dfloat = src1_dfloat_a.alloc(ne00);
-        const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
-        GGML_ASSERT(to_fp16_cuda != nullptr);
-        to_fp16_cuda(src1_ddf_i, src1_dfloat, ne00, stream);
-    }
-#else
-    const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion
-#endif // GGML_CUDA_F16
-
-    switch (src0->type) {
-        case GGML_TYPE_Q4_0:
-            dequantize_mul_mat_vec_q4_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q4_1:
-            dequantize_mul_mat_vec_q4_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q5_0:
-            dequantize_mul_mat_vec_q5_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q5_1:
-            dequantize_mul_mat_vec_q5_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q8_0:
-            dequantize_mul_mat_vec_q8_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q2_K:
-            dequantize_mul_mat_vec_q2_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q3_K:
-            dequantize_mul_mat_vec_q3_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q4_K:
-            dequantize_mul_mat_vec_q4_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q5_K:
-            dequantize_mul_mat_vec_q5_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q6_K:
-            dequantize_mul_mat_vec_q6_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_F16:
-            convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
-            break;
-        default:
-            GGML_ABORT("fatal error");
-            break;
-    }
-
-    GGML_UNUSED(src1);
-    GGML_UNUSED(dst);
-    GGML_UNUSED(src1_ddq_i);
-    GGML_UNUSED(src1_ncols);
-    GGML_UNUSED(src1_padded_row_size);
-}
-
-bool ggml_cuda_dmmv_type_supported(ggml_type src0_type) {
-    return src0_type == GGML_TYPE_Q4_0 || src0_type == GGML_TYPE_Q4_1 ||
-        src0_type == GGML_TYPE_Q5_0 || src0_type == GGML_TYPE_Q5_1 ||
-        src0_type == GGML_TYPE_Q8_0 || src0_type == GGML_TYPE_Q2_K ||
-        src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q4_K ||
-        src0_type == GGML_TYPE_Q5_K || src0_type == GGML_TYPE_Q6_K ||
-        src0_type == GGML_TYPE_F16;
-}

+ 3 - 3
llama/ggml-cuda/fattn-common.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -543,9 +543,9 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
 }
 
 template<int D, int parallel_blocks> // D == head size
-#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(D, 1)
-#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 static __global__ void flash_attn_combine_results(
         const float  * __restrict__ VKQ_parts,
         const float2 * __restrict__ VKQ_meta,

+ 4 - 4
llama/ggml-cuda/fattn-tile-f16.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -31,9 +31,9 @@
 #define FATTN_KQ_STRIDE_TILE_F16 64
 
 template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
-#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(nwarps*WARP_SIZE, 1)
-#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 static __global__ void flash_attn_tile_ext_f16(
         const char * __restrict__ Q,
         const char * __restrict__ K,
@@ -285,7 +285,7 @@ static __global__ void flash_attn_tile_ext_f16(
         }
 
         half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
-        kqsum_j = warp_reduce_sum(kqsum_j);
+        kqsum_j = warp_reduce_sum((float)kqsum_j);
 
 #pragma unroll
         for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {

+ 1 - 1
llama/ggml-cuda/fattn-tile-f16.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 3 - 3
llama/ggml-cuda/fattn-tile-f32.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -31,9 +31,9 @@
 #define FATTN_KQ_STRIDE_TILE_F32 32
 
 template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
-#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(nwarps*WARP_SIZE, 1)
-#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 static __global__ void flash_attn_tile_ext_f32(
         const char * __restrict__ Q,
         const char * __restrict__ K,

+ 1 - 1
llama/ggml-cuda/fattn-tile-f32.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 6 - 7
llama/ggml-cuda/fattn-vec-f16.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -28,9 +28,9 @@
 #include "fattn-common.cuh"
 
 template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
-#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(D, 1)
-#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 static __global__ void flash_attn_vec_ext_f16(
         const char * __restrict__ Q,
         const char * __restrict__ K,
@@ -222,7 +222,7 @@ static __global__ void flash_attn_vec_ext_f16(
 #pragma unroll
             for (int j = 0; j < ncols; ++j) {
                 half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
-                sum = warp_reduce_sum(sum);
+                sum = warp_reduce_sum((float)sum);
 
                 if (use_logit_softcap) {
                     sum = logit_softcap*tanhf(sum);
@@ -246,7 +246,6 @@ static __global__ void flash_attn_vec_ext_f16(
         for (int j = 0; j < ncols; ++j) {
             half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
 
-            kqmax_new_j = warp_reduce_max(kqmax_new_j);
             if (threadIdx.x == 0) {
                 kqmax_shared[j][threadIdx.y] = kqmax_new_j;
             }
@@ -291,7 +290,7 @@ static __global__ void flash_attn_vec_ext_f16(
 
 #pragma unroll
     for (int j = 0; j < ncols; ++j) {
-        kqsum[j] = warp_reduce_sum(kqsum[j]);
+        kqsum[j] = warp_reduce_sum((float)kqsum[j]);
         if (threadIdx.x == 0) {
             kqsum_shared[j][threadIdx.y] = kqsum[j];
         }
@@ -306,7 +305,7 @@ static __global__ void flash_attn_vec_ext_f16(
         }
 
         kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
-        kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
+        kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]);
 
         half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
         if (parallel_blocks == 1) {

+ 3 - 4
llama/ggml-cuda/fattn-vec-f32.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -28,9 +28,9 @@
 #include "fattn-common.cuh"
 
 template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
-#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(D, 1)
-#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 static __global__ void flash_attn_vec_ext_f32(
         const char * __restrict__ Q,
         const char * __restrict__ K,
@@ -232,7 +232,6 @@ static __global__ void flash_attn_vec_ext_f32(
         for (int j = 0; j < ncols; ++j) {
             float kqmax_new_j = kqmax_new_arr[j];
 
-            kqmax_new_j = warp_reduce_max(kqmax_new_j);
             if (threadIdx.x == 0) {
                 kqmax_shared[j][threadIdx.y] = kqmax_new_j;
             }

+ 3 - 3
llama/ggml-cuda/fattn-wmma-f16.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -33,9 +33,9 @@
 
 // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
 template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
-#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(nwarps*WARP_SIZE, 1)
-#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 static __global__ void flash_attn_ext_f16(
         const char * __restrict__ Q,
         const char * __restrict__ K,

+ 6 - 6
llama/ggml-cuda/fattn.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -39,9 +39,9 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
     const ggml_tensor * KQV = dst;
     const ggml_tensor * Q   = dst->src[0];
 
-    const int32_t precision = KQV->op_params[3];
+    const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
 
-    if (precision != GGML_PREC_DEFAULT) {
+    if (prec != GGML_PREC_DEFAULT) {
         if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
             constexpr int cols_per_block = 16;
             switch (Q->ne[0]) {
@@ -327,11 +327,11 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
 
     ggml_cuda_set_device(ctx.device);
     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
-    const int32_t precision = KQV->op_params[3];
+    const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
 
     // On AMD the tile kernels perform poorly, use the vec kernel instead:
     if (cc >= CC_OFFSET_AMD) {
-        if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
+        if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
         } else {
             ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
@@ -358,7 +358,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     }
 
     if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
-        if (precision == GGML_PREC_DEFAULT) {
+        if (prec == GGML_PREC_DEFAULT) {
             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
             return;
         } else if(Q->ne[0] <= 128) {

+ 1 - 1
llama/ggml-cuda/fattn.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/getrows.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/getrows.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

Diferenças do arquivo suprimidas por serem muito extensas
+ 144 - 333
llama/ggml-cuda/ggml-cuda.cu


+ 4 - 4
llama/ggml-cuda/im2col.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -117,9 +117,9 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const int64_t OH = is_2D ? dst->ne[2] : 1;
     const int64_t OW =         dst->ne[1];
 
-    const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
-    const int64_t batch = src1->ne[3];
-    const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
+    const size_t  delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
+    const int64_t batch        = src1->ne[is_2D ? 3 : 2];
+    const size_t  batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
 
     if(dst->type == GGML_TYPE_F16) {
         im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);

+ 1 - 1
llama/ggml-cuda/im2col.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/mma.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 3 - 5
llama/ggml-cuda/mmq.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -34,8 +34,6 @@ void ggml_cuda_op_mul_mat_q(
 
     const int64_t ne00 = src0->ne[0];
 
-    const int64_t nb01 = src0->nb[1];
-
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
     GGML_ASSERT(ne10 % QK8_1 == 0);
@@ -43,7 +41,7 @@ void ggml_cuda_op_mul_mat_q(
     const int64_t ne0 = dst->ne[0];
 
     const int64_t row_diff = row_high - row_low;
-    const int64_t stride00 = nb01 / ggml_type_size(src0->type);
+    const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
 
     int id = ggml_cuda_get_device();
     const int compute_capability = ggml_cuda_info().devices[id].cc;
@@ -176,5 +174,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
         return cc < CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
     }
 
-    return cc < CC_RDNA3 || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
+    return (cc < CC_RDNA3 && cc != CC_CDNA && cc != CC_VEGA20) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
 }

+ 14 - 14
llama/ggml-cuda/mmq.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -126,9 +126,9 @@ static constexpr __device__ int get_mmq_x_max_device() {
     return 128;
 #else // INT8_MMA_AVAILABLE
 
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
     return 128;
-#else // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 
 #if __CUDA_ARCH__ >= CC_VOLTA
 #ifdef GGML_CUDA_FORCE_MMQ
@@ -141,7 +141,7 @@ static constexpr __device__ int get_mmq_x_max_device() {
     return 64;
 #endif // __CUDA_ARCH__ >= CC_VOLTA
 
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 #endif // INT8_MMA_AVAILABLE
 }
 
@@ -150,7 +150,7 @@ static constexpr int get_mmq_y_host(const int cc) {
 }
 
 static constexpr __device__ int get_mmq_y_device() {
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 #if defined(RDNA1)
     return 64;
 #else
@@ -162,7 +162,7 @@ static constexpr __device__ int get_mmq_y_device() {
 #else
     return 64;
 #endif // __CUDA_ARCH__ >= CC_VOLTA
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 }
 
 #define MMQ_DP4A_TXS_Q4_0    tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_0   + mmq_y/QI4_0,     0}
@@ -2595,17 +2595,17 @@ static __device__ void mul_mat_q_process_tile(
 // The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
 
 template <ggml_type type, int mmq_x, int nwarps, bool need_check>
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2)
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
     __launch_bounds__(WARP_SIZE*nwarps, 2)
-#endif // defined(RDNA3) || defined(RDNA2)
+#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
 #else
 #if __CUDA_ARCH__ >= CC_VOLTA
     __launch_bounds__(WARP_SIZE*nwarps, 1)
 #else
     __launch_bounds__(WARP_SIZE*nwarps, 2)
 #endif // __CUDA_ARCH__ >= CC_VOLTA
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 static __global__ void mul_mat_q(
     const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
     const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
@@ -2620,7 +2620,7 @@ static __global__ void mul_mat_q(
     constexpr int mmq_y = get_mmq_y_device();
 
     // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
-#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
+#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
     {
         constexpr bool fixup = false;
         mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
@@ -2628,7 +2628,7 @@ static __global__ void mul_mat_q(
                 blockIdx.x, blockIdx.y, 0, ne00/qk);
         return;
     }
-#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
+#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
 
     const     int64_t blocks_per_ne00 = ne00 / qk;
     constexpr int     blocks_per_iter = MMQ_ITER_K / qk;
@@ -2791,14 +2791,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
 
     const int shmem = mmq_get_shmem<type>(mmq_x, mmq_y, cc);
 
-#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
     static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
     if (!shmem_limit_raised[id]) {
         CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
         CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>,  cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
         shmem_limit_raised[id] = true;
     }
-#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 
     const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
     const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;

+ 249 - 0
llama/ggml-cuda/mmv.cu

@@ -0,0 +1,249 @@
+/**
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+#include "mmv.cuh"
+
+template <typename type_acc, int block_size>
+static __global__ void mul_mat_vec(
+        const half * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
+        const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
+    const int64_t row     = blockIdx.x;
+    const int64_t channel = blockIdx.z;
+    const int     tid     = threadIdx.x;
+
+    x   += (channel/channel_ratio)*stride_channel_x + row*stride_row;
+    y   +=  channel               *stride_channel_y;
+    dst +=  channel               *stride_channel_dst;
+
+    const half2  * x2 = (const half2  *) x;
+    const float2 * y2 = (const float2 *) y;
+
+    extern __shared__ char data_mmv[];
+    float * buf_iw = (float *) data_mmv;
+
+    if (block_size > WARP_SIZE) {
+        if (tid < WARP_SIZE) {
+            buf_iw[tid] = 0.0f;
+        }
+        __syncthreads();
+    }
+
+    float sumf;
+
+    if (std::is_same<type_acc, float>::value) {
+        sumf = 0.0f;
+
+        for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+            const float2 tmpx = __half22float2(x2[col2]);
+            const float2 tmpy = y2[col2];
+            sumf += tmpx.x * tmpy.x;
+            sumf += tmpx.y * tmpy.y;
+        }
+    } else {
+#ifdef FP16_AVAILABLE
+        half2 sumh2 = make_half2(0.0f, 0.0f);
+
+        for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+            const float2 tmp = y2[col2];
+            sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
+        }
+
+        sumf = __low2float(sumh2) + __high2float(sumh2);
+#else
+        NO_DEVICE_CODE;
+#endif // FP16_AVAILABLE
+    }
+
+    sumf = warp_reduce_sum(sumf);
+
+    if (block_size > WARP_SIZE) {
+        buf_iw[tid/WARP_SIZE] = sumf;
+        __syncthreads();
+        if (tid > WARP_SIZE) {
+            return;
+        }
+        sumf = buf_iw[tid];
+        sumf = warp_reduce_sum(sumf);
+    }
+
+    if (tid != 0) {
+        return;
+    }
+
+    dst[row] = sumf;
+}
+
+template <typename type_acc>
+static void launch_mul_mat_vec_cuda(
+        const half * x, const float * y, float * dst,
+        const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
+        cudaStream_t stream) {
+    GGML_ASSERT(ncols      % 2 == 0);
+    GGML_ASSERT(stride_row % 2 == 0);
+    GGML_ASSERT(nchannels_y % nchannels_x == 0);
+    const int64_t channel_ratio = nchannels_y / nchannels_x;
+
+    int64_t block_size_best = WARP_SIZE;
+    int64_t niter_best      = (ncols + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
+    for (int64_t block_size = 2*WARP_SIZE; block_size <= 256; block_size += WARP_SIZE) {
+        const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
+        if (niter < niter_best) {
+            niter_best      = niter;
+            block_size_best = block_size;
+        }
+    }
+
+    const int smem = WARP_SIZE*sizeof(float);
+    const dim3 block_nums(nrows, 1, nchannels_y);
+    const dim3 block_dims(block_size_best, 1, 1);
+    switch (block_size_best) {
+        case   32: {
+            mul_mat_vec<type_acc,  32><<<block_nums, block_dims, smem, stream>>>
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+        } break;
+        case   64: {
+            mul_mat_vec<type_acc,  64><<<block_nums, block_dims, smem, stream>>>
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+        } break;
+        case   96: {
+            mul_mat_vec<type_acc,  96><<<block_nums, block_dims, smem, stream>>>
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+        } break;
+        case  128: {
+            mul_mat_vec<type_acc, 128><<<block_nums, block_dims, smem, stream>>>
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+        } break;
+        case  160: {
+            mul_mat_vec<type_acc, 160><<<block_nums, block_dims, smem, stream>>>
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+        } break;
+        case  192: {
+            mul_mat_vec<type_acc, 192><<<block_nums, block_dims, smem, stream>>>
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+        } break;
+        case  224: {
+            mul_mat_vec<type_acc, 224><<<block_nums, block_dims, smem, stream>>>
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+        } break;
+        case  256: {
+            mul_mat_vec<type_acc, 256><<<block_nums, block_dims, smem, stream>>>
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+        } break;
+        default: {
+            GGML_ABORT("fatal error");
+        } break;
+    }
+}
+
+static void mul_mat_vec_cuda(
+        const half * x, const float * y, float * dst,
+        const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
+        enum ggml_prec prec, cudaStream_t stream) {
+    switch (prec) {
+        case GGML_PREC_DEFAULT: {
+            launch_mul_mat_vec_cuda<half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
+                stride_channel_x, stride_channel_y, stride_channel_dst, stream);
+        } break;
+        case GGML_PREC_F32: {
+            launch_mul_mat_vec_cuda<float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
+                stride_channel_x, stride_channel_y, stride_channel_dst, stream);
+        } break;
+    }
+}
+
+void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+
+    GGML_ASSERT(src1->ne[1] == 1);
+
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+    const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
+
+    const half  * src0_d = (const half  *) src0->data;
+    const float * src1_d = (const float *) src1->data;
+    float       *  dst_d = (float       *)  dst->data;
+
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne12 = src1->ne[2];
+    GGML_ASSERT(dst->ne[2] == ne12);
+
+    GGML_ASSERT(src0->ne[3] == 1);
+    GGML_ASSERT(src1->ne[3] == 1);
+    GGML_ASSERT( dst->ne[3] == 1);
+
+    const int64_t stride_row         = src0->nb[1] / ggml_type_size(src0->type);
+    const int64_t channel_stride_x   = src0->nb[2] / ggml_type_size(src0->type);
+    const int64_t channel_stride_y   = src1->nb[2] / ggml_type_size(src1->type);
+    const int64_t channel_stride_dst =  dst->nb[2] / ggml_type_size( dst->type);
+
+    mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
+}
+
+void ggml_cuda_op_mul_mat_vec(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, cudaStream_t stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t row_diff = row_high - row_low;
+
+    GGML_ASSERT(src1_ncols == 1);
+
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+    const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
+
+
+    // ggml_cuda_op provides single, contiguous matrices
+    const int64_t stride_row         = ne00;
+    const int64_t nchannels_x        = 1;
+    const int64_t nchannels_y        = 1;
+    const int64_t channel_stride_x   = 0;
+    const int64_t channel_stride_y   = 0;
+    const int64_t channel_stride_dst = 0;
+
+    mul_mat_vec_cuda((const half *) src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
+        nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
+
+    GGML_UNUSED(ctx);
+    GGML_UNUSED(src1);
+    GGML_UNUSED(dst);
+    GGML_UNUSED(src1_ddq_i);
+    GGML_UNUSED(src1_ncols);
+    GGML_UNUSED(src1_padded_row_size);
+}

+ 5 - 13
llama/ggml-cuda/dmmv.cuh → llama/ggml-cuda/mmv.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -26,21 +26,13 @@
 
 #include "common.cuh"
 
-// dmmv = dequantize_mul_mat_vec
+// maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available
+#define MMV_MAX_ROWS 512
 
-// TODO: remove this?
-#ifndef GGML_CUDA_DMMV_X
-#define GGML_CUDA_DMMV_X 32
-#endif
+void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
 
-#ifndef GGML_CUDA_MMV_Y
-#define GGML_CUDA_MMV_Y 1
-#endif
-
-void ggml_cuda_op_dequantize_mul_mat_vec(
+void ggml_cuda_op_mul_mat_vec(
     ggml_backend_cuda_context & ctx,
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
     const int64_t src1_padded_row_size, cudaStream_t stream);
-
-bool ggml_cuda_dmmv_type_supported(ggml_type src0_type);

+ 6 - 6
llama/ggml-cuda/mmvq.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -74,10 +74,10 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
 }
 
 template <ggml_type type, int ncols_y>
-#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 // tell the compiler to use as many registers as it wants, see nwarps definition below
 __launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
-#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 static __global__ void mul_mat_vec_q(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@@ -88,13 +88,13 @@ static __global__ void mul_mat_vec_q(
 
     constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
 
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
     constexpr int nwarps              = 1;
     constexpr int rows_per_cuda_block = 1;
 #else
     constexpr int nwarps              = ncols_y <= 4 ? 4 : 2;
     constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
 
     const     int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
     const     int row0 = rows_per_cuda_block*blockIdx.x;
@@ -168,7 +168,7 @@ static void mul_mat_vec_q_cuda(
     int64_t nwarps = 1;
     int64_t rows_per_cuda_block = 1;
 
-    if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
+    if (ggml_cuda_info().devices[id].cc < CC_CDNA || ggml_cuda_info().devices[id].cc == CC_RDNA1) { // NVIDIA and AMD older than RDNA2 but not CDNA
         switch(ncols_y) {
             case 1:
                 nwarps = 4;

+ 1 - 1
llama/ggml-cuda/mmvq.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/norm.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/norm.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 34 - 36
llama/ggml-cuda/opt-step-adamw.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -24,14 +24,14 @@
  * SOFTWARE.
  */
 
+#include "ggml-impl.h"
 #include "opt-step-adamw.cuh"
 
 #include <cstdint>
 
 static __global__ void opt_step_adamw_f32(
-    float * __restrict__ x, const float * __restrict__ g, float * __restrict__ g_m, float * __restrict__ g_v, const int64_t k,
-    const float alpha, const float beta1, const float beta2, const float eps, const float wd,
-    const float beta1h, const float beta2h) {
+    float * __restrict__ x, const float * __restrict__ g, float * __restrict__ g_m, float * __restrict__ g_v,
+    const float * __restrict__ pars, const int64_t k) {
 
     const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
 
@@ -39,6 +39,14 @@ static __global__ void opt_step_adamw_f32(
         return;
     }
 
+    const float alpha  = pars[0];
+    const float beta1  = pars[1];
+    const float beta2  = pars[2];
+    const float eps    = pars[3];
+    const float wd     = pars[4];
+    const float beta1h = pars[5];
+    const float beta2h = pars[6];
+
     const float gi = g[i];
     const float gmi = g_m[i]*beta1 +    gi*(1.0f - beta1);
     const float gvi = g_v[i]*beta2 + gi*gi*(1.0f - beta2);
@@ -49,58 +57,48 @@ static __global__ void opt_step_adamw_f32(
     const float mh =       gmi*beta1h;
     const float vh = sqrtf(gvi*beta2h) + eps;
 
-    x[i] = x[i]*(1.0f - alpha*wd) - mh/vh;
+    x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh;
 }
 
 static void opt_step_adamw_f32_cuda(
-    float * x, const float * g, float * g_m, float * g_v, const int64_t k,
-    const float alpha, const float beta1, const float beta2, const float eps, const float wd,
-    const float beta1h, const float beta2h, cudaStream_t stream) {
+    float * x, const float * g, float * g_m, float * g_v, const float * pars, const int64_t k, cudaStream_t stream) {
 
     const dim3 block_dims(CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);
     const dim3 block_nums((k + CUDA_OPT_STEP_ADAMW_BLOCK_SIZE - 1) / CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);
-    opt_step_adamw_f32<<<block_nums, block_dims, 0, stream>>>(x, g, g_m, g_v, k, alpha, beta1, beta2, eps, wd, beta1h, beta2h);
+    opt_step_adamw_f32<<<block_nums, block_dims, 0, stream>>>(x, g, g_m, g_v, pars, k);
 }
 
 void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0        = dst->src[0];
-    const ggml_tensor * src0_grad   = dst->src[1];
-    const ggml_tensor * src0_grad_m = dst->src[2];
-    const ggml_tensor * src0_grad_v = dst->src[3];
-
-    GGML_ASSERT(src0->type        == GGML_TYPE_F32);
-    GGML_ASSERT(src0_grad->type   == GGML_TYPE_F32);
-    GGML_ASSERT(src0_grad_m->type == GGML_TYPE_F32);
-    GGML_ASSERT(src0_grad_v->type == GGML_TYPE_F32);
+    const ggml_tensor * src0         = dst->src[0];
+    const ggml_tensor * src0_grad    = dst->src[1];
+    const ggml_tensor * src0_grad_m  = dst->src[2];
+    const ggml_tensor * src0_grad_v  = dst->src[3];
+    const ggml_tensor * adamw_params = dst->src[4];
+
+    GGML_ASSERT(src0->type         == GGML_TYPE_F32);
+    GGML_ASSERT(src0_grad->type    == GGML_TYPE_F32);
+    GGML_ASSERT(src0_grad_m->type  == GGML_TYPE_F32);
+    GGML_ASSERT(src0_grad_v->type  == GGML_TYPE_F32);
+    GGML_ASSERT(adamw_params->type == GGML_TYPE_F32);
     GGML_ASSERT(ggml_is_contiguous(src0));
     GGML_ASSERT(ggml_is_contiguous(src0_grad));
     GGML_ASSERT(ggml_is_contiguous(src0_grad_m));
     GGML_ASSERT(ggml_is_contiguous(src0_grad_v));
+    GGML_ASSERT(ggml_is_contiguous(adamw_params));
     GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
     GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
     GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
+    GGML_ASSERT(ggml_nelements(adamw_params) == 7);
 
-    float       * src0_d        = (float       *) src0->data;
-    const float * src0_grad_d   = (const float *) src0_grad->data;
-    float       * src0_grad_m_d = (float       *) src0_grad_m->data;
-    float       * src0_grad_v_d = (float       *) src0_grad_v->data;
+    float       * src0_d         = (float       *) src0->data;
+    const float * src0_grad_d    = (const float *) src0_grad->data;
+    float       * src0_grad_m_d  = (float       *) src0_grad_m->data;
+    float       * src0_grad_v_d  = (float       *) src0_grad_v->data;
+    const float * adamw_params_d = (const float *) adamw_params->data;
 
     cudaStream_t stream = ctx.stream();
 
     const int64_t ne = ggml_nelements(src0);
 
-    int64_t iter;  memcpy(&iter,  &dst->op_params[0], sizeof(int64_t));
-    float   alpha; memcpy(&alpha, &dst->op_params[2], sizeof(float));
-    float   beta1; memcpy(&beta1, &dst->op_params[3], sizeof(float));
-    float   beta2; memcpy(&beta2, &dst->op_params[4], sizeof(float));
-    float   eps;   memcpy(&eps,   &dst->op_params[5], sizeof(float));
-    float   wd;    memcpy(&wd,    &dst->op_params[6], sizeof(float));
-
-    const float beta1h  = alpha/(1.0f - powf(beta1, iter));
-    const float beta2h  =  1.0f/(1.0f - powf(beta2, iter));
-
-    opt_step_adamw_f32_cuda(src0_d, src0_grad_d, src0_grad_m_d, src0_grad_v_d, ne, alpha, beta1, beta2, eps, wd, beta1h, beta2h, stream);
-
-    iter++;
-    memcpy(&dst->op_params[0], &iter, sizeof(int64_t));
+    opt_step_adamw_f32_cuda(src0_d, src0_grad_d, src0_grad_m_d, src0_grad_v_d, adamw_params_d, ne, stream);
 }

+ 1 - 1
llama/ggml-cuda/opt-step-adamw.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/out-prod.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/out-prod.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/pad.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/pad.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/pool2d.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/pool2d.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 5 - 5
llama/ggml-cuda/quantize.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *
@@ -95,8 +95,8 @@ static __global__ void quantize_mmq_q8_1(
 
     // Exchange max. abs. value between vals_per_scale/4 threads.
 #pragma unroll
-    for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
-        amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
+    for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) {
+        amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE));
     }
 
     float sum;
@@ -105,8 +105,8 @@ static __global__ void quantize_mmq_q8_1(
 
         // Exchange calculate sum across vals_per_sum/4 threads.
 #pragma unroll
-        for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
-            sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
+        for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
+            sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
         }
     }
 

+ 1 - 1
llama/ggml-cuda/quantize.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/rope.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/rope.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/scale.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/scale.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/softmax.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/softmax.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
+ * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  *
  * MIT License
  *

Alguns arquivos não foram mostrados porque muitos arquivos mudaram nesse diff