|
@@ -1,5 +1,5 @@
|
|
|
/**
|
|
|
- * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
|
|
|
+ * llama.cpp - commit ba1cb19cdd0d92e012e0f6e009e0620f854b6afd - do not edit this file
|
|
|
*
|
|
|
* MIT License
|
|
|
*
|
|
@@ -24,23 +24,60 @@
|
|
|
* SOFTWARE.
|
|
|
*/
|
|
|
|
|
|
-#define GGML_COMMON_IMPL_C
|
|
|
+#define GGML_COMMON_IMPL_CPP
|
|
|
+#define GGML_COMMON_DECL_CPP
|
|
|
#include "ggml-common.h"
|
|
|
+#include "ggml-backend-impl.h"
|
|
|
|
|
|
#include "ggml-quants.h"
|
|
|
#include "ggml-impl.h"
|
|
|
#include "ggml-cpu.h"
|
|
|
#include "ggml-cpu-impl.h"
|
|
|
+#include "ggml-cpu-traits.h"
|
|
|
|
|
|
-#include <math.h>
|
|
|
-#include <string.h>
|
|
|
-#include <assert.h>
|
|
|
-#include <float.h>
|
|
|
-#include <stdlib.h> // for qsort
|
|
|
-#include <stdio.h> // for GGML_ASSERT
|
|
|
+#include <cmath>
|
|
|
+#include <cstring>
|
|
|
+#include <cassert>
|
|
|
+#include <cfloat>
|
|
|
+#include <cstdlib> // for qsort
|
|
|
+#include <cstdio> // for GGML_ASSERT
|
|
|
|
|
|
#include "ggml-cpu-aarch64.h"
|
|
|
|
|
|
+// TODO: move to include file?
|
|
|
+template <int K> constexpr int QK_0() {
|
|
|
+ if constexpr (K == 4) {
|
|
|
+ return QK4_0;
|
|
|
+ }
|
|
|
+ if constexpr (K == 8) {
|
|
|
+ return QK8_0;
|
|
|
+ }
|
|
|
+ return -1;
|
|
|
+}
|
|
|
+
|
|
|
+template <int K, int N> struct block {
|
|
|
+ ggml_half d[N]; // deltas for N qK_0 blocks
|
|
|
+ int8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_0 blocks
|
|
|
+};
|
|
|
+
|
|
|
+// control size
|
|
|
+static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding");
|
|
|
+static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding");
|
|
|
+static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding");
|
|
|
+static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding");
|
|
|
+
|
|
|
+using block_q4_0x4 = block<4, 4>;
|
|
|
+using block_q4_0x8 = block<4, 8>;
|
|
|
+using block_q8_0x4 = block<8, 4>;
|
|
|
+using block_q8_0x8 = block<8, 8>;
|
|
|
+
|
|
|
+struct block_iq4_nlx4 {
|
|
|
+ ggml_half d[4]; // deltas for 4 iq4_nl blocks
|
|
|
+ uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks
|
|
|
+};
|
|
|
+
|
|
|
+static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
|
|
|
+
|
|
|
#if defined(__GNUC__)
|
|
|
#pragma GCC diagnostic ignored "-Woverlength-strings"
|
|
|
#elif defined(_MSC_VER)
|
|
@@ -211,12 +248,12 @@ static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y)
|
|
|
|
|
|
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
|
|
|
|
|
-static void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) {
|
|
|
+static void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
|
|
assert(QK8_0 == 32);
|
|
|
assert(k % QK8_0 == 0);
|
|
|
const int nb = k / QK8_0;
|
|
|
|
|
|
- block_q8_0x4 * restrict y = (block_q8_0x4 *) vy;
|
|
|
+ block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
|
|
|
|
|
|
#if defined(__ARM_NEON)
|
|
|
float32x4_t srcv[4][8];
|
|
@@ -305,12 +342,12 @@ static void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int6
|
|
|
#endif
|
|
|
}
|
|
|
|
|
|
-static void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k) {
|
|
|
+static void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
|
|
assert(QK8_0 == 32);
|
|
|
assert(k % QK8_0 == 0);
|
|
|
const int nb = k / QK8_0;
|
|
|
|
|
|
- block_q8_0x4 * restrict y = (block_q8_0x4 *) vy;
|
|
|
+ block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
|
|
|
|
|
|
#if defined(__ARM_NEON)
|
|
|
float32x4_t srcv[4][8];
|
|
@@ -520,7 +557,7 @@ static void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int6
|
|
|
#endif
|
|
|
}
|
|
|
|
|
|
-void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
|
|
|
+static void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
|
|
|
assert(nrow == 4);
|
|
|
UNUSED(nrow);
|
|
|
if (blck_size_interleave == 4) {
|
|
@@ -532,7 +569,7 @@ void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nro
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
|
|
|
+static void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
|
const int qk = QK8_0;
|
|
|
const int nb = n / qk;
|
|
|
const int ncols_interleaved = 4;
|
|
@@ -617,7 +654,7 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
|
|
|
+static void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
|
const int qk = QK8_0;
|
|
|
const int nb = n / qk;
|
|
|
const int ncols_interleaved = 4;
|
|
@@ -727,7 +764,7 @@ void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
|
|
|
+static void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
|
const int qk = QK8_0;
|
|
|
const int nb = n / qk;
|
|
|
const int ncols_interleaved = 8;
|
|
@@ -1000,7 +1037,7 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
|
|
|
+static void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
|
const int qk = QK8_0;
|
|
|
const int nb = n / qk;
|
|
|
const int ncols_interleaved = 4;
|
|
@@ -1096,7 +1133,7 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
|
|
|
+static void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
|
const int qk = QK8_0;
|
|
|
const int nb = n / qk;
|
|
|
const int ncols_interleaved = 4;
|
|
@@ -1612,7 +1649,7 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
|
|
|
+static void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
|
const int qk = QK8_0;
|
|
|
const int nb = n / qk;
|
|
|
const int ncols_interleaved = 4;
|
|
@@ -2066,7 +2103,7 @@ void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
|
|
|
+static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
|
const int qk = QK8_0;
|
|
|
const int nb = n / qk;
|
|
|
const int ncols_interleaved = 8;
|
|
@@ -2586,31 +2623,31 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
|
|
|
|
|
|
// Shuffle pattern one - right side input
|
|
|
- const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
|
|
|
- const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
|
|
|
+ const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
|
|
|
+ const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
|
|
|
|
|
|
- const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
|
|
|
- const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
|
|
|
+ const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
|
|
|
+ const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
|
|
|
|
|
|
- const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
|
|
|
- const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
|
|
|
+ const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
|
|
|
+ const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
|
|
|
|
|
|
- const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
|
|
|
- const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
|
|
|
+ const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
|
|
|
+ const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
|
|
|
|
|
|
// Shuffle pattern two - right side input
|
|
|
|
|
|
- const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
|
|
- const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
|
|
+ const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
|
|
+ const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
|
|
|
|
|
- const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
|
|
|
- const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
|
|
|
+ const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
|
|
|
+ const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
|
|
|
|
|
|
- const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
|
|
|
- const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
|
|
|
+ const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
|
|
|
+ const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
|
|
|
|
|
|
- const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
|
|
|
- const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
|
|
|
+ const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
|
|
|
+ const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
|
|
|
|
|
|
// Scale values - Load the weight scale values of two block_q4_0x8
|
|
|
const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
|
|
@@ -2644,31 +2681,31 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
|
|
|
// Shuffle pattern one - left side input
|
|
|
|
|
|
- const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
|
|
- const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
|
|
+ const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
|
|
+ const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
|
|
|
|
|
- const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
|
|
- const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
|
|
+ const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
|
|
+ const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
|
|
|
|
|
- const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
|
|
- const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
|
|
+ const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
|
|
+ const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
|
|
|
|
|
- const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
|
|
- const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
|
|
+ const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
|
|
+ const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
|
|
|
|
|
// Shuffle pattern two - left side input
|
|
|
|
|
|
- const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
|
|
- const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
|
|
+ const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
|
|
+ const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
|
|
|
|
|
- const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
|
|
- const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
|
|
+ const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
|
|
+ const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
|
|
|
|
|
- const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
|
|
- const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
|
|
+ const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
|
|
+ const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
|
|
|
|
|
- const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
|
|
- const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
|
|
+ const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
|
|
+ const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
|
|
|
|
|
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
|
|
// Resembles MMLAs into 2x2 matrices in ARM Version
|
|
@@ -2697,10 +2734,10 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
|
|
|
|
|
|
// Straighten out to make 4 row vectors
|
|
|
- __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
|
|
|
- __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
|
|
|
- __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
|
|
|
- __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
|
|
|
+ __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
|
|
|
+ __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
|
|
|
+ __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
|
|
|
+ __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
|
|
|
|
|
|
// Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
|
|
|
const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68);
|
|
@@ -2779,31 +2816,31 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
|
|
|
|
|
|
// Shuffle pattern one - right side input
|
|
|
- const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
|
|
|
- const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
|
|
|
+ const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
|
|
|
+ const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
|
|
|
|
|
|
- const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
|
|
|
- const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
|
|
|
+ const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
|
|
|
+ const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
|
|
|
|
|
|
- const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
|
|
|
- const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
|
|
|
+ const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
|
|
|
+ const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
|
|
|
|
|
|
- const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
|
|
|
- const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
|
|
|
+ const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
|
|
|
+ const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
|
|
|
|
|
|
// Shuffle pattern two - right side input
|
|
|
|
|
|
- const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
|
|
- const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
|
|
+ const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
|
|
+ const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
|
|
|
|
|
- const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
|
|
|
- const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
|
|
|
+ const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
|
|
|
+ const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
|
|
|
|
|
|
- const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
|
|
|
- const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
|
|
|
+ const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
|
|
|
+ const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
|
|
|
|
|
|
- const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
|
|
|
- const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
|
|
|
+ const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
|
|
|
+ const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
|
|
|
|
|
|
|
|
|
// Scale values - Load the weight scale values of two block_q4_0x8
|
|
@@ -2835,31 +2872,31 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
|
|
|
// Shuffle pattern one - left side input
|
|
|
|
|
|
- const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
|
|
- const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
|
|
+ const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
|
|
+ const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
|
|
|
|
|
- const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
|
|
- const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
|
|
+ const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
|
|
+ const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
|
|
|
|
|
- const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
|
|
- const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
|
|
+ const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
|
|
+ const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
|
|
|
|
|
- const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
|
|
- const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
|
|
+ const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
|
|
+ const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
|
|
|
|
|
// Shuffle pattern two - left side input
|
|
|
|
|
|
- const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
|
|
- const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
|
|
+ const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
|
|
+ const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
|
|
|
|
|
- const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
|
|
- const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
|
|
+ const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
|
|
+ const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
|
|
|
|
|
- const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
|
|
- const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
|
|
+ const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
|
|
+ const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
|
|
|
|
|
- const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
|
|
- const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
|
|
+ const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
|
|
+ const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
|
|
|
|
|
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
|
|
// Resembles MMLAs into 2x2 matrices in ARM Version
|
|
@@ -2888,10 +2925,10 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
|
|
|
|
|
|
// Straighten out to make 4 row vectors
|
|
|
- __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
|
|
|
- __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
|
|
|
- __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
|
|
|
- __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
|
|
|
+ __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
|
|
|
+ __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
|
|
|
+ __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
|
|
|
+ __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
|
|
|
|
|
|
// Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
|
|
|
const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68);
|
|
@@ -3486,7 +3523,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
|
|
|
+static void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
|
const int qk = QK8_0;
|
|
|
const int nb = n / qk;
|
|
|
const int ncols_interleaved = 4;
|
|
@@ -3597,7 +3634,6 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-// FIXME: this code is duplicated from ggml-aarch64.c
|
|
|
static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
|
|
|
block_q4_0x4 out;
|
|
|
|
|
@@ -3667,20 +3703,20 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in
|
|
|
return out;
|
|
|
}
|
|
|
|
|
|
-static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * restrict data, size_t data_size) {
|
|
|
+static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
|
|
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
|
|
|
GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
|
|
+ constexpr int nrows_interleaved = 4;
|
|
|
|
|
|
block_q4_0x4 * dst = (block_q4_0x4 *)t->data;
|
|
|
const block_q4_0 * src = (const block_q4_0 *)data;
|
|
|
block_q4_0 dst_tmp[4];
|
|
|
- int nrow = t->ne[1]; // Number of rows
|
|
|
- int nrows_interleaved = 4;
|
|
|
+ int nrow = ggml_nrows(t);
|
|
|
int nblocks = t->ne[0] / QK4_0;
|
|
|
|
|
|
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
|
|
|
|
|
|
- if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
|
|
+ if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
|
|
return -1;
|
|
|
}
|
|
|
|
|
@@ -3698,20 +3734,20 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block
|
|
|
GGML_UNUSED(data_size);
|
|
|
}
|
|
|
|
|
|
-static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor *t, int interleave_block, const void * restrict data, size_t data_size) {
|
|
|
+static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
|
|
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
|
|
|
GGML_ASSERT(interleave_block == 8);
|
|
|
+ constexpr int nrows_interleaved = 8;
|
|
|
|
|
|
block_q4_0x8 * dst = (block_q4_0x8*)t->data;
|
|
|
const block_q4_0 * src = (const block_q4_0*) data;
|
|
|
block_q4_0 dst_tmp[8];
|
|
|
- int nrow = t->ne[1]; // Number of rows
|
|
|
- int nrows_interleaved = 8;
|
|
|
+ int nrow = ggml_nrows(t);
|
|
|
int nblocks = t->ne[0] / QK4_0;
|
|
|
|
|
|
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
|
|
|
|
|
|
- if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
|
|
+ if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
|
|
return -1;
|
|
|
}
|
|
|
|
|
@@ -3738,16 +3774,18 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s
|
|
|
|
|
|
const int end = QK4_NL * 2 / blck_size_interleave;
|
|
|
|
|
|
- if (blck_size_interleave == 8) {
|
|
|
- for (int i = 0; i < end; ++i) {
|
|
|
- int src_id = i % 4;
|
|
|
- int src_offset = (i / 4) * blck_size_interleave;
|
|
|
- int dst_offset = i * blck_size_interleave;
|
|
|
-
|
|
|
- // Using memcpy to avoid unaligned memory accesses
|
|
|
- memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
|
|
|
- }
|
|
|
- } else if (blck_size_interleave == 4) {
|
|
|
+ // TODO: this branch seems wrong
|
|
|
+ //if (blck_size_interleave == 8) {
|
|
|
+ // for (int i = 0; i < end; ++i) {
|
|
|
+ // int src_id = i % 4;
|
|
|
+ // int src_offset = (i / 4) * blck_size_interleave;
|
|
|
+ // int dst_offset = i * blck_size_interleave;
|
|
|
+
|
|
|
+ // // Using memcpy to avoid unaligned memory accesses
|
|
|
+ // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
|
|
|
+ // }
|
|
|
+ //} else
|
|
|
+ if (blck_size_interleave == 4) {
|
|
|
for (int i = 0; i < end; ++i) {
|
|
|
int src_id = i % 4;
|
|
|
int src_offset = (i / 4) * blck_size_interleave;
|
|
@@ -3762,20 +3800,21 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s
|
|
|
return out;
|
|
|
}
|
|
|
|
|
|
-static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * restrict data, size_t data_size) {
|
|
|
+static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
|
|
GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
|
|
|
- GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
|
|
+ //GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
|
|
+ GGML_ASSERT(interleave_block == 4);
|
|
|
|
|
|
block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data;
|
|
|
const block_iq4_nl * src = (const block_iq4_nl *)data;
|
|
|
block_iq4_nl dst_tmp[4];
|
|
|
- int nrow = t->ne[1]; // Number of rows
|
|
|
+ int nrow = ggml_nrows(t);
|
|
|
int nrows_interleaved = 4;
|
|
|
int nblocks = t->ne[0] / QK4_0;
|
|
|
|
|
|
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
|
|
|
|
|
|
- if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
|
|
+ if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
|
|
return -1;
|
|
|
}
|
|
|
|
|
@@ -3793,57 +3832,457 @@ static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_b
|
|
|
GGML_UNUSED(data_size);
|
|
|
}
|
|
|
|
|
|
-// Prepare for optimized kernels if applicable
|
|
|
-void ggml_aarch64_repack_tensor(struct ggml_tensor * cur, enum ggml_type repack_type, const void * restrict data, size_t data_size) {
|
|
|
- if (cur->type == repack_type) {
|
|
|
- memcpy(cur->data, data, data_size);
|
|
|
- return;
|
|
|
+namespace ggml::cpu::aarch64 {
|
|
|
+// repack
|
|
|
+template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
|
|
|
+int repack(struct ggml_tensor *, const void *, size_t);
|
|
|
+
|
|
|
+// TODO: generalise.
|
|
|
+template <> int repack<block_q4_0, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
|
|
+ return repack_q4_0_to_q4_0_4_bl(t, 4, data, data_size);
|
|
|
+}
|
|
|
+
|
|
|
+template <> int repack<block_q4_0, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
|
|
+ return repack_q4_0_to_q4_0_4_bl(t, 8, data, data_size);
|
|
|
+}
|
|
|
+
|
|
|
+template <> int repack<block_q4_0, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
|
|
+ return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size);
|
|
|
+}
|
|
|
+
|
|
|
+template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
|
|
+ return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
|
|
|
+}
|
|
|
+
|
|
|
+// TODO: needs to be revisited
|
|
|
+//template <> int repack<block_iq4_nl, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
|
|
+// return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size);
|
|
|
+//}
|
|
|
+
|
|
|
+// gemv
|
|
|
+template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
|
|
|
+void gemv(int, float *, size_t, const void *, const void *, int, int);
|
|
|
+
|
|
|
+template <> void gemv<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
|
+ ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
|
+}
|
|
|
+
|
|
|
+template <> void gemv<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
|
+ ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
|
+}
|
|
|
+
|
|
|
+template <> void gemv<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
|
+ ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
|
+}
|
|
|
+
|
|
|
+template <>
|
|
|
+void gemv<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
|
+ ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
|
+}
|
|
|
+
|
|
|
+// gemm
|
|
|
+template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
|
|
|
+void gemm(int, float *, size_t, const void *, const void *, int, int);
|
|
|
+
|
|
|
+template <> void gemm<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
|
+ ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
|
+}
|
|
|
+
|
|
|
+template <> void gemm<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
|
+ ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
|
+}
|
|
|
+
|
|
|
+template <> void gemm<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
|
+ ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
|
+}
|
|
|
+
|
|
|
+template <>
|
|
|
+void gemm<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
|
+ ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
|
+}
|
|
|
+
|
|
|
+class tensor_traits_base : public ggml::cpu::tensor_traits {
|
|
|
+ public:
|
|
|
+ virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
|
|
|
+};
|
|
|
+
|
|
|
+template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base {
|
|
|
+
|
|
|
+ bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
|
|
|
+ // not realy a GGML_TYPE_Q8_0 but same size.
|
|
|
+ switch (op->op) {
|
|
|
+ case GGML_OP_MUL_MAT:
|
|
|
+ size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1]));
|
|
|
+ return true;
|
|
|
+ case GGML_OP_MUL_MAT_ID:
|
|
|
+ size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1]));
|
|
|
+ size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
|
|
|
+ size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
|
|
|
+ return true;
|
|
|
+ default:
|
|
|
+ // GGML_ABORT("fatal error");
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ return false;
|
|
|
}
|
|
|
|
|
|
- if (cur->type == GGML_TYPE_Q4_0) {
|
|
|
- switch (repack_type) {
|
|
|
- case GGML_TYPE_Q4_0_8_8:
|
|
|
- repack_q4_0_to_q4_0_8_bl(cur, 8, data, data_size);
|
|
|
- break;
|
|
|
- case GGML_TYPE_Q4_0_4_8:
|
|
|
- repack_q4_0_to_q4_0_4_bl(cur, 8, data, data_size);
|
|
|
- break;
|
|
|
- case GGML_TYPE_Q4_0_4_4:
|
|
|
- repack_q4_0_to_q4_0_4_bl(cur, 4, data, data_size);
|
|
|
- break;
|
|
|
- default:
|
|
|
- GGML_ABORT("Unsupported type");
|
|
|
+ bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
|
|
|
+ switch (op->op) {
|
|
|
+ case GGML_OP_MUL_MAT:
|
|
|
+ forward_mul_mat(params, op);
|
|
|
+ return true;
|
|
|
+ case GGML_OP_MUL_MAT_ID:
|
|
|
+ forward_mul_mat_id(params, op);
|
|
|
+ return true;
|
|
|
+ default:
|
|
|
+ // GGML_ABORT("fatal error");
|
|
|
+ break;
|
|
|
}
|
|
|
- } else if (cur->type == GGML_TYPE_IQ4_NL) {
|
|
|
- switch (repack_type) {
|
|
|
- case GGML_TYPE_IQ4_NL_4_4:
|
|
|
- repack_iq4_nl_to_iq4_nl_4_bl(cur, 4, data, data_size);
|
|
|
- break;
|
|
|
- default:
|
|
|
- GGML_ABORT("Unsupported type");
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) {
|
|
|
+ const ggml_tensor * src0 = op->src[0];
|
|
|
+ const ggml_tensor * src1 = op->src[1];
|
|
|
+ ggml_tensor * dst = op;
|
|
|
+
|
|
|
+ GGML_TENSOR_BINARY_OP_LOCALS
|
|
|
+
|
|
|
+ const int ith = params->ith;
|
|
|
+ const int nth = params->nth;
|
|
|
+
|
|
|
+ GGML_ASSERT(ne0 == ne01);
|
|
|
+ GGML_ASSERT(ne1 == ne11);
|
|
|
+ GGML_ASSERT(ne2 == ne12);
|
|
|
+ GGML_ASSERT(ne3 == ne13);
|
|
|
+
|
|
|
+ // dst cannot be transposed or permuted
|
|
|
+ GGML_ASSERT(nb0 == sizeof(float));
|
|
|
+ GGML_ASSERT(nb0 <= nb1);
|
|
|
+ GGML_ASSERT(nb1 <= nb2);
|
|
|
+ GGML_ASSERT(nb2 <= nb3);
|
|
|
+
|
|
|
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
|
+
|
|
|
+ GGML_ASSERT(ggml_n_dims(op->src[0]) == 2);
|
|
|
+ // GGML_ASSERT(ggml_n_dims(op->src[1]) == 2);
|
|
|
+
|
|
|
+ char * wdata = static_cast<char *>(params->wdata);
|
|
|
+ const size_t nbw1 = ggml_row_size(GGML_TYPE_Q8_0, ne10);
|
|
|
+
|
|
|
+ assert(params->wsize >= nbw1 * ne11);
|
|
|
+
|
|
|
+ const ggml_from_float_t from_float = ggml_get_type_traits_cpu(GGML_TYPE_Q8_0)->from_float;
|
|
|
+
|
|
|
+ int64_t i11_processed = 0;
|
|
|
+ for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
|
|
+ quantize_mat_q8_0((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10,
|
|
|
+ INTER_SIZE);
|
|
|
+ }
|
|
|
+ i11_processed = ne11 - ne11 % 4;
|
|
|
+ for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
|
|
|
+ from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
|
|
|
+ }
|
|
|
+
|
|
|
+ ggml_barrier(params->threadpool);
|
|
|
+
|
|
|
+ const void * src1_wdata = params->wdata;
|
|
|
+ const size_t src1_col_stride = ggml_row_size(GGML_TYPE_Q8_0, ne10);
|
|
|
+ int64_t src0_start = (ith * ne01) / nth;
|
|
|
+ int64_t src0_end = ((ith + 1) * ne01) / nth;
|
|
|
+ src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
|
|
|
+ src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
|
|
|
+ if (src0_start >= src0_end) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ // If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
|
|
+ if (ne11 > 3) {
|
|
|
+ gemm<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data) + src0_start, ne01,
|
|
|
+ (const char *) src0->data + src0_start * nb01,
|
|
|
+ (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
|
|
|
+ }
|
|
|
+ for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
|
|
|
+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
|
|
|
+ (const char *) src0->data + src0_start * nb01,
|
|
|
+ (const char *) src1_wdata + (src1_col_stride * iter), 1,
|
|
|
+ src0_end - src0_start);
|
|
|
}
|
|
|
- } else {
|
|
|
- GGML_ABORT("Unsupported type");
|
|
|
}
|
|
|
-}
|
|
|
|
|
|
-enum ggml_type ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur) {
|
|
|
+ void forward_mul_mat_id(ggml_compute_params * params, ggml_tensor * op) {
|
|
|
+ const ggml_tensor * src0 = op->src[0];
|
|
|
+ const ggml_tensor * src1 = op->src[1];
|
|
|
+ const ggml_tensor * ids = op->src[2];
|
|
|
+ ggml_tensor * dst = op;
|
|
|
+
|
|
|
+ GGML_TENSOR_BINARY_OP_LOCALS
|
|
|
+
|
|
|
+ const int ith = params->ith;
|
|
|
+ const int nth = params->nth;
|
|
|
+
|
|
|
+ const ggml_from_float_t from_float = ggml_get_type_traits_cpu(GGML_TYPE_Q8_0)->from_float;
|
|
|
+
|
|
|
+ // we don't support permuted src0 or src1
|
|
|
+ GGML_ASSERT(nb00 == ggml_type_size(src0->type));
|
|
|
+ GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
|
|
+
|
|
|
+ // dst cannot be transposed or permuted
|
|
|
+ GGML_ASSERT(nb0 == sizeof(float));
|
|
|
+ GGML_ASSERT(nb0 <= nb1);
|
|
|
+ GGML_ASSERT(nb1 <= nb2);
|
|
|
+ GGML_ASSERT(nb2 <= nb3);
|
|
|
+
|
|
|
+ GGML_ASSERT(ne03 == 1);
|
|
|
+ GGML_ASSERT(ne13 == 1);
|
|
|
+ GGML_ASSERT(ne3 == 1);
|
|
|
+
|
|
|
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
|
+
|
|
|
+ // row groups
|
|
|
+ const int n_ids = ids->ne[0]; // n_expert_used
|
|
|
+ const int n_as = ne02; // n_expert
|
|
|
+
|
|
|
+ const size_t nbw1 = ggml_row_size(GGML_TYPE_Q8_0, ne10);
|
|
|
+ const size_t nbw2 = nbw1*ne11;
|
|
|
+ const size_t nbw3 = nbw2*ne12;
|
|
|
+
|
|
|
+ struct mmid_row_mapping {
|
|
|
+ int32_t i1;
|
|
|
+ int32_t i2;
|
|
|
+ };
|
|
|
+
|
|
|
+ GGML_ASSERT(params->wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
|
|
|
+ n_as * ne12 * sizeof(mmid_row_mapping)));
|
|
|
+
|
|
|
+ auto wdata = (char *) params->wdata;
|
|
|
+ auto wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t));
|
|
|
+ int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
|
|
|
+ struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
|
|
|
+
|
|
|
+ // src1: float32 => block_q8_0
|
|
|
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
|
|
+ for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
|
|
|
+ from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
|
|
|
+ (void *) (wdata + i12 * nbw2 + i11 * nbw1),
|
|
|
+ ne10);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ne12 + (i1)]
|
|
|
+
|
|
|
+ if (ith == 0) {
|
|
|
+ // initialize matrix_row_counts
|
|
|
+ memset(matrix_row_counts, 0, n_as * sizeof(int64_t));
|
|
|
+
|
|
|
+ // group rows by src0 matrix
|
|
|
+ for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
|
|
|
+ for (int32_t id = 0; id < n_ids; ++id) {
|
|
|
+ const int32_t i02 =
|
|
|
+ *(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
|
|
|
+
|
|
|
+ GGML_ASSERT(i02 >= 0 && i02 < n_as);
|
|
|
+
|
|
|
+ MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 };
|
|
|
+ matrix_row_counts[i02] += 1;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ ggml_barrier(params->threadpool);
|
|
|
+
|
|
|
+ // compute each matrix multiplication in sequence
|
|
|
+ for (int cur_a = 0; cur_a < n_as; ++cur_a) {
|
|
|
+ const int64_t cne1 = matrix_row_counts[cur_a];
|
|
|
+
|
|
|
+ if (cne1 == 0) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ auto src0_cur = (const char *) src0->data + cur_a*nb02;
|
|
|
+
|
|
|
+ //const int64_t nr0 = ne01; // src0 rows
|
|
|
+ const int64_t nr1 = cne1; // src1 rows
|
|
|
+
|
|
|
+ int64_t src0_cur_start = (ith * ne01) / nth;
|
|
|
+ int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
|
|
|
+ src0_cur_start =
|
|
|
+ (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
|
|
|
+ src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
|
|
|
+
|
|
|
+ if (src0_cur_start >= src0_cur_end) return;
|
|
|
+
|
|
|
+ for (int ir1 = 0; ir1 < nr1; ir1++) {
|
|
|
+ struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
|
|
|
+ const int id = row_mapping.i1; // selected expert index
|
|
|
+
|
|
|
+ const int64_t i11 = id % ne11;
|
|
|
+ const int64_t i12 = row_mapping.i2; // row index in src1
|
|
|
+
|
|
|
+ const int64_t i1 = id; // selected expert index
|
|
|
+ const int64_t i2 = i12; // row
|
|
|
+
|
|
|
+ auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
|
|
|
+
|
|
|
+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(
|
|
|
+ ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start,
|
|
|
+ ne01, src0_cur + src0_cur_start * nb01,
|
|
|
+ src1_col, 1, src0_cur_end - src0_cur_start);
|
|
|
+ }
|
|
|
+ }
|
|
|
+#undef MMID_MATRIX_ROW
|
|
|
+ }
|
|
|
+
|
|
|
+ int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
|
|
|
+ GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type),
|
|
|
+ (int) NB_COLS, (int) INTER_SIZE);
|
|
|
+ return ggml::cpu::aarch64::repack<BLOC_TYPE, INTER_SIZE, NB_COLS>(t, data, data_size);
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
+// instance for Q4
|
|
|
+static const tensor_traits<block_q4_0, 4, 4> q4_0_4x4_q8_0;
|
|
|
+static const tensor_traits<block_q4_0, 8, 4> q4_0_4x8_q8_0;
|
|
|
+static const tensor_traits<block_q4_0, 8, 8> q4_0_8x8_q8_0;
|
|
|
+
|
|
|
+// instance for IQ4
|
|
|
+static const tensor_traits<block_iq4_nl, 4, 4> iq4_nl_4x4_q8_0;
|
|
|
+
|
|
|
+} // namespace ggml::cpu::aarch64
|
|
|
+
|
|
|
+static const ggml::cpu::tensor_traits * ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur) {
|
|
|
if (cur->type == GGML_TYPE_Q4_0) {
|
|
|
- // TODO: enable for AVX2 - currently disabled due to bad gemv performance
|
|
|
- if (/* ggml_cpu_has_avx2() || */ (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) {
|
|
|
- return GGML_TYPE_Q4_0_8_8;
|
|
|
+ if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) {
|
|
|
+ if (cur->ne[1] % 8 == 0) {
|
|
|
+ return &ggml::cpu::aarch64::q4_0_8x8_q8_0;
|
|
|
+ }
|
|
|
}
|
|
|
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
|
|
- return GGML_TYPE_Q4_0_4_8;
|
|
|
+ if (cur->ne[1] % 4 == 0) {
|
|
|
+ return &ggml::cpu::aarch64::q4_0_4x8_q8_0;
|
|
|
+ }
|
|
|
}
|
|
|
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
|
|
- return GGML_TYPE_Q4_0_4_4;
|
|
|
+ if (cur->ne[1] % 4 == 0) {
|
|
|
+ return &ggml::cpu::aarch64::q4_0_4x4_q8_0;
|
|
|
+ }
|
|
|
}
|
|
|
} else if (cur->type == GGML_TYPE_IQ4_NL) {
|
|
|
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
|
|
- return GGML_TYPE_IQ4_NL_4_4;
|
|
|
+ if (cur->ne[1] % 4 == 0) {
|
|
|
+ return &ggml::cpu::aarch64::iq4_nl_4x4_q8_0;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return nullptr;
|
|
|
+}
|
|
|
+
|
|
|
+static void ggml_backend_cpu_aarch64_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
|
|
|
+ tensor->extra = (void *) const_cast<ggml::cpu::tensor_traits *>(ggml_aarch64_get_optimal_repack_type(tensor));
|
|
|
+
|
|
|
+ GGML_UNUSED(buffer);
|
|
|
+}
|
|
|
+
|
|
|
+static void ggml_backend_cpu_aarch64_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
|
|
|
+ const void * data, size_t offset, size_t size) {
|
|
|
+ GGML_ASSERT(offset == 0);
|
|
|
+ GGML_ASSERT(size == ggml_nbytes(tensor));
|
|
|
+
|
|
|
+ auto tensor_traits = (ggml::cpu::aarch64::tensor_traits_base *) tensor->extra;
|
|
|
+ auto OK = tensor_traits->repack(tensor, data, size);
|
|
|
+
|
|
|
+ GGML_ASSERT(OK == 0);
|
|
|
+ GGML_UNUSED(buffer);
|
|
|
+}
|
|
|
+
|
|
|
+static const char * ggml_backend_cpu_aarch64_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
|
|
+ return "CPU_AARCH64";
|
|
|
+
|
|
|
+ GGML_UNUSED(buft);
|
|
|
+}
|
|
|
+
|
|
|
+static ggml_backend_buffer_t ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
|
|
+ ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
|
|
|
+
|
|
|
+ if (buffer == nullptr) {
|
|
|
+ return nullptr;
|
|
|
+ }
|
|
|
+
|
|
|
+ buffer->buft = buft;
|
|
|
+ buffer->iface.init_tensor = ggml_backend_cpu_aarch64_buffer_init_tensor;
|
|
|
+ buffer->iface.set_tensor = ggml_backend_cpu_aarch64_buffer_set_tensor;
|
|
|
+ return buffer;
|
|
|
+}
|
|
|
+
|
|
|
+static size_t ggml_backend_cpu_aarch64_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
|
|
+ return TENSOR_ALIGNMENT;
|
|
|
+
|
|
|
+ GGML_UNUSED(buft);
|
|
|
+}
|
|
|
+
|
|
|
+namespace ggml::cpu::aarch64 {
|
|
|
+class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|
|
+ bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
|
|
+ if ( op->op == GGML_OP_MUL_MAT &&
|
|
|
+ op->src[0]->buffer &&
|
|
|
+ (ggml_n_dims(op->src[0]) == 2) &&
|
|
|
+ op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type() &&
|
|
|
+ ggml_aarch64_get_optimal_repack_type(op->src[0])
|
|
|
+ ) {
|
|
|
+ if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ if (op->src[1]->type == GGML_TYPE_F32) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ //if (op->src[1]->type == GGML_TYPE_Q8_0) {
|
|
|
+ // return true;
|
|
|
+ //}
|
|
|
+ // may be possible if Q8_0 packed...
|
|
|
+ } else if (op->op == GGML_OP_MUL_MAT_ID
|
|
|
+ && op->src[0]->buffer
|
|
|
+ && (ggml_n_dims(op->src[0]) == 3)
|
|
|
+ && op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type()
|
|
|
+ && ggml_aarch64_get_optimal_repack_type(op->src[0])
|
|
|
+ ) {
|
|
|
+ if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ if (op->src[1]->type == GGML_TYPE_F32) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ //if (op->src[1]->type == GGML_TYPE_Q8_0) {
|
|
|
+ // return true;
|
|
|
+ //}
|
|
|
}
|
|
|
+ return false;
|
|
|
}
|
|
|
|
|
|
- return cur->type;
|
|
|
+ ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
|
|
|
+ if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) {
|
|
|
+ if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type()) {
|
|
|
+ return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return nullptr;
|
|
|
+ }
|
|
|
+};
|
|
|
+} // namespace ggml::cpu::aarch64
|
|
|
+
|
|
|
+ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void) {
|
|
|
+ static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_aarch64 = {
|
|
|
+ /* .iface = */ {
|
|
|
+ /* .get_name = */ ggml_backend_cpu_aarch64_buffer_type_get_name,
|
|
|
+ /* .alloc_buffer = */ ggml_backend_cpu_aarch64_buffer_type_alloc_buffer,
|
|
|
+ /* .get_alignment = */ ggml_backend_cpu_aarch64_buffer_type_get_alignment,
|
|
|
+ /* .get_max_size = */ nullptr, // defaults to SIZE_MAX
|
|
|
+ /* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
|
|
|
+ /* .is_host = */ nullptr,
|
|
|
+ },
|
|
|
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
|
|
|
+ /* .context = */ new ggml::cpu::aarch64::extra_buffer_type(),
|
|
|
+ };
|
|
|
+
|
|
|
+ return &ggml_backend_cpu_buffer_type_aarch64;
|
|
|
}
|