mmv.cu 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. /**
  2. * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #include "common.cuh"
  27. #include "mmv.cuh"
  28. template <typename T, typename type_acc, int block_size>
  29. static __global__ void mul_mat_vec(
  30. const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
  31. const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
  32. const int64_t row = blockIdx.x;
  33. const int64_t channel = blockIdx.z;
  34. const int tid = threadIdx.x;
  35. x += (channel/channel_ratio)*stride_channel_x + row*stride_row;
  36. y += channel *stride_channel_y;
  37. dst += channel *stride_channel_dst;
  38. const float2 * y2 = (const float2 *) y;
  39. extern __shared__ char data_mmv[];
  40. float * buf_iw = (float *) data_mmv;
  41. if (block_size > WARP_SIZE) {
  42. if (tid < WARP_SIZE) {
  43. buf_iw[tid] = 0.0f;
  44. }
  45. __syncthreads();
  46. }
  47. float sumf;
  48. if constexpr (std::is_same<T, half>::value) {
  49. const half2 * x2 = (const half2 *) x;
  50. if (std::is_same<type_acc, float>::value) {
  51. sumf = 0.0f;
  52. for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
  53. const float2 tmpx = __half22float2(x2[col2]);
  54. const float2 tmpy = y2[col2];
  55. sumf += tmpx.x * tmpy.x;
  56. sumf += tmpx.y * tmpy.y;
  57. }
  58. } else {
  59. #ifdef FP16_AVAILABLE
  60. half2 sumh2 = make_half2(0.0f, 0.0f);
  61. for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
  62. const float2 tmp = y2[col2];
  63. sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
  64. }
  65. sumf = __low2float(sumh2) + __high2float(sumh2);
  66. #else
  67. NO_DEVICE_CODE;
  68. #endif // FP16_AVAILABLE
  69. }
  70. } else if constexpr (std::is_same<T, nv_bfloat16>::value) {
  71. const int * x2 = (const int *) x;
  72. sumf = 0.0f;
  73. for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
  74. const int tmpx = x2[col2];
  75. const float2 tmpy = y2[col2];
  76. sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
  77. sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
  78. }
  79. } else {
  80. static_assert(std::is_same<T, void>::value, "unsupported type");
  81. }
  82. sumf = warp_reduce_sum(sumf);
  83. if (block_size > WARP_SIZE) {
  84. buf_iw[tid/WARP_SIZE] = sumf;
  85. __syncthreads();
  86. if (tid >= WARP_SIZE) {
  87. return;
  88. }
  89. sumf = buf_iw[tid];
  90. sumf = warp_reduce_sum(sumf);
  91. }
  92. if (tid != 0) {
  93. return;
  94. }
  95. dst[row] = sumf;
  96. }
  97. template <typename T, typename type_acc>
  98. static void launch_mul_mat_vec_cuda(
  99. const T * x, const float * y, float * dst,
  100. const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
  101. const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
  102. cudaStream_t stream) {
  103. GGML_ASSERT(ncols % 2 == 0);
  104. GGML_ASSERT(stride_row % 2 == 0);
  105. GGML_ASSERT(nchannels_y % nchannels_x == 0);
  106. const int64_t channel_ratio = nchannels_y / nchannels_x;
  107. int64_t block_size_best = WARP_SIZE;
  108. int64_t niter_best = (ncols + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
  109. for (int64_t block_size = 2*WARP_SIZE; block_size <= 256; block_size += WARP_SIZE) {
  110. const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
  111. if (niter < niter_best) {
  112. niter_best = niter;
  113. block_size_best = block_size;
  114. }
  115. }
  116. const int smem = WARP_SIZE*sizeof(float);
  117. const dim3 block_nums(nrows, 1, nchannels_y);
  118. const dim3 block_dims(block_size_best, 1, 1);
  119. switch (block_size_best) {
  120. case 32: {
  121. mul_mat_vec<T, type_acc, 32><<<block_nums, block_dims, smem, stream>>>
  122. (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
  123. } break;
  124. case 64: {
  125. mul_mat_vec<T, type_acc, 64><<<block_nums, block_dims, smem, stream>>>
  126. (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
  127. } break;
  128. case 96: {
  129. mul_mat_vec<T, type_acc, 96><<<block_nums, block_dims, smem, stream>>>
  130. (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
  131. } break;
  132. case 128: {
  133. mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
  134. (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
  135. } break;
  136. case 160: {
  137. mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
  138. (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
  139. } break;
  140. case 192: {
  141. mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
  142. (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
  143. } break;
  144. case 224: {
  145. mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
  146. (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
  147. } break;
  148. case 256: {
  149. mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
  150. (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
  151. } break;
  152. default: {
  153. GGML_ABORT("fatal error");
  154. } break;
  155. }
  156. }
  157. template<typename T>
  158. static void mul_mat_vec_cuda(
  159. const T * x, const float * y, float * dst,
  160. const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
  161. const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
  162. enum ggml_prec prec, cudaStream_t stream) {
  163. switch (prec) {
  164. case GGML_PREC_DEFAULT: {
  165. launch_mul_mat_vec_cuda<T, half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
  166. stride_channel_x, stride_channel_y, stride_channel_dst, stream);
  167. } break;
  168. case GGML_PREC_F32: {
  169. launch_mul_mat_vec_cuda<T, float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
  170. stride_channel_x, stride_channel_y, stride_channel_dst, stream);
  171. } break;
  172. }
  173. }
  174. void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
  175. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  176. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  177. const int64_t ne00 = src0->ne[0];
  178. const int64_t ne01 = src0->ne[1];
  179. GGML_ASSERT(src1->ne[1] == 1);
  180. const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
  181. const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
  182. const float * src1_d = (const float *) src1->data;
  183. float * dst_d = (float *) dst->data;
  184. const int64_t ne02 = src0->ne[2];
  185. const int64_t ne12 = src1->ne[2];
  186. GGML_ASSERT(dst->ne[2] == ne12);
  187. GGML_ASSERT(src0->ne[3] == 1);
  188. GGML_ASSERT(src1->ne[3] == 1);
  189. GGML_ASSERT( dst->ne[3] == 1);
  190. const int64_t stride_row = src0->nb[1] / ggml_type_size(src0->type);
  191. const int64_t channel_stride_x = src0->nb[2] / ggml_type_size(src0->type);
  192. const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type);
  193. const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type);
  194. switch (src0->type) {
  195. case GGML_TYPE_F16: {
  196. const half * src0_d = (const half *) src0->data;
  197. mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
  198. channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
  199. } break;
  200. case GGML_TYPE_BF16: {
  201. const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
  202. mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
  203. channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
  204. } break;
  205. default:
  206. GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
  207. }
  208. }
  209. void ggml_cuda_op_mul_mat_vec(
  210. ggml_backend_cuda_context & ctx,
  211. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
  212. const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
  213. const int64_t src1_padded_row_size, cudaStream_t stream) {
  214. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  215. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  216. const int64_t ne00 = src0->ne[0];
  217. const int64_t row_diff = row_high - row_low;
  218. GGML_ASSERT(src1_ncols == 1);
  219. const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
  220. const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
  221. // ggml_cuda_op provides single, contiguous matrices
  222. const int64_t stride_row = ne00;
  223. const int64_t nchannels_x = 1;
  224. const int64_t nchannels_y = 1;
  225. const int64_t channel_stride_x = 0;
  226. const int64_t channel_stride_y = 0;
  227. const int64_t channel_stride_dst = 0;
  228. switch (src0->type) {
  229. case GGML_TYPE_F16: {
  230. const half * src0_d = (const half *) src0_dd_i;
  231. mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
  232. nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
  233. } break;
  234. case GGML_TYPE_BF16: {
  235. const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
  236. mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
  237. nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
  238. } break;
  239. default:
  240. GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
  241. }
  242. GGML_UNUSED(ctx);
  243. GGML_UNUSED(src1);
  244. GGML_UNUSED(dst);
  245. GGML_UNUSED(src1_ddq_i);
  246. GGML_UNUSED(src1_ncols);
  247. GGML_UNUSED(src1_padded_row_size);
  248. }