123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287 |
- /**
- * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
- *
- * MIT License
- *
- * Copyright (c) 2023-2024 The ggml authors
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
- #include "common.cuh"
- #include "mmv.cuh"
- template <typename T, typename type_acc, int block_size>
- static __global__ void mul_mat_vec(
- const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
- const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
- const int64_t row = blockIdx.x;
- const int64_t channel = blockIdx.z;
- const int tid = threadIdx.x;
- x += (channel/channel_ratio)*stride_channel_x + row*stride_row;
- y += channel *stride_channel_y;
- dst += channel *stride_channel_dst;
- const float2 * y2 = (const float2 *) y;
- extern __shared__ char data_mmv[];
- float * buf_iw = (float *) data_mmv;
- if (block_size > WARP_SIZE) {
- if (tid < WARP_SIZE) {
- buf_iw[tid] = 0.0f;
- }
- __syncthreads();
- }
- float sumf;
- if constexpr (std::is_same<T, half>::value) {
- const half2 * x2 = (const half2 *) x;
- if (std::is_same<type_acc, float>::value) {
- sumf = 0.0f;
- for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
- const float2 tmpx = __half22float2(x2[col2]);
- const float2 tmpy = y2[col2];
- sumf += tmpx.x * tmpy.x;
- sumf += tmpx.y * tmpy.y;
- }
- } else {
- #ifdef FP16_AVAILABLE
- half2 sumh2 = make_half2(0.0f, 0.0f);
- for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
- const float2 tmp = y2[col2];
- sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
- }
- sumf = __low2float(sumh2) + __high2float(sumh2);
- #else
- NO_DEVICE_CODE;
- #endif // FP16_AVAILABLE
- }
- } else if constexpr (std::is_same<T, nv_bfloat16>::value) {
- const int * x2 = (const int *) x;
- sumf = 0.0f;
- for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
- const int tmpx = x2[col2];
- const float2 tmpy = y2[col2];
- sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
- sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
- }
- } else {
- static_assert(std::is_same<T, void>::value, "unsupported type");
- }
- sumf = warp_reduce_sum(sumf);
- if (block_size > WARP_SIZE) {
- buf_iw[tid/WARP_SIZE] = sumf;
- __syncthreads();
- if (tid >= WARP_SIZE) {
- return;
- }
- sumf = buf_iw[tid];
- sumf = warp_reduce_sum(sumf);
- }
- if (tid != 0) {
- return;
- }
- dst[row] = sumf;
- }
- template <typename T, typename type_acc>
- static void launch_mul_mat_vec_cuda(
- const T * x, const float * y, float * dst,
- const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
- const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
- cudaStream_t stream) {
- GGML_ASSERT(ncols % 2 == 0);
- GGML_ASSERT(stride_row % 2 == 0);
- GGML_ASSERT(nchannels_y % nchannels_x == 0);
- const int64_t channel_ratio = nchannels_y / nchannels_x;
- int64_t block_size_best = WARP_SIZE;
- int64_t niter_best = (ncols + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
- for (int64_t block_size = 2*WARP_SIZE; block_size <= 256; block_size += WARP_SIZE) {
- const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
- if (niter < niter_best) {
- niter_best = niter;
- block_size_best = block_size;
- }
- }
- const int smem = WARP_SIZE*sizeof(float);
- const dim3 block_nums(nrows, 1, nchannels_y);
- const dim3 block_dims(block_size_best, 1, 1);
- switch (block_size_best) {
- case 32: {
- mul_mat_vec<T, type_acc, 32><<<block_nums, block_dims, smem, stream>>>
- (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
- } break;
- case 64: {
- mul_mat_vec<T, type_acc, 64><<<block_nums, block_dims, smem, stream>>>
- (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
- } break;
- case 96: {
- mul_mat_vec<T, type_acc, 96><<<block_nums, block_dims, smem, stream>>>
- (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
- } break;
- case 128: {
- mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
- (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
- } break;
- case 160: {
- mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
- (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
- } break;
- case 192: {
- mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
- (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
- } break;
- case 224: {
- mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
- (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
- } break;
- case 256: {
- mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
- (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
- } break;
- default: {
- GGML_ABORT("fatal error");
- } break;
- }
- }
- template<typename T>
- static void mul_mat_vec_cuda(
- const T * x, const float * y, float * dst,
- const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
- const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
- enum ggml_prec prec, cudaStream_t stream) {
- switch (prec) {
- case GGML_PREC_DEFAULT: {
- launch_mul_mat_vec_cuda<T, half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
- stride_channel_x, stride_channel_y, stride_channel_dst, stream);
- } break;
- case GGML_PREC_F32: {
- launch_mul_mat_vec_cuda<T, float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
- stride_channel_x, stride_channel_y, stride_channel_dst, stream);
- } break;
- }
- }
- void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
- const int64_t ne00 = src0->ne[0];
- const int64_t ne01 = src0->ne[1];
- GGML_ASSERT(src1->ne[1] == 1);
- const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
- const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
- const float * src1_d = (const float *) src1->data;
- float * dst_d = (float *) dst->data;
- const int64_t ne02 = src0->ne[2];
- const int64_t ne12 = src1->ne[2];
- GGML_ASSERT(dst->ne[2] == ne12);
- GGML_ASSERT(src0->ne[3] == 1);
- GGML_ASSERT(src1->ne[3] == 1);
- GGML_ASSERT( dst->ne[3] == 1);
- const int64_t stride_row = src0->nb[1] / ggml_type_size(src0->type);
- const int64_t channel_stride_x = src0->nb[2] / ggml_type_size(src0->type);
- const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type);
- const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type);
- switch (src0->type) {
- case GGML_TYPE_F16: {
- const half * src0_d = (const half *) src0->data;
- mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
- channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
- } break;
- case GGML_TYPE_BF16: {
- const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
- mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
- channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
- } break;
- default:
- GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
- }
- }
- void ggml_cuda_op_mul_mat_vec(
- ggml_backend_cuda_context & ctx,
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
- const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
- const int64_t src1_padded_row_size, cudaStream_t stream) {
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
- const int64_t ne00 = src0->ne[0];
- const int64_t row_diff = row_high - row_low;
- GGML_ASSERT(src1_ncols == 1);
- const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
- const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
- // ggml_cuda_op provides single, contiguous matrices
- const int64_t stride_row = ne00;
- const int64_t nchannels_x = 1;
- const int64_t nchannels_y = 1;
- const int64_t channel_stride_x = 0;
- const int64_t channel_stride_y = 0;
- const int64_t channel_stride_dst = 0;
- switch (src0->type) {
- case GGML_TYPE_F16: {
- const half * src0_d = (const half *) src0_dd_i;
- mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
- nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
- } break;
- case GGML_TYPE_BF16: {
- const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
- mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
- nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
- } break;
- default:
- GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
- }
- GGML_UNUSED(ctx);
- GGML_UNUSED(src1);
- GGML_UNUSED(dst);
- GGML_UNUSED(src1_ddq_i);
- GGML_UNUSED(src1_ncols);
- GGML_UNUSED(src1_padded_row_size);
- }
|