|
@@ -1,5 +1,5 @@
|
|
|
/**
|
|
|
- * llama.cpp - commit ba1cb19cdd0d92e012e0f6e009e0620f854b6afd - do not edit this file
|
|
|
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
|
|
|
*
|
|
|
* MIT License
|
|
|
*
|
|
@@ -27,9 +27,9 @@
|
|
|
#include "common.cuh"
|
|
|
#include "mmv.cuh"
|
|
|
|
|
|
-template <typename type_acc, int block_size>
|
|
|
+template <typename T, typename type_acc, int block_size>
|
|
|
static __global__ void mul_mat_vec(
|
|
|
- const half * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
|
|
|
+ const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
|
|
|
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
|
|
|
const int64_t row = blockIdx.x;
|
|
|
const int64_t channel = blockIdx.z;
|
|
@@ -39,7 +39,6 @@ static __global__ void mul_mat_vec(
|
|
|
y += channel *stride_channel_y;
|
|
|
dst += channel *stride_channel_dst;
|
|
|
|
|
|
- const half2 * x2 = (const half2 *) x;
|
|
|
const float2 * y2 = (const float2 *) y;
|
|
|
|
|
|
extern __shared__ char data_mmv[];
|
|
@@ -54,28 +53,44 @@ static __global__ void mul_mat_vec(
|
|
|
|
|
|
float sumf;
|
|
|
|
|
|
- if (std::is_same<type_acc, float>::value) {
|
|
|
- sumf = 0.0f;
|
|
|
+ if constexpr (std::is_same<T, half>::value) {
|
|
|
+ const half2 * x2 = (const half2 *) x;
|
|
|
|
|
|
- for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
|
|
|
- const float2 tmpx = __half22float2(x2[col2]);
|
|
|
- const float2 tmpy = y2[col2];
|
|
|
- sumf += tmpx.x * tmpy.x;
|
|
|
- sumf += tmpx.y * tmpy.y;
|
|
|
- }
|
|
|
- } else {
|
|
|
+ if (std::is_same<type_acc, float>::value) {
|
|
|
+ sumf = 0.0f;
|
|
|
+
|
|
|
+ for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
|
|
|
+ const float2 tmpx = __half22float2(x2[col2]);
|
|
|
+ const float2 tmpy = y2[col2];
|
|
|
+ sumf += tmpx.x * tmpy.x;
|
|
|
+ sumf += tmpx.y * tmpy.y;
|
|
|
+ }
|
|
|
+ } else {
|
|
|
#ifdef FP16_AVAILABLE
|
|
|
- half2 sumh2 = make_half2(0.0f, 0.0f);
|
|
|
+ half2 sumh2 = make_half2(0.0f, 0.0f);
|
|
|
|
|
|
- for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
|
|
|
- const float2 tmp = y2[col2];
|
|
|
- sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
|
|
|
- }
|
|
|
+ for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
|
|
|
+ const float2 tmp = y2[col2];
|
|
|
+ sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
|
|
|
+ }
|
|
|
|
|
|
- sumf = __low2float(sumh2) + __high2float(sumh2);
|
|
|
+ sumf = __low2float(sumh2) + __high2float(sumh2);
|
|
|
#else
|
|
|
- NO_DEVICE_CODE;
|
|
|
+ NO_DEVICE_CODE;
|
|
|
#endif // FP16_AVAILABLE
|
|
|
+ }
|
|
|
+ } else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
|
|
+ const int * x2 = (const int *) x;
|
|
|
+ sumf = 0.0f;
|
|
|
+
|
|
|
+ for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
|
|
|
+ const int tmpx = x2[col2];
|
|
|
+ const float2 tmpy = y2[col2];
|
|
|
+ sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
|
|
|
+ sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ static_assert(std::is_same<T, void>::value, "unsupported type");
|
|
|
}
|
|
|
|
|
|
sumf = warp_reduce_sum(sumf);
|
|
@@ -97,9 +112,9 @@ static __global__ void mul_mat_vec(
|
|
|
dst[row] = sumf;
|
|
|
}
|
|
|
|
|
|
-template <typename type_acc>
|
|
|
+template <typename T, typename type_acc>
|
|
|
static void launch_mul_mat_vec_cuda(
|
|
|
- const half * x, const float * y, float * dst,
|
|
|
+ const T * x, const float * y, float * dst,
|
|
|
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
|
|
|
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
|
|
cudaStream_t stream) {
|
|
@@ -123,35 +138,35 @@ static void launch_mul_mat_vec_cuda(
|
|
|
const dim3 block_dims(block_size_best, 1, 1);
|
|
|
switch (block_size_best) {
|
|
|
case 32: {
|
|
|
- mul_mat_vec<type_acc, 32><<<block_nums, block_dims, smem, stream>>>
|
|
|
+ mul_mat_vec<T, type_acc, 32><<<block_nums, block_dims, smem, stream>>>
|
|
|
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
|
|
} break;
|
|
|
case 64: {
|
|
|
- mul_mat_vec<type_acc, 64><<<block_nums, block_dims, smem, stream>>>
|
|
|
+ mul_mat_vec<T, type_acc, 64><<<block_nums, block_dims, smem, stream>>>
|
|
|
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
|
|
} break;
|
|
|
case 96: {
|
|
|
- mul_mat_vec<type_acc, 96><<<block_nums, block_dims, smem, stream>>>
|
|
|
+ mul_mat_vec<T, type_acc, 96><<<block_nums, block_dims, smem, stream>>>
|
|
|
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
|
|
} break;
|
|
|
case 128: {
|
|
|
- mul_mat_vec<type_acc, 128><<<block_nums, block_dims, smem, stream>>>
|
|
|
+ mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
|
|
|
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
|
|
} break;
|
|
|
case 160: {
|
|
|
- mul_mat_vec<type_acc, 160><<<block_nums, block_dims, smem, stream>>>
|
|
|
+ mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
|
|
|
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
|
|
} break;
|
|
|
case 192: {
|
|
|
- mul_mat_vec<type_acc, 192><<<block_nums, block_dims, smem, stream>>>
|
|
|
+ mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
|
|
|
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
|
|
} break;
|
|
|
case 224: {
|
|
|
- mul_mat_vec<type_acc, 224><<<block_nums, block_dims, smem, stream>>>
|
|
|
+ mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
|
|
|
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
|
|
} break;
|
|
|
case 256: {
|
|
|
- mul_mat_vec<type_acc, 256><<<block_nums, block_dims, smem, stream>>>
|
|
|
+ mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
|
|
|
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
|
|
} break;
|
|
|
default: {
|
|
@@ -160,25 +175,25 @@ static void launch_mul_mat_vec_cuda(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+template<typename T>
|
|
|
static void mul_mat_vec_cuda(
|
|
|
- const half * x, const float * y, float * dst,
|
|
|
+ const T * x, const float * y, float * dst,
|
|
|
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
|
|
|
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
|
|
enum ggml_prec prec, cudaStream_t stream) {
|
|
|
switch (prec) {
|
|
|
case GGML_PREC_DEFAULT: {
|
|
|
- launch_mul_mat_vec_cuda<half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
|
|
|
+ launch_mul_mat_vec_cuda<T, half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
|
|
|
stride_channel_x, stride_channel_y, stride_channel_dst, stream);
|
|
|
} break;
|
|
|
case GGML_PREC_F32: {
|
|
|
- launch_mul_mat_vec_cuda<float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
|
|
|
+ launch_mul_mat_vec_cuda<T, float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
|
|
|
stride_channel_x, stride_channel_y, stride_channel_dst, stream);
|
|
|
} break;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
|
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
|
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
|
|
|
@@ -190,7 +205,6 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
|
|
|
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
|
|
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
|
|
|
|
|
- const half * src0_d = (const half *) src0->data;
|
|
|
const float * src1_d = (const float *) src1->data;
|
|
|
float * dst_d = (float *) dst->data;
|
|
|
|
|
@@ -207,7 +221,20 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
|
|
|
const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type);
|
|
|
const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type);
|
|
|
|
|
|
- mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
|
|
|
+ switch (src0->type) {
|
|
|
+ case GGML_TYPE_F16: {
|
|
|
+ const half * src0_d = (const half *) src0->data;
|
|
|
+ mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
|
|
|
+ channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
|
|
|
+ } break;
|
|
|
+ case GGML_TYPE_BF16: {
|
|
|
+ const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
|
|
|
+ mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
|
|
|
+ channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
|
|
|
+ } break;
|
|
|
+ default:
|
|
|
+ GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
void ggml_cuda_op_mul_mat_vec(
|
|
@@ -216,7 +243,6 @@ void ggml_cuda_op_mul_mat_vec(
|
|
|
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
|
|
const int64_t src1_padded_row_size, cudaStream_t stream) {
|
|
|
|
|
|
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
|
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
|
|
|
@@ -237,8 +263,20 @@ void ggml_cuda_op_mul_mat_vec(
|
|
|
const int64_t channel_stride_y = 0;
|
|
|
const int64_t channel_stride_dst = 0;
|
|
|
|
|
|
- mul_mat_vec_cuda((const half *) src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
|
|
|
- nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
|
|
|
+ switch (src0->type) {
|
|
|
+ case GGML_TYPE_F16: {
|
|
|
+ const half * src0_d = (const half *) src0_dd_i;
|
|
|
+ mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
|
|
|
+ nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
|
|
|
+ } break;
|
|
|
+ case GGML_TYPE_BF16: {
|
|
|
+ const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
|
|
|
+ mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
|
|
|
+ nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
|
|
|
+ } break;
|
|
|
+ default:
|
|
|
+ GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
|
|
|
+ }
|
|
|
|
|
|
GGML_UNUSED(ctx);
|
|
|
GGML_UNUSED(src1);
|