mmv.cu 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. /**
  2. * llama.cpp - commit ba1cb19cdd0d92e012e0f6e009e0620f854b6afd - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #include "common.cuh"
  27. #include "mmv.cuh"
  28. template <typename type_acc, int block_size>
  29. static __global__ void mul_mat_vec(
  30. const half * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
  31. const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
  32. const int64_t row = blockIdx.x;
  33. const int64_t channel = blockIdx.z;
  34. const int tid = threadIdx.x;
  35. x += (channel/channel_ratio)*stride_channel_x + row*stride_row;
  36. y += channel *stride_channel_y;
  37. dst += channel *stride_channel_dst;
  38. const half2 * x2 = (const half2 *) x;
  39. const float2 * y2 = (const float2 *) y;
  40. extern __shared__ char data_mmv[];
  41. float * buf_iw = (float *) data_mmv;
  42. if (block_size > WARP_SIZE) {
  43. if (tid < WARP_SIZE) {
  44. buf_iw[tid] = 0.0f;
  45. }
  46. __syncthreads();
  47. }
  48. float sumf;
  49. if (std::is_same<type_acc, float>::value) {
  50. sumf = 0.0f;
  51. for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
  52. const float2 tmpx = __half22float2(x2[col2]);
  53. const float2 tmpy = y2[col2];
  54. sumf += tmpx.x * tmpy.x;
  55. sumf += tmpx.y * tmpy.y;
  56. }
  57. } else {
  58. #ifdef FP16_AVAILABLE
  59. half2 sumh2 = make_half2(0.0f, 0.0f);
  60. for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
  61. const float2 tmp = y2[col2];
  62. sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
  63. }
  64. sumf = __low2float(sumh2) + __high2float(sumh2);
  65. #else
  66. NO_DEVICE_CODE;
  67. #endif // FP16_AVAILABLE
  68. }
  69. sumf = warp_reduce_sum(sumf);
  70. if (block_size > WARP_SIZE) {
  71. buf_iw[tid/WARP_SIZE] = sumf;
  72. __syncthreads();
  73. if (tid >= WARP_SIZE) {
  74. return;
  75. }
  76. sumf = buf_iw[tid];
  77. sumf = warp_reduce_sum(sumf);
  78. }
  79. if (tid != 0) {
  80. return;
  81. }
  82. dst[row] = sumf;
  83. }
  84. template <typename type_acc>
  85. static void launch_mul_mat_vec_cuda(
  86. const half * x, const float * y, float * dst,
  87. const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
  88. const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
  89. cudaStream_t stream) {
  90. GGML_ASSERT(ncols % 2 == 0);
  91. GGML_ASSERT(stride_row % 2 == 0);
  92. GGML_ASSERT(nchannels_y % nchannels_x == 0);
  93. const int64_t channel_ratio = nchannels_y / nchannels_x;
  94. int64_t block_size_best = WARP_SIZE;
  95. int64_t niter_best = (ncols + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
  96. for (int64_t block_size = 2*WARP_SIZE; block_size <= 256; block_size += WARP_SIZE) {
  97. const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
  98. if (niter < niter_best) {
  99. niter_best = niter;
  100. block_size_best = block_size;
  101. }
  102. }
  103. const int smem = WARP_SIZE*sizeof(float);
  104. const dim3 block_nums(nrows, 1, nchannels_y);
  105. const dim3 block_dims(block_size_best, 1, 1);
  106. switch (block_size_best) {
  107. case 32: {
  108. mul_mat_vec<type_acc, 32><<<block_nums, block_dims, smem, stream>>>
  109. (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
  110. } break;
  111. case 64: {
  112. mul_mat_vec<type_acc, 64><<<block_nums, block_dims, smem, stream>>>
  113. (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
  114. } break;
  115. case 96: {
  116. mul_mat_vec<type_acc, 96><<<block_nums, block_dims, smem, stream>>>
  117. (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
  118. } break;
  119. case 128: {
  120. mul_mat_vec<type_acc, 128><<<block_nums, block_dims, smem, stream>>>
  121. (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
  122. } break;
  123. case 160: {
  124. mul_mat_vec<type_acc, 160><<<block_nums, block_dims, smem, stream>>>
  125. (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
  126. } break;
  127. case 192: {
  128. mul_mat_vec<type_acc, 192><<<block_nums, block_dims, smem, stream>>>
  129. (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
  130. } break;
  131. case 224: {
  132. mul_mat_vec<type_acc, 224><<<block_nums, block_dims, smem, stream>>>
  133. (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
  134. } break;
  135. case 256: {
  136. mul_mat_vec<type_acc, 256><<<block_nums, block_dims, smem, stream>>>
  137. (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
  138. } break;
  139. default: {
  140. GGML_ABORT("fatal error");
  141. } break;
  142. }
  143. }
  144. static void mul_mat_vec_cuda(
  145. const half * x, const float * y, float * dst,
  146. const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
  147. const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
  148. enum ggml_prec prec, cudaStream_t stream) {
  149. switch (prec) {
  150. case GGML_PREC_DEFAULT: {
  151. launch_mul_mat_vec_cuda<half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
  152. stride_channel_x, stride_channel_y, stride_channel_dst, stream);
  153. } break;
  154. case GGML_PREC_F32: {
  155. launch_mul_mat_vec_cuda<float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
  156. stride_channel_x, stride_channel_y, stride_channel_dst, stream);
  157. } break;
  158. }
  159. }
  160. void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
  161. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  162. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  163. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  164. const int64_t ne00 = src0->ne[0];
  165. const int64_t ne01 = src0->ne[1];
  166. GGML_ASSERT(src1->ne[1] == 1);
  167. const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
  168. const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
  169. const half * src0_d = (const half *) src0->data;
  170. const float * src1_d = (const float *) src1->data;
  171. float * dst_d = (float *) dst->data;
  172. const int64_t ne02 = src0->ne[2];
  173. const int64_t ne12 = src1->ne[2];
  174. GGML_ASSERT(dst->ne[2] == ne12);
  175. GGML_ASSERT(src0->ne[3] == 1);
  176. GGML_ASSERT(src1->ne[3] == 1);
  177. GGML_ASSERT( dst->ne[3] == 1);
  178. const int64_t stride_row = src0->nb[1] / ggml_type_size(src0->type);
  179. const int64_t channel_stride_x = src0->nb[2] / ggml_type_size(src0->type);
  180. const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type);
  181. const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type);
  182. mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
  183. }
  184. void ggml_cuda_op_mul_mat_vec(
  185. ggml_backend_cuda_context & ctx,
  186. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
  187. const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
  188. const int64_t src1_padded_row_size, cudaStream_t stream) {
  189. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  190. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  191. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  192. const int64_t ne00 = src0->ne[0];
  193. const int64_t row_diff = row_high - row_low;
  194. GGML_ASSERT(src1_ncols == 1);
  195. const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
  196. const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
  197. // ggml_cuda_op provides single, contiguous matrices
  198. const int64_t stride_row = ne00;
  199. const int64_t nchannels_x = 1;
  200. const int64_t nchannels_y = 1;
  201. const int64_t channel_stride_x = 0;
  202. const int64_t channel_stride_y = 0;
  203. const int64_t channel_stride_dst = 0;
  204. mul_mat_vec_cuda((const half *) src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
  205. nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
  206. GGML_UNUSED(ctx);
  207. GGML_UNUSED(src1);
  208. GGML_UNUSED(dst);
  209. GGML_UNUSED(src1_ddq_i);
  210. GGML_UNUSED(src1_ncols);
  211. GGML_UNUSED(src1_padded_row_size);
  212. }