|
@@ -1,15 +1,26 @@
|
|
|
|
+#pragma once
|
|
|
|
+
|
|
#include "common.cuh"
|
|
#include "common.cuh"
|
|
#include "vecdotq.cuh"
|
|
#include "vecdotq.cuh"
|
|
|
|
|
|
#include <climits>
|
|
#include <climits>
|
|
#include <cstdint>
|
|
#include <cstdint>
|
|
|
|
|
|
|
|
+#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
|
|
|
|
+
|
|
typedef void (*load_tiles_mmq_t)(
|
|
typedef void (*load_tiles_mmq_t)(
|
|
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
|
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
|
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride);
|
|
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride);
|
|
typedef void (*vec_dot_mmq_t)(
|
|
typedef void (*vec_dot_mmq_t)(
|
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
|
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, float * __restrict__ sum, const int & k0);
|
|
|
|
|
|
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0);
|
|
|
|
+
|
|
|
|
+struct block_q8_1_mmq {
|
|
|
|
+ half2 ds[4];
|
|
|
|
+ int8_t qs[4*QK8_1];
|
|
|
|
+};
|
|
|
|
+static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
|
|
|
|
+static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
|
|
|
|
|
|
struct tile_x_sizes {
|
|
struct tile_x_sizes {
|
|
int ql;
|
|
int ql;
|
|
@@ -132,10 +143,14 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
template <int mmq_x, int mmq_y, int nwarps>
|
|
template <int mmq_x, int mmq_y, int nwarps>
|
|
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat(
|
|
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat(
|
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
|
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
|
|
|
|
|
|
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
|
|
|
|
|
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
|
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
|
|
|
|
|
|
|
+ const float * x_dmf = (const float *) x_dm;
|
|
|
|
+ const int * y_qs = (const int *) y + 4;
|
|
|
|
+ const half2 * y_ds = (const half2 *) y;
|
|
|
|
+
|
|
#pragma unroll
|
|
#pragma unroll
|
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
const int j = j0 + threadIdx.y;
|
|
const int j = j0 + threadIdx.y;
|
|
@@ -145,19 +160,18 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat(
|
|
const int i = i0 + threadIdx.x;
|
|
const int i = i0 + threadIdx.x;
|
|
|
|
|
|
const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
|
|
const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
|
|
- const float * x_dmf = (const float *) x_dm;
|
|
|
|
|
|
|
|
int u[2*VDR_Q4_0_Q8_1_MMQ];
|
|
int u[2*VDR_Q4_0_Q8_1_MMQ];
|
|
|
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
|
|
for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
|
|
- u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
|
|
|
|
- u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE];
|
|
|
|
|
|
+ u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
|
|
|
|
+ u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_0) % WARP_SIZE];
|
|
}
|
|
}
|
|
|
|
|
|
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
|
|
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
|
|
- (&x_ql[i * (WARP_SIZE + 1) + k0], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
|
|
|
|
- y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
|
|
|
|
|
|
+ (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dmf[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
|
|
|
|
+ y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -203,10 +217,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
template <int mmq_x, int mmq_y, int nwarps>
|
|
template <int mmq_x, int mmq_y, int nwarps>
|
|
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat(
|
|
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat(
|
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
|
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
|
|
|
|
|
|
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
|
|
|
|
|
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
|
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
|
|
|
|
|
|
|
+ const int * y_qs = (const int *) y + 4;
|
|
|
|
+ const half2 * y_ds = (const half2 *) y;
|
|
|
|
+
|
|
#pragma unroll
|
|
#pragma unroll
|
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
const int j = j0 + threadIdx.y;
|
|
const int j = j0 + threadIdx.y;
|
|
@@ -221,13 +238,13 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat(
|
|
|
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
|
|
for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
|
|
- u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
|
|
|
|
- u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE];
|
|
|
|
|
|
+ u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
|
|
|
|
+ u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_1) % WARP_SIZE];
|
|
}
|
|
}
|
|
|
|
|
|
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
|
|
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
|
|
- (&x_ql[i * (WARP_SIZE + 1) + k0], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1],
|
|
|
|
- y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
|
|
|
|
|
|
+ (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1],
|
|
|
|
+ y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -293,10 +310,14 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
template <int mmq_x, int mmq_y, int nwarps>
|
|
template <int mmq_x, int mmq_y, int nwarps>
|
|
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat(
|
|
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat(
|
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
|
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
|
|
|
|
|
|
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
|
|
|
|
|
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
|
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
|
|
|
|
|
|
|
+ const float * x_dmf = (const float *) x_dm;
|
|
|
|
+ const int * y_qs = (const int *) y + 4;
|
|
|
|
+ const float * y_df = (const float *) y;
|
|
|
|
+
|
|
#pragma unroll
|
|
#pragma unroll
|
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
const int j = j0 + threadIdx.y;
|
|
const int j = j0 + threadIdx.y;
|
|
@@ -306,20 +327,18 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat(
|
|
const int i = i0 + threadIdx.x;
|
|
const int i = i0 + threadIdx.x;
|
|
|
|
|
|
const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
|
|
const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
|
|
- const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0;
|
|
|
|
- const float * x_dmf = (const float *) x_dm;
|
|
|
|
- const float * y_df = (const float *) y_ds;
|
|
|
|
|
|
+ const int index_bx = i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0;
|
|
|
|
|
|
int u[2*VDR_Q5_0_Q8_1_MMQ];
|
|
int u[2*VDR_Q5_0_Q8_1_MMQ];
|
|
|
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
|
|
for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
|
|
- u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
|
|
|
|
- u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
|
|
|
|
|
|
+ u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
|
|
|
|
+ u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_0) % WARP_SIZE];
|
|
}
|
|
}
|
|
|
|
|
|
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
|
|
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
|
|
- (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k0], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
|
|
|
|
|
|
+ (&x_ql[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dmf[index_bx], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -383,10 +402,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
template <int mmq_x, int mmq_y, int nwarps>
|
|
template <int mmq_x, int mmq_y, int nwarps>
|
|
static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat(
|
|
static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat(
|
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
|
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
|
|
|
|
|
|
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
|
|
|
|
|
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
|
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
|
|
|
|
|
|
|
+ const int * y_qs = (const int *) y + 4;
|
|
|
|
+ const half2 * y_ds = (const half2 *) y;
|
|
|
|
+
|
|
#pragma unroll
|
|
#pragma unroll
|
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
const int j = j0 + threadIdx.y;
|
|
const int j = j0 + threadIdx.y;
|
|
@@ -396,18 +418,18 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat(
|
|
const int i = i0 + threadIdx.x;
|
|
const int i = i0 + threadIdx.x;
|
|
|
|
|
|
const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
|
|
const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
|
|
- const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k0/QI5_1;
|
|
|
|
|
|
+ const int index_bx = i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1;
|
|
|
|
|
|
int u[2*VDR_Q5_1_Q8_1_MMQ];
|
|
int u[2*VDR_Q5_1_Q8_1_MMQ];
|
|
|
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
|
|
for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
|
|
- u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
|
|
|
|
- u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE];
|
|
|
|
|
|
+ u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
|
|
|
|
+ u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_1) % WARP_SIZE];
|
|
}
|
|
}
|
|
|
|
|
|
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
|
|
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
|
|
- (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k0], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
|
|
|
|
|
|
+ (&x_ql[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dm[index_bx], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -455,10 +477,14 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
template <int mmq_x, int mmq_y, int nwarps>
|
|
template <int mmq_x, int mmq_y, int nwarps>
|
|
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat(
|
|
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat(
|
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
|
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
|
|
|
|
|
|
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
|
|
|
|
|
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
|
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
|
|
|
|
|
|
|
+ const float * x_dmf = (const float *) x_dm;
|
|
|
|
+ const int * y_qs = (const int *) y + 4;
|
|
|
|
+ const float * y_df = (const float *) y;
|
|
|
|
+
|
|
#pragma unroll
|
|
#pragma unroll
|
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
const int j = j0 + threadIdx.y;
|
|
const int j = j0 + threadIdx.y;
|
|
@@ -467,12 +493,9 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat(
|
|
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
|
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
|
const int i = i0 + threadIdx.x;
|
|
const int i = i0 + threadIdx.x;
|
|
|
|
|
|
- const float * x_dmf = (const float *) x_dm;
|
|
|
|
- const float * y_df = (const float *) y_ds;
|
|
|
|
-
|
|
|
|
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
|
|
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
|
|
- (&x_ql[i * (WARP_SIZE + 1) + k0], &y_qs[j * WARP_SIZE + k0], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
|
|
|
|
- y_df[j * (WARP_SIZE/QI8_1) + k0/QI8_1]);
|
|
|
|
|
|
+ (&x_ql[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_dmf[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
|
|
|
|
+ y_df[j*MMQ_TILE_Y_K + k0/QI8_1]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -531,10 +554,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
template <int mmq_x, int mmq_y, int nwarps>
|
|
template <int mmq_x, int mmq_y, int nwarps>
|
|
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat(
|
|
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat(
|
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
|
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
|
|
|
|
|
|
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
|
|
|
|
|
GGML_UNUSED(x_qh);
|
|
GGML_UNUSED(x_qh);
|
|
|
|
|
|
|
|
+ const int * y_qs = (const int *) y + 4;
|
|
|
|
+ const float * y_df = (const float *) y;
|
|
|
|
+
|
|
#pragma unroll
|
|
#pragma unroll
|
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
const int j = j0 + threadIdx.y;
|
|
const int j = j0 + threadIdx.y;
|
|
@@ -545,11 +571,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat(
|
|
|
|
|
|
const int kbx = k0 / QI2_K;
|
|
const int kbx = k0 / QI2_K;
|
|
const int ky = (k0 % QI2_K) * QR2_K;
|
|
const int ky = (k0 % QI2_K) * QR2_K;
|
|
- const float * y_df = (const float *) y_ds;
|
|
|
|
|
|
|
|
int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
|
|
int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
|
|
|
|
|
|
- const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
|
|
|
|
|
|
+ const int kqsx = i*(WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
|
|
const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
|
|
const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
|
|
|
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
@@ -557,11 +582,11 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat(
|
|
v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
|
|
v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
|
|
}
|
|
}
|
|
|
|
|
|
- const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
|
|
|
|
|
|
+ const uint8_t * scales = ((const uint8_t *) &x_sc[i*(WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
|
|
|
|
|
|
- const int index_y = j * WARP_SIZE + (QR2_K*k0) % WARP_SIZE;
|
|
|
|
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
|
|
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
|
|
- v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
|
|
|
|
|
|
+ v, &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE], scales,
|
|
|
|
+ x_dm[i*(WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -646,7 +671,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
template <int mmq_x, int mmq_y, int nwarps>
|
|
template <int mmq_x, int mmq_y, int nwarps>
|
|
static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat(
|
|
static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat(
|
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
|
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
|
|
|
|
|
|
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
|
|
|
+
|
|
|
|
+ const float * x_dmf = (const float *) x_dm;
|
|
|
|
+ const int * y_qs = (const int *) y + 4;
|
|
|
|
+ const float * y_df = (const float *) y;
|
|
|
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
@@ -658,8 +687,6 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat(
|
|
|
|
|
|
const int kbx = k0 / QI3_K;
|
|
const int kbx = k0 / QI3_K;
|
|
const int ky = (k0 % QI3_K) * QR3_K;
|
|
const int ky = (k0 % QI3_K) * QR3_K;
|
|
- const float * x_dmf = (const float *) x_dm;
|
|
|
|
- const float * y_df = (const float *) y_ds;
|
|
|
|
|
|
|
|
const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
|
|
const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
|
|
|
|
|
|
@@ -667,19 +694,19 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat(
|
|
|
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
|
|
for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
|
|
- const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
|
|
|
|
|
|
+ const int kqsx = i*(WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
|
|
const int shift = 2 * ((ky % 32) / 8);
|
|
const int shift = 2 * ((ky % 32) / 8);
|
|
const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
|
|
const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
|
|
|
|
|
|
- const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
|
|
|
|
|
|
+ const int vh = x_qh[i*(WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
|
|
const int vlh = (vh << 2) & 0x04040404;
|
|
const int vlh = (vh << 2) & 0x04040404;
|
|
|
|
|
|
v[l] = __vsubss4(vll, vlh);
|
|
v[l] = __vsubss4(vll, vlh);
|
|
}
|
|
}
|
|
|
|
|
|
- const int index_y = j * WARP_SIZE + (k0*QR3_K) % WARP_SIZE;
|
|
|
|
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
|
|
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
|
|
- v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
|
|
|
|
|
|
+ v, &y_qs[j*MMQ_TILE_Y_K + (k0*QR3_K) % WARP_SIZE], scales,
|
|
|
|
+ x_dmf[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[j*MMQ_TILE_Y_K + ((k0*QR3_K) % WARP_SIZE)/QI8_1]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -746,10 +773,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
template <int mmq_x, int mmq_y, int nwarps>
|
|
template <int mmq_x, int mmq_y, int nwarps>
|
|
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat(
|
|
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat(
|
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
|
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
|
|
|
|
|
|
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
|
|
|
|
|
GGML_UNUSED(x_qh);
|
|
GGML_UNUSED(x_qh);
|
|
|
|
|
|
|
|
+ const int * y_qs = (const int *) y + 4;
|
|
|
|
+ const half2 * y_ds = (const half2 *) y;
|
|
|
|
+
|
|
#pragma unroll
|
|
#pragma unroll
|
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
const int j = j0 + threadIdx.y;
|
|
const int j = j0 + threadIdx.y;
|
|
@@ -760,9 +790,9 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat(
|
|
|
|
|
|
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8);
|
|
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8);
|
|
|
|
|
|
- const int index_y = j * WARP_SIZE + (QR4_K*k0) % WARP_SIZE;
|
|
|
|
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
|
|
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
|
|
- &x_ql[i * (WARP_SIZE + 1) + k0], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
|
|
|
|
|
|
+ &x_ql[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR4_K*k0) % WARP_SIZE], sc, sc+8,
|
|
|
|
+ x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + ((QR4_K*k0) % WARP_SIZE)/QI8_1]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -842,10 +872,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
template <int mmq_x, int mmq_y, int nwarps>
|
|
template <int mmq_x, int mmq_y, int nwarps>
|
|
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat(
|
|
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat(
|
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
|
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
|
|
|
|
|
|
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
|
|
|
|
|
GGML_UNUSED(x_qh);
|
|
GGML_UNUSED(x_qh);
|
|
|
|
|
|
|
|
+ const int * y_qs = (const int *) y + 4;
|
|
|
|
+ const half2 * y_ds = (const half2 *) y;
|
|
|
|
+
|
|
#pragma unroll
|
|
#pragma unroll
|
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
const int j = j0 + threadIdx.y;
|
|
const int j = j0 + threadIdx.y;
|
|
@@ -856,10 +889,9 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat(
|
|
|
|
|
|
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
|
|
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
|
|
|
|
|
|
- const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k0;
|
|
|
|
- const int index_y = j * WARP_SIZE + (QR5_K*k0) % WARP_SIZE;
|
|
|
|
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
|
|
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
|
|
- &x_ql[index_x], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
|
|
|
|
|
|
+ &x_ql[i*(QR5_K*WARP_SIZE + 1) + QR5_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR5_K*k0) % WARP_SIZE], sc, sc+8,
|
|
|
|
+ x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + ((QR5_K*k0) % WARP_SIZE)/QI8_1]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -932,10 +964,14 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
template <int mmq_x, int mmq_y, int nwarps>
|
|
template <int mmq_x, int mmq_y, int nwarps>
|
|
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
|
|
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
|
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
|
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
|
|
|
|
|
|
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
|
|
|
|
|
GGML_UNUSED(x_qh);
|
|
GGML_UNUSED(x_qh);
|
|
|
|
|
|
|
|
+ const float * x_dmf = (const float *) x_dm;
|
|
|
|
+ const int * y_qs = (const int *) y + 4;
|
|
|
|
+ const float * y_df = (const float *) y;
|
|
|
|
+
|
|
#pragma unroll
|
|
#pragma unroll
|
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
const int j = j0 + threadIdx.y;
|
|
const int j = j0 + threadIdx.y;
|
|
@@ -944,15 +980,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
|
|
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
|
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
|
const int i = i0 + threadIdx.x;
|
|
const int i = i0 + threadIdx.x;
|
|
|
|
|
|
- const float * x_dmf = (const float *) x_dm;
|
|
|
|
- const float * y_df = (const float *) y_ds;
|
|
|
|
-
|
|
|
|
const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
|
|
const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
|
|
|
|
|
|
- const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k0;
|
|
|
|
- const int index_y = j * WARP_SIZE + (QR6_K*k0) % WARP_SIZE;
|
|
|
|
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
|
|
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
|
|
- &x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
|
|
|
|
|
|
+ &x_ql[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc,
|
|
|
|
+ x_dmf[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -964,7 +996,6 @@ struct mmq_type_traits;
|
|
|
|
|
|
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
|
|
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
|
|
- static constexpr bool need_sum = true;
|
|
|
|
static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
|
|
static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
|
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
|
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
|
|
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
|
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
|
@@ -972,7 +1003,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
|
|
|
|
|
|
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
|
|
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
|
|
- static constexpr bool need_sum = true;
|
|
|
|
static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
|
|
static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
|
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
|
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
|
|
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
|
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
|
@@ -980,7 +1010,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
|
|
|
|
|
|
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
|
|
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
|
|
- static constexpr bool need_sum = false;
|
|
|
|
static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
|
|
static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
|
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
|
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
|
|
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
|
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
|
@@ -988,7 +1017,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
|
|
|
|
|
|
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
|
|
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
|
|
- static constexpr bool need_sum = true;
|
|
|
|
static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
|
|
static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
|
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
|
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
|
|
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
|
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
|
@@ -996,7 +1024,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
|
|
|
|
|
|
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
|
|
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
|
|
- static constexpr bool need_sum = false;
|
|
|
|
static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
|
|
static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
|
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
|
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
|
|
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
|
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
|
@@ -1004,7 +1031,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
|
|
|
|
|
|
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
|
|
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
|
|
- static constexpr bool need_sum = false;
|
|
|
|
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
|
|
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
|
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
|
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
|
|
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
|
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
|
@@ -1012,7 +1038,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
|
|
|
|
|
|
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
|
|
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
|
|
- static constexpr bool need_sum = false;
|
|
|
|
static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
|
|
static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
|
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
|
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
|
|
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
|
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
|
@@ -1020,7 +1045,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
|
|
|
|
|
|
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
|
|
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
|
|
- static constexpr bool need_sum = true;
|
|
|
|
static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
|
|
static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
|
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
|
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
|
|
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
|
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
|
@@ -1028,7 +1052,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
|
|
|
|
|
|
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
|
|
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
|
|
- static constexpr bool need_sum = true;
|
|
|
|
static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
|
|
static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
|
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
|
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
|
|
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
|
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
|
@@ -1036,12 +1059,36 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
|
|
|
|
|
|
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
|
|
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
|
|
- static constexpr bool need_sum = false;
|
|
|
|
static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
|
|
static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
|
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
|
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
|
|
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
|
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
|
};
|
|
};
|
|
|
|
|
|
|
|
+static int mmq_need_sum(const ggml_type type_x) {
|
|
|
|
+ switch (type_x) {
|
|
|
|
+ case GGML_TYPE_Q4_0:
|
|
|
|
+ case GGML_TYPE_Q4_1:
|
|
|
|
+ return true;
|
|
|
|
+ case GGML_TYPE_Q5_0:
|
|
|
|
+ return false;
|
|
|
|
+ case GGML_TYPE_Q5_1:
|
|
|
|
+ return true;
|
|
|
|
+ case GGML_TYPE_Q8_0:
|
|
|
|
+ case GGML_TYPE_Q2_K:
|
|
|
|
+ case GGML_TYPE_Q3_K:
|
|
|
|
+ return false;
|
|
|
|
+ case GGML_TYPE_Q4_K:
|
|
|
|
+ case GGML_TYPE_Q5_K:
|
|
|
|
+ return true;
|
|
|
|
+ case GGML_TYPE_Q6_K:
|
|
|
|
+ return false;
|
|
|
|
+ default:
|
|
|
|
+ GGML_ASSERT(false);
|
|
|
|
+ break;
|
|
|
|
+ }
|
|
|
|
+ return false;
|
|
|
|
+}
|
|
|
|
+
|
|
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
|
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
|
#if defined(RDNA3) || defined(RDNA2)
|
|
#if defined(RDNA3) || defined(RDNA2)
|
|
@@ -1056,7 +1103,7 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
|
static __global__ void mul_mat_q(
|
|
static __global__ void mul_mat_q(
|
|
const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst,
|
|
const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst,
|
|
- const int ne00, const int ne01, const int stride00, const int ne10, const int ne11, const int ne0) {
|
|
|
|
|
|
+ const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
|
|
|
|
|
|
// Skip unused template specializations for faster compilation:
|
|
// Skip unused template specializations for faster compilation:
|
|
if (mmq_x > get_mmq_x_max_device()) {
|
|
if (mmq_x > get_mmq_x_max_device()) {
|
|
@@ -1068,7 +1115,6 @@ static __global__ void mul_mat_q(
|
|
constexpr int qr = ggml_cuda_type_traits<type>::qr;
|
|
constexpr int qr = ggml_cuda_type_traits<type>::qr;
|
|
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
|
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
|
constexpr int mmq_y = get_mmq_y_device(mmq_x);
|
|
constexpr int mmq_y = get_mmq_y_device(mmq_x);
|
|
- constexpr bool need_sum = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::need_sum;
|
|
|
|
constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
|
|
constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
|
|
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
|
|
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
|
|
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot;
|
|
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot;
|
|
@@ -1080,62 +1126,38 @@ static __global__ void mul_mat_q(
|
|
half2 * tile_x_dm = (half2 *) (tile_x_ql + txs.ql);
|
|
half2 * tile_x_dm = (half2 *) (tile_x_ql + txs.ql);
|
|
int * tile_x_qh = (int *) (tile_x_dm + txs.dm);
|
|
int * tile_x_qh = (int *) (tile_x_dm + txs.dm);
|
|
int * tile_x_sc = (int *) (tile_x_qh + txs.qh);
|
|
int * tile_x_sc = (int *) (tile_x_qh + txs.qh);
|
|
- int * tile_y_qs = (int *) (tile_x_sc + txs.sc); // [mmq_x * WARP_SIZE]
|
|
|
|
- half2 * tile_y_ds = (half2 *) (tile_y_qs + mmq_x*WARP_SIZE); // [mmq_x * WARP_SIZE/QI8_1];
|
|
|
|
-
|
|
|
|
- const block_q8_1 * y = (const block_q8_1 *) yc;
|
|
|
|
|
|
+ int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)]
|
|
|
|
|
|
const int blocks_per_row_x = ne00 / qk;
|
|
const int blocks_per_row_x = ne00 / qk;
|
|
- const int blocks_per_col_y = ne10 / QK8_1;
|
|
|
|
const int blocks_per_warp = WARP_SIZE / qi;
|
|
const int blocks_per_warp = WARP_SIZE / qi;
|
|
|
|
|
|
const int & ne1 = ne11;
|
|
const int & ne1 = ne11;
|
|
|
|
|
|
const int tile_x_max_i = ne01 - blockIdx.x*mmq_y - 1;
|
|
const int tile_x_max_i = ne01 - blockIdx.x*mmq_y - 1;
|
|
|
|
|
|
|
|
+ const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
|
|
|
|
+
|
|
float sum[(mmq_x/nwarps) * (mmq_y/WARP_SIZE)] = {0.0f};
|
|
float sum[(mmq_x/nwarps) * (mmq_y/WARP_SIZE)] = {0.0f};
|
|
|
|
|
|
for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
|
|
for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
|
|
|
|
|
|
- load_tiles(x, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, stride00*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride00);
|
|
|
|
|
|
+ load_tiles(x, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, stride01*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride01);
|
|
|
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
for (int kr = 0; kr < qr; ++kr) {
|
|
for (int kr = 0; kr < qr; ++kr) {
|
|
- const int kqs = kr*WARP_SIZE + threadIdx.x;
|
|
|
|
- const int kbxd = kqs / QI8_1;
|
|
|
|
-
|
|
|
|
|
|
+ const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + kr*sizeof(block_q8_1_mmq)/sizeof(int));
|
|
#pragma unroll
|
|
#pragma unroll
|
|
- for (int i0 = 0; i0 < mmq_x; i0 += nwarps) {
|
|
|
|
- const int i = min(blockIdx.y*mmq_x + threadIdx.y + i0, ne11-1); // to prevent out-of-bounds memory accesses
|
|
|
|
-
|
|
|
|
- const block_q8_1 * by0 = &y[i*blocks_per_col_y + kb0 * (qk/QK8_1) + kbxd];
|
|
|
|
|
|
+ for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
|
|
|
|
+ int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
|
|
|
|
|
|
- const int index_y = (i0 + threadIdx.y) * WARP_SIZE + kqs % WARP_SIZE;
|
|
|
|
- tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1);
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
-#pragma unroll
|
|
|
|
- for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
|
|
|
|
- const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x;
|
|
|
|
- const int kby = threadIdx.x % (WARP_SIZE/QI8_1);
|
|
|
|
- const int i_y_eff = min(blockIdx.y*mmq_x + ids, ne11-1);
|
|
|
|
-
|
|
|
|
- // if the sum is not needed it's faster to transform the scale to f32 ahead of time
|
|
|
|
- const half2 * dsi_src = &y[i_y_eff*blocks_per_col_y + kb0 * (qk/QK8_1) + kr*(WARP_SIZE/QI8_1) + kby].ds;
|
|
|
|
- half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby];
|
|
|
|
- if (need_sum) {
|
|
|
|
- *dsi_dst = *dsi_src;
|
|
|
|
- } else {
|
|
|
|
- float * dfi_dst = (float *) dsi_dst;
|
|
|
|
- *dfi_dst = __low2float(*dsi_src);
|
|
|
|
- }
|
|
|
|
|
|
+ tile_y[l] = by0[l];
|
|
}
|
|
}
|
|
|
|
|
|
__syncthreads();
|
|
__syncthreads();
|
|
|
|
|
|
// #pragma unroll // unrolling this loop causes too much register pressure
|
|
// #pragma unroll // unrolling this loop causes too much register pressure
|
|
for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
|
|
for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
|
|
- vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, sum, k0);
|
|
|
|
|
|
+ vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y, sum, k0);
|
|
}
|
|
}
|
|
|
|
|
|
__syncthreads();
|
|
__syncthreads();
|
|
@@ -1165,8 +1187,8 @@ static __global__ void mul_mat_q(
|
|
|
|
|
|
struct mmq_args {
|
|
struct mmq_args {
|
|
const char * x; const char * y; float * dst;
|
|
const char * x; const char * y; float * dst;
|
|
- int64_t ne00; int64_t ne01; int64_t stride00;
|
|
|
|
- int64_t ne10; int64_t ne11;
|
|
|
|
|
|
+ int64_t ne00; int64_t ne01; int64_t stride01;
|
|
|
|
+ int64_t ne10; int64_t ne11; int64_t stride11;
|
|
int64_t ne0;
|
|
int64_t ne0;
|
|
};
|
|
};
|
|
|
|
|
|
@@ -1184,7 +1206,7 @@ static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
|
|
const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
|
|
const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
|
|
const int shmem_x = txs.ql*sizeof(int) + txs.dm*sizeof(half2) + txs.qh*sizeof(int) + txs.sc*sizeof(int);
|
|
const int shmem_x = txs.ql*sizeof(int) + txs.dm*sizeof(half2) + txs.qh*sizeof(int) + txs.sc*sizeof(int);
|
|
const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
|
|
const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
|
|
- const int shmem = shmem_x + shmem_y;
|
|
|
|
|
|
+ const int shmem = shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int));
|
|
|
|
|
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
|
static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
@@ -1198,11 +1220,11 @@ static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
|
|
if (args.ne01 % mmq_y == 0) {
|
|
if (args.ne01 % mmq_y == 0) {
|
|
const bool need_check = false;
|
|
const bool need_check = false;
|
|
mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
|
|
mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
|
|
- (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride00, args.ne10, args.ne11, args.ne0);
|
|
|
|
|
|
+ (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
|
|
} else {
|
|
} else {
|
|
const bool need_check = true;
|
|
const bool need_check = true;
|
|
mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
|
|
mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
|
|
- (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride00, args.ne10, args.ne11, args.ne0);
|
|
|
|
|
|
+ (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|