fattn-tile-f16.cu 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. #include "common.cuh"
  2. #include "fattn-common.cuh"
  3. #include "fattn-tile-f16.cuh"
  4. #define FATTN_KQ_STRIDE_TILE_F16 64
  5. template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
  6. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  7. __launch_bounds__(nwarps*WARP_SIZE, 1)
  8. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  9. static __global__ void flash_attn_tile_ext_f16(
  10. const char * __restrict__ Q,
  11. const char * __restrict__ K,
  12. const char * __restrict__ V,
  13. const char * __restrict__ mask,
  14. float * __restrict__ dst,
  15. float2 * __restrict__ dst_meta,
  16. const float scale,
  17. const float max_bias,
  18. const float m0,
  19. const float m1,
  20. const uint32_t n_head_log2,
  21. const int ne00,
  22. const int ne01,
  23. const int ne02,
  24. const int ne03,
  25. const int ne10,
  26. const int ne11,
  27. const int ne12,
  28. const int ne13,
  29. const int ne31,
  30. const int nb31,
  31. const int nb01,
  32. const int nb02,
  33. const int nb03,
  34. const int nb11,
  35. const int nb12,
  36. const int nb13,
  37. const int ne0,
  38. const int ne1,
  39. const int ne2,
  40. const int ne3) {
  41. #if FP16_AVAILABLE
  42. //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
  43. const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
  44. const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
  45. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
  46. const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
  47. const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
  48. const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
  49. const half * maskh = (const half *) mask + ne11*ic0;
  50. const int stride_KV2 = nb11 / sizeof(half2);
  51. const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
  52. const half slopeh = __float2half(slopef);
  53. static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
  54. __shared__ half KQ[ncols*FATTN_KQ_STRIDE_TILE_F16];
  55. half2 * KQ2 = (half2 *) KQ;
  56. __shared__ half2 KV_tmp[FATTN_KQ_STRIDE_TILE_F16][D/2 + 1]; // Pad D to avoid memory bank conflicts.
  57. half kqmax[ncols/nwarps];
  58. #pragma unroll
  59. for (int j0 = 0; j0 < ncols; j0 += nwarps) {
  60. kqmax[j0/nwarps] = -HALF_MAX_HALF;
  61. }
  62. half2 kqsum[ncols/nwarps] = {{0.0f, 0.0f}};
  63. half2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
  64. // Convert Q to half2 and store in registers:
  65. __shared__ half2 Q_h2[ncols][D/2];
  66. #pragma unroll
  67. for (int j0 = 0; j0 < ncols; j0 += nwarps) {
  68. const int j = j0 + threadIdx.y;
  69. #pragma unroll
  70. for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
  71. const int i = i0 + threadIdx.x;
  72. const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
  73. Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
  74. }
  75. }
  76. __syncthreads();
  77. const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F16;
  78. for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F16) {
  79. // Calculate KQ tile and keep track of new maximum KQ values:
  80. half kqmax_new[ncols/nwarps];
  81. #pragma unroll
  82. for (int j = 0; j < ncols/nwarps; ++j) {
  83. kqmax_new[j] = kqmax[j];
  84. }
  85. #pragma unroll
  86. for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += nwarps) {
  87. const int i_KQ = i_KQ_0 + threadIdx.y;
  88. #pragma unroll
  89. for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
  90. const int k_KQ = k_KQ_0 + threadIdx.x;
  91. KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
  92. }
  93. }
  94. __syncthreads();
  95. half2 sum2[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE][ncols/nwarps] = {{{0.0f, 0.0f}}};
  96. #pragma unroll
  97. for (int k_KQ = 0; k_KQ < D/2; ++k_KQ) {
  98. half2 K_k[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE];
  99. half2 Q_k[ncols/nwarps];
  100. #pragma unroll
  101. for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
  102. const int i_KQ = i_KQ_0 + threadIdx.x;
  103. K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ];
  104. }
  105. #pragma unroll
  106. for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
  107. const int j_KQ = j_KQ_0 + threadIdx.y;
  108. Q_k[j_KQ_0/nwarps] = Q_h2[j_KQ][k_KQ];
  109. }
  110. #pragma unroll
  111. for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
  112. #pragma unroll
  113. for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
  114. sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE]*Q_k[j_KQ_0/nwarps];
  115. }
  116. }
  117. }
  118. #pragma unroll
  119. for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
  120. const int i_KQ = i_KQ_0 + threadIdx.x;
  121. #pragma unroll
  122. for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
  123. const int j_KQ = j_KQ_0 + threadIdx.y;
  124. half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
  125. sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
  126. kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
  127. KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F16 + i_KQ] = sum;
  128. }
  129. }
  130. __syncthreads();
  131. #pragma unroll
  132. for (int j0 = 0; j0 < ncols; j0 += nwarps) {
  133. const int j = j0 + threadIdx.y;
  134. kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]);
  135. const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]));
  136. kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
  137. #pragma unroll
  138. for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F16/2; i0 += WARP_SIZE) {
  139. const int i = i0 + threadIdx.x;
  140. const half2 diff = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] - __half2half2(kqmax[j0/nwarps]);
  141. const half2 val = h2exp(diff);
  142. kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + val;
  143. KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] = val;
  144. }
  145. #pragma unroll
  146. for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
  147. VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
  148. }
  149. }
  150. __syncthreads();
  151. #pragma unroll
  152. for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += nwarps) {
  153. const int k = k0 + threadIdx.y;
  154. #pragma unroll
  155. for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
  156. const int i = i0 + threadIdx.x;
  157. KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i];
  158. }
  159. }
  160. __syncthreads();
  161. #pragma unroll
  162. for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += 2) {
  163. half2 V_k[(D/2)/WARP_SIZE][2];
  164. half2 KQ_k[ncols/nwarps];
  165. #pragma unroll
  166. for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
  167. const int i = i0 + threadIdx.x;
  168. V_k[i0/WARP_SIZE][0] = KV_tmp[k0 + 0][i];
  169. V_k[i0/WARP_SIZE][1] = KV_tmp[k0 + 1][i];
  170. }
  171. #pragma unroll
  172. for (int j0 = 0; j0 < ncols; j0 += nwarps) {
  173. const int j = j0 + threadIdx.y;
  174. KQ_k[j0/nwarps] = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + k0/2];
  175. }
  176. #pragma unroll
  177. for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
  178. #pragma unroll
  179. for (int j0 = 0; j0 < ncols; j0 += nwarps) {
  180. VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][0]* __low2half2(KQ_k[j0/nwarps]);
  181. VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][1]*__high2half2(KQ_k[j0/nwarps]);
  182. }
  183. }
  184. }
  185. __syncthreads();
  186. }
  187. #pragma unroll
  188. for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
  189. const int j_VKQ = j_VKQ_0 + threadIdx.y;
  190. half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
  191. kqsum_j = warp_reduce_sum(kqsum_j);
  192. #pragma unroll
  193. for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
  194. const int i0 = i00 + 2*threadIdx.x;
  195. half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
  196. if (parallel_blocks == 1) {
  197. dst_val /= __half2half2(kqsum_j);
  198. }
  199. const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
  200. dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val);
  201. dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
  202. }
  203. if (parallel_blocks != 1 && threadIdx.x == 0) {
  204. dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
  205. }
  206. }
  207. #else
  208. NO_DEVICE_CODE;
  209. #endif // FP16_AVAILABLE
  210. }
  211. template <int cols_per_block, int parallel_blocks>
  212. void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  213. const ggml_tensor * Q = dst->src[0];
  214. switch (Q->ne[0]) {
  215. case 64: {
  216. constexpr int D = 64;
  217. constexpr int nwarps = 8;
  218. fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
  219. launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
  220. } break;
  221. case 128: {
  222. constexpr int D = 128;
  223. constexpr int nwarps = 8;
  224. fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
  225. launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
  226. } break;
  227. default: {
  228. GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
  229. } break;
  230. }
  231. }
  232. void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  233. const ggml_tensor * KQV = dst;
  234. const ggml_tensor * Q = dst->src[0];
  235. const int32_t precision = KQV->op_params[2];
  236. GGML_ASSERT(precision == GGML_PREC_DEFAULT);
  237. if (Q->ne[1] <= 16) {
  238. constexpr int cols_per_block = 16;
  239. constexpr int parallel_blocks = 4;
  240. launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
  241. return;
  242. }
  243. if (Q->ne[1] <= 32) {
  244. constexpr int cols_per_block = 32;
  245. constexpr int parallel_blocks = 4;
  246. launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
  247. return;
  248. }
  249. constexpr int cols_per_block = 32;
  250. constexpr int parallel_blocks = 1;
  251. launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
  252. }