mma.cuh 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. /**
  2. * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #include "common.cuh"
  27. struct mma_int_A_I16K4 {
  28. static constexpr int I = 16;
  29. static constexpr int K = 4;
  30. static constexpr int ne = 2;
  31. int x[ne] = {0};
  32. static __device__ __forceinline__ int get_i(const int l) {
  33. const int ret = (l%2) * (I/2) + threadIdx.x / K;
  34. GGML_CUDA_ASSUME(ret >= 0);
  35. GGML_CUDA_ASSUME(ret < I);
  36. return ret;
  37. }
  38. static __device__ __forceinline__ int get_k(const int /* l */) {
  39. const int ret = threadIdx.x % K;
  40. GGML_CUDA_ASSUME(ret >= 0);
  41. GGML_CUDA_ASSUME(ret < K);
  42. return ret;
  43. }
  44. __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
  45. #if defined(INT8_MMA_AVAILABLE)
  46. const int * xs = xs0 + (threadIdx.x%I)*stride;
  47. asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
  48. : "+r"(x[0]), "+r"(x[1])
  49. : "l"(xs));
  50. #else
  51. #pragma unroll
  52. for (int l = 0; l < ne; ++l) {
  53. x[l] = xs0[get_i(l)*stride + get_k(l)];
  54. }
  55. #endif // defined(INT8_MMA_AVAILABLE)
  56. }
  57. };
  58. struct mma_int_A_I16K8 {
  59. static constexpr int I = 16;
  60. static constexpr int K = 8;
  61. static constexpr int ne = 4;
  62. int x[ne] = {0};
  63. static __device__ __forceinline__ int get_i(const int l) {
  64. const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
  65. GGML_CUDA_ASSUME(ret >= 0);
  66. GGML_CUDA_ASSUME(ret < I);
  67. return ret;
  68. }
  69. static __device__ __forceinline__ int get_k(const int l) {
  70. const int ret = (l/2) * (K/2) + threadIdx.x % (K/2);
  71. GGML_CUDA_ASSUME(ret >= 0);
  72. GGML_CUDA_ASSUME(ret < K);
  73. return ret;
  74. }
  75. __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
  76. #if defined(INT8_MMA_AVAILABLE)
  77. const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
  78. asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
  79. : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
  80. : "l"(xs));
  81. #else
  82. #pragma unroll
  83. for (int l = 0; l < ne; ++l) {
  84. x[l] = xs0[get_i(l)*stride + get_k(l)];
  85. }
  86. #endif // defined(INT8_MMA_AVAILABLE)
  87. }
  88. __device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) {
  89. ((mma_int_A_I16K4 *) x)[0].load(xs0, stride);
  90. }
  91. };
  92. struct mma_int_B_J8K4 {
  93. static constexpr int J = 8;
  94. static constexpr int K = 4;
  95. static constexpr int ne = 1;
  96. int x[ne] = {0};
  97. static __device__ __forceinline__ int get_j(const int /* l */) {
  98. const int ret = threadIdx.x / K;
  99. GGML_CUDA_ASSUME(ret >= 0);
  100. GGML_CUDA_ASSUME(ret < J);
  101. return ret;
  102. }
  103. static __device__ __forceinline__ int get_k(const int /* l */) {
  104. const int ret = threadIdx.x % K;
  105. GGML_CUDA_ASSUME(ret >= 0);
  106. GGML_CUDA_ASSUME(ret < K);
  107. return ret;
  108. }
  109. __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
  110. #if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
  111. const int * xs = xs0 + (threadIdx.x%J)*stride;
  112. asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
  113. : "+r"(x[0])
  114. : "l"(xs));
  115. #else
  116. #pragma unroll
  117. for (int l = 0; l < ne; ++l) {
  118. x[l] = xs0[get_j(l)*stride + get_k(l)];
  119. }
  120. #endif // defined(INT8_MMA_AVAILABLE)
  121. }
  122. };
  123. struct mma_int_B_J8K8 {
  124. static constexpr int J = 8;
  125. static constexpr int K = 8;
  126. static constexpr int ne = 2;
  127. int x[ne] = {0};
  128. static __device__ __forceinline__ int get_j(const int /* l */) {
  129. const int ret = threadIdx.x / (K/2);
  130. GGML_CUDA_ASSUME(ret >= 0);
  131. GGML_CUDA_ASSUME(ret < J);
  132. return ret;
  133. }
  134. static __device__ __forceinline__ int get_k(const int l) {
  135. const int ret = l * (K/2) + threadIdx.x % (K/2);
  136. GGML_CUDA_ASSUME(ret >= 0);
  137. GGML_CUDA_ASSUME(ret < K);
  138. return ret;
  139. }
  140. __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
  141. #if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
  142. const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
  143. asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
  144. : "+r"(x[0]), "+r"(x[1])
  145. : "l"(xs));
  146. #else
  147. #pragma unroll
  148. for (int l = 0; l < ne; ++l) {
  149. x[l] = xs0[get_j(l)*stride + get_k(l)];
  150. }
  151. #endif // defined(INT8_MMA_AVAILABLE)
  152. }
  153. };
  154. struct mma_int_C_I16J8 {
  155. static constexpr int I = 16;
  156. static constexpr int J = 8;
  157. static constexpr int ne = 4;
  158. int x[ne] = {0};
  159. static __device__ __forceinline__ int get_i(const int l) {
  160. const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
  161. GGML_CUDA_ASSUME(ret >= 0);
  162. GGML_CUDA_ASSUME(ret < I);
  163. return ret;
  164. }
  165. static __device__ __forceinline__ int get_j(const int l) {
  166. const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
  167. GGML_CUDA_ASSUME(ret >= 0);
  168. GGML_CUDA_ASSUME(ret < J);
  169. return ret;
  170. }
  171. __device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
  172. #ifdef INT8_MMA_AVAILABLE
  173. #if __CUDA_ARCH__ >= CC_AMPERE
  174. asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
  175. : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
  176. : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
  177. #else
  178. // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
  179. asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
  180. : "+r"(x[0]), "+r"(x[1])
  181. : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
  182. asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
  183. : "+r"(x[2]), "+r"(x[3])
  184. : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
  185. #endif // __CUDA_ARCH__ >= CC_AMPERE
  186. #else
  187. GGML_UNUSED(mma_A);
  188. GGML_UNUSED(mma_B);
  189. NO_DEVICE_CODE;
  190. #endif // INT8_MMA_AVAILABLE
  191. }
  192. __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
  193. #ifdef INT8_MMA_AVAILABLE
  194. #if __CUDA_ARCH__ >= CC_AMPERE
  195. asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
  196. : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
  197. : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
  198. #else
  199. // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
  200. asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
  201. : "+r"(x[0]), "+r"(x[1])
  202. : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
  203. asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
  204. : "+r"(x[2]), "+r"(x[3])
  205. : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
  206. asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
  207. : "+r"(x[0]), "+r"(x[1])
  208. : "r"(mma_A.x[2]), "r"(mma_B.x[1]));
  209. asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
  210. : "+r"(x[2]), "+r"(x[3])
  211. : "r"(mma_A.x[3]), "r"(mma_B.x[1]));
  212. #endif // __CUDA_ARCH__ >= CC_AMPERE
  213. #else
  214. GGML_UNUSED(mma_A);
  215. GGML_UNUSED(mma_B);
  216. NO_DEVICE_CODE;
  217. #endif // INT8_MMA_AVAILABLE
  218. }
  219. };