common.cuh 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683
  1. #pragma once
  2. #include "ggml.h"
  3. #include "ggml-cuda.h"
  4. #include <memory>
  5. #if defined(GGML_USE_HIPBLAS)
  6. #define GGML_COMMON_DECL_HIP
  7. #define GGML_COMMON_IMPL_HIP
  8. #else
  9. #define GGML_COMMON_DECL_CUDA
  10. #define GGML_COMMON_IMPL_CUDA
  11. #endif
  12. #include "ggml-common.h"
  13. #include <cstdio>
  14. #include <array>
  15. #include <cassert>
  16. #include <cfloat>
  17. #include <string>
  18. #include <vector>
  19. #if defined(GGML_USE_HIPBLAS)
  20. #include <hip/hip_runtime.h>
  21. #include <hipblas/hipblas.h>
  22. #include <hip/hip_fp16.h>
  23. #ifdef __HIP_PLATFORM_AMD__
  24. // for rocblas_initialize()
  25. #include "rocblas/rocblas.h"
  26. #endif // __HIP_PLATFORM_AMD__
  27. #define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
  28. #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
  29. #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
  30. #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
  31. #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
  32. #define CUBLAS_OP_N HIPBLAS_OP_N
  33. #define CUBLAS_OP_T HIPBLAS_OP_T
  34. #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
  35. #define CUBLAS_TF32_TENSOR_OP_MATH 0
  36. #define CUDA_R_16F HIPBLAS_R_16F
  37. #define CUDA_R_32F HIPBLAS_R_32F
  38. #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
  39. #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
  40. #define cublasCreate hipblasCreate
  41. #define cublasDestroy hipblasDestroy
  42. #define cublasGemmEx hipblasGemmEx
  43. #define cublasGemmBatchedEx hipblasGemmBatchedEx
  44. #define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
  45. #define cublasHandle_t hipblasHandle_t
  46. #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
  47. #define cublasSetStream hipblasSetStream
  48. #define cublasSgemm hipblasSgemm
  49. #define cublasStatus_t hipblasStatus_t
  50. #define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
  51. #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
  52. #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
  53. #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
  54. #define cudaDeviceProp hipDeviceProp_t
  55. #define cudaDeviceSynchronize hipDeviceSynchronize
  56. #define cudaError_t hipError_t
  57. #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
  58. #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
  59. #define cudaEventCreateWithFlags hipEventCreateWithFlags
  60. #define cudaEventDisableTiming hipEventDisableTiming
  61. #define cudaEventRecord hipEventRecord
  62. #define cudaEventSynchronize hipEventSynchronize
  63. #define cudaEvent_t hipEvent_t
  64. #define cudaEventDestroy hipEventDestroy
  65. #define cudaFree hipFree
  66. #define cudaFreeHost hipHostFree
  67. #define cudaGetDevice hipGetDevice
  68. #define cudaGetDeviceCount hipGetDeviceCount
  69. #define cudaGetDeviceProperties hipGetDeviceProperties
  70. #define cudaGetErrorString hipGetErrorString
  71. #define cudaGetLastError hipGetLastError
  72. #define cudaHostRegister hipHostRegister
  73. #define cudaHostRegisterPortable hipHostRegisterPortable
  74. #define cudaHostRegisterReadOnly hipHostRegisterReadOnly
  75. #define cudaHostUnregister hipHostUnregister
  76. #define cudaLaunchHostFunc hipLaunchHostFunc
  77. #ifdef GGML_HIP_UMA
  78. #define cudaMalloc hipMallocManaged
  79. #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size)
  80. #else
  81. #define cudaMalloc hipMalloc
  82. #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
  83. #endif
  84. #define cudaMemcpy hipMemcpy
  85. #define cudaMemcpyAsync hipMemcpyAsync
  86. #define cudaMemcpyPeerAsync hipMemcpyPeerAsync
  87. #define cudaMemcpy2DAsync hipMemcpy2DAsync
  88. #define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
  89. #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
  90. #define cudaMemcpyHostToDevice hipMemcpyHostToDevice
  91. #define cudaMemcpyKind hipMemcpyKind
  92. #define cudaMemset hipMemset
  93. #define cudaMemsetAsync hipMemsetAsync
  94. #define cudaMemGetInfo hipMemGetInfo
  95. #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
  96. #define cudaSetDevice hipSetDevice
  97. #define cudaStreamCreateWithFlags hipStreamCreateWithFlags
  98. #define cudaStreamDestroy hipStreamDestroy
  99. #define cudaStreamFireAndForget hipStreamFireAndForget
  100. #define cudaStreamNonBlocking hipStreamNonBlocking
  101. #define cudaStreamPerThread hipStreamPerThread
  102. #define cudaStreamSynchronize hipStreamSynchronize
  103. #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
  104. #define cudaStream_t hipStream_t
  105. #define cudaSuccess hipSuccess
  106. #define __trap abort
  107. #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
  108. #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
  109. #define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
  110. #define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
  111. #define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
  112. #define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
  113. #define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
  114. #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
  115. #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
  116. #else
  117. #include <cuda_runtime.h>
  118. #include <cuda.h>
  119. #include <cublas_v2.h>
  120. #include <cuda_fp16.h>
  121. #if CUDART_VERSION < 11020
  122. #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
  123. #define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
  124. #define CUBLAS_COMPUTE_16F CUDA_R_16F
  125. #define CUBLAS_COMPUTE_32F CUDA_R_32F
  126. #define cublasComputeType_t cudaDataType_t
  127. #endif // CUDART_VERSION < 11020
  128. #endif // defined(GGML_USE_HIPBLAS)
  129. #define STRINGIZE_IMPL(...) #__VA_ARGS__
  130. #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
  131. #define WARP_SIZE 32
  132. #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
  133. #define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
  134. #define CC_PASCAL 600
  135. #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
  136. #define CC_VOLTA 700
  137. #define CC_AMPERE 800
  138. #define CC_OFFSET_AMD 1000000
  139. #define CC_RDNA1 (CC_OFFSET_AMD + 1010)
  140. #define CC_RDNA2 (CC_OFFSET_AMD + 1030)
  141. #define CC_RDNA3 (CC_OFFSET_AMD + 1100)
  142. // define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
  143. // on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
  144. // for large computational tasks. the drawback is that this requires some extra amount of VRAM:
  145. // - 7B quantum model: +100-200 MB
  146. // - 13B quantum model: +200-400 MB
  147. //
  148. //#define GGML_CUDA_FORCE_MMQ
  149. // TODO: improve this to be correct for more hardware
  150. // for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores
  151. #if !defined(GGML_CUDA_FORCE_MMQ)
  152. #define CUDA_USE_TENSOR_CORES
  153. #endif
  154. #define MMVQ_MAX_BATCH_SIZE 8 // max batch size to use MMVQ kernels
  155. #define MMQ_MAX_BATCH_SIZE 32 // max batch size to use MMQ kernels when tensor cores are available
  156. #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
  157. #if defined(_MSC_VER)
  158. #pragma warning(disable: 4244 4267) // possible loss of data
  159. #endif
  160. #define GGML_CUDA_MAX_STREAMS 8
  161. [[noreturn]]
  162. void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg);
  163. #define CUDA_CHECK_GEN(err, success, error_fn) \
  164. do { \
  165. auto err_ = (err); \
  166. if (err_ != (success)) { \
  167. ggml_cuda_error(#err, __func__, __FILE__, __LINE__, error_fn(err_)); \
  168. } \
  169. } while (0)
  170. #define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
  171. #if CUDART_VERSION >= 12000
  172. static const char * cublas_get_error_str(const cublasStatus_t err) {
  173. return cublasGetStatusString(err);
  174. }
  175. #else
  176. static const char * cublas_get_error_str(const cublasStatus_t err) {
  177. switch (err) {
  178. case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
  179. case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
  180. case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
  181. case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
  182. case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
  183. case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
  184. case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
  185. case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
  186. case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
  187. default: return "unknown error";
  188. }
  189. }
  190. #endif // CUDART_VERSION >= 12000
  191. #define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
  192. #if !defined(GGML_USE_HIPBLAS)
  193. static const char * cu_get_error_str(CUresult err) {
  194. const char * err_str;
  195. cuGetErrorString(err, &err_str);
  196. return err_str;
  197. }
  198. #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
  199. #endif
  200. #if CUDART_VERSION >= 11100
  201. #define GGML_CUDA_ASSUME(x) __builtin_assume(x)
  202. #else
  203. #define GGML_CUDA_ASSUME(x)
  204. #endif // CUDART_VERSION >= 11100
  205. #ifdef GGML_CUDA_F16
  206. typedef half dfloat; // dequantize float
  207. typedef half2 dfloat2;
  208. #else
  209. typedef float dfloat; // dequantize float
  210. typedef float2 dfloat2;
  211. #endif //GGML_CUDA_F16
  212. #if defined(GGML_USE_HIPBLAS)
  213. #define __CUDA_ARCH__ 1300
  214. #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
  215. defined(__gfx1150__) || defined(__gfx1151__)
  216. #define RDNA3
  217. #endif
  218. #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
  219. defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
  220. #define RDNA2
  221. #endif
  222. #ifndef __has_builtin
  223. #define __has_builtin(x) 0
  224. #endif
  225. typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
  226. typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
  227. static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
  228. const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
  229. const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
  230. #if __has_builtin(__builtin_elementwise_sub_sat)
  231. const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
  232. return reinterpret_cast<const int &>(c);
  233. #else
  234. int8x4_t c;
  235. int16_t tmp;
  236. #pragma unroll
  237. for (int i = 0; i < 4; i++) {
  238. tmp = va[i] - vb[i];
  239. if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
  240. if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
  241. c[i] = tmp;
  242. }
  243. return reinterpret_cast<int &>(c);
  244. #endif // __has_builtin(__builtin_elementwise_sub_sat)
  245. }
  246. static __device__ __forceinline__ int __vsub4(const int a, const int b) {
  247. return __vsubss4(a, b);
  248. }
  249. static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
  250. const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
  251. const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
  252. unsigned int c;
  253. uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
  254. #pragma unroll
  255. for (int i = 0; i < 4; ++i) {
  256. vc[i] = va[i] == vb[i] ? 0xff : 0x00;
  257. }
  258. return c;
  259. }
  260. static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
  261. #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
  262. c = __builtin_amdgcn_sdot4(a, b, c, false);
  263. #elif defined(RDNA3)
  264. c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
  265. #elif defined(__gfx1010__) || defined(__gfx900__)
  266. int tmp1;
  267. int tmp2;
  268. asm("\n \
  269. v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
  270. v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
  271. v_add3_u32 %0, %1, %2, %0 \n \
  272. v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
  273. v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
  274. v_add3_u32 %0, %1, %2, %0 \n \
  275. "
  276. : "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
  277. : "v"(a), "v"(b)
  278. );
  279. #else
  280. const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
  281. const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
  282. c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
  283. #endif
  284. return c;
  285. }
  286. #if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
  287. // __shfl_xor() for half2 was added in ROCm 5.6
  288. static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
  289. typedef union half2_b32 {
  290. half2 val;
  291. int b32;
  292. } half2_b32_t;
  293. half2_b32_t tmp;
  294. tmp.val = var;
  295. tmp.b32 = __shfl_xor(tmp.b32, laneMask, width);
  296. return tmp.val;
  297. }
  298. #endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
  299. #endif // defined(GGML_USE_HIPBLAS)
  300. #define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
  301. #define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
  302. static bool fast_fp16_available(const int cc) {
  303. return cc >= CC_PASCAL && cc != 610;
  304. }
  305. static bool fp16_mma_available(const int cc) {
  306. return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
  307. }
  308. [[noreturn]]
  309. static __device__ void no_device_code(
  310. const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
  311. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  312. printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
  313. file_name, line, function_name, arch);
  314. GGML_UNUSED(arch_list);
  315. #else
  316. printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
  317. file_name, line, function_name, arch, arch_list);
  318. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  319. __trap();
  320. GGML_UNUSED(no_device_code); // suppress unused function warning
  321. }
  322. #ifdef __CUDA_ARCH__
  323. #define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
  324. #else
  325. #define NO_DEVICE_CODE //GGML_ASSERT(false && "NO_DEVICE_CODE not valid in host code.")
  326. #endif // __CUDA_ARCH__
  327. static __device__ __forceinline__ float warp_reduce_sum(float x) {
  328. #pragma unroll
  329. for (int mask = 16; mask > 0; mask >>= 1) {
  330. x += __shfl_xor_sync(0xffffffff, x, mask, 32);
  331. }
  332. return x;
  333. }
  334. static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
  335. #pragma unroll
  336. for (int mask = 16; mask > 0; mask >>= 1) {
  337. a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
  338. a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
  339. }
  340. return a;
  341. }
  342. static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
  343. #if FP16_AVAILABLE
  344. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  345. #pragma unroll
  346. for (int mask = 16; mask > 0; mask >>= 1) {
  347. const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32);
  348. reinterpret_cast<half&>(a.x) += __low2half(a_other);
  349. reinterpret_cast<half&>(a.y) += __high2half(a_other);
  350. }
  351. return a;
  352. #else
  353. #pragma unroll
  354. for (int mask = 16; mask > 0; mask >>= 1) {
  355. a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
  356. }
  357. return a;
  358. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  359. #else
  360. NO_DEVICE_CODE;
  361. return a;
  362. #endif // FP16_AVAILABLE
  363. }
  364. static __device__ __forceinline__ float warp_reduce_max(float x) {
  365. #pragma unroll
  366. for (int mask = 16; mask > 0; mask >>= 1) {
  367. x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
  368. }
  369. return x;
  370. }
  371. static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
  372. #if FP16_AVAILABLE
  373. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
  374. return __float2half(fmaxf(__half2float(a), __half2float(b)));
  375. #else
  376. return __hmax(a, b);
  377. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
  378. #else
  379. NO_DEVICE_CODE;
  380. GGML_UNUSED(b);
  381. return a;
  382. #endif // FP16_AVAILABLE
  383. }
  384. static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
  385. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  386. #if CUDART_VERSION >= CUDART_HMAX
  387. return __hmax2(a, b);
  388. #else
  389. half2 ret;
  390. reinterpret_cast<half&>(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b)));
  391. reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
  392. return ret;
  393. #endif // CUDART_VERSION >= CUDART_HMAX
  394. #else
  395. GGML_UNUSED(a);
  396. GGML_UNUSED(b);
  397. NO_DEVICE_CODE;
  398. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  399. }
  400. static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
  401. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
  402. #pragma unroll
  403. for (int mask = 16; mask > 0; mask >>= 1) {
  404. x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
  405. }
  406. return x;
  407. #else
  408. GGML_UNUSED(x);
  409. NO_DEVICE_CODE;
  410. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
  411. }
  412. #if CUDART_VERSION < CUDART_HMASK
  413. static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
  414. const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
  415. const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
  416. return mask_low | mask_high;
  417. }
  418. #endif // CUDART_VERSION < 12000
  419. // TODO: move to ggml-common.h
  420. static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
  421. typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
  422. static __device__ __forceinline__ float get_alibi_slope(
  423. const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1
  424. ) {
  425. if (max_bias <= 0.0f) {
  426. return 1.0f;
  427. }
  428. const float base = h < n_head_log2 ? m0 : m1;
  429. const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
  430. return powf(base, exph);
  431. }
  432. //////////////////////
  433. struct ggml_cuda_device_info {
  434. int device_count;
  435. struct cuda_device_info {
  436. int cc; // compute capability
  437. int nsm; // number of streaming multiprocessors
  438. size_t smpb; // max. shared memory per block
  439. bool vmm; // virtual memory support
  440. size_t vmm_granularity; // granularity of virtual memory
  441. size_t total_vram;
  442. };
  443. cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
  444. std::array<float, GGML_CUDA_MAX_DEVICES> default_tensor_split = {};
  445. };
  446. const ggml_cuda_device_info & ggml_cuda_info();
  447. void ggml_cuda_set_device(int device);
  448. int ggml_cuda_get_device();
  449. struct ggml_cuda_pool {
  450. virtual ~ggml_cuda_pool() = default;
  451. virtual void * alloc(size_t size, size_t * actual_size) = 0;
  452. virtual void free(void * ptr, size_t size) = 0;
  453. };
  454. template<typename T>
  455. struct ggml_cuda_pool_alloc {
  456. ggml_cuda_pool * pool = nullptr;
  457. T * ptr = nullptr;
  458. size_t actual_size = 0;
  459. ggml_cuda_pool_alloc() = default;
  460. explicit ggml_cuda_pool_alloc(ggml_cuda_pool & pool) : pool(&pool) {
  461. }
  462. ggml_cuda_pool_alloc(ggml_cuda_pool & pool, size_t size) : pool(&pool) {
  463. alloc(size);
  464. }
  465. ~ggml_cuda_pool_alloc() {
  466. if (ptr != nullptr) {
  467. pool->free(ptr, actual_size);
  468. }
  469. }
  470. // size is in number of elements
  471. T * alloc(size_t size) {
  472. GGML_ASSERT(pool != nullptr);
  473. GGML_ASSERT(ptr == nullptr);
  474. ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
  475. return ptr;
  476. }
  477. T * alloc(ggml_cuda_pool & pool, size_t size) {
  478. this->pool = &pool;
  479. return alloc(size);
  480. }
  481. T * get() {
  482. return ptr;
  483. }
  484. ggml_cuda_pool_alloc(const ggml_cuda_pool_alloc &) = delete;
  485. ggml_cuda_pool_alloc(ggml_cuda_pool_alloc &&) = delete;
  486. ggml_cuda_pool_alloc& operator=(const ggml_cuda_pool_alloc &) = delete;
  487. ggml_cuda_pool_alloc& operator=(ggml_cuda_pool_alloc &&) = delete;
  488. };
  489. // backend interface
  490. struct ggml_tensor_extra_gpu {
  491. void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
  492. cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs
  493. };
  494. #if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)
  495. #define USE_CUDA_GRAPH
  496. #endif
  497. struct ggml_graph_node_properties {
  498. void * node_address;
  499. ggml_op node_op;
  500. int64_t ne[GGML_MAX_DIMS];
  501. size_t nb[GGML_MAX_DIMS];
  502. void * src_address[GGML_MAX_SRC];
  503. };
  504. struct ggml_cuda_graph {
  505. #ifdef USE_CUDA_GRAPH
  506. ~ggml_cuda_graph() {
  507. if (instance != nullptr) {
  508. CUDA_CHECK(cudaGraphExecDestroy(instance));
  509. }
  510. if (graph != nullptr) {
  511. CUDA_CHECK(cudaGraphDestroy(graph));
  512. }
  513. }
  514. cudaGraph_t graph = nullptr;
  515. cudaGraphExec_t instance = nullptr;
  516. size_t num_nodes = 0;
  517. std::vector<cudaGraphNode_t> nodes;
  518. std::vector<cudaKernelNodeParams> params;
  519. bool disable_due_to_gpu_arch = false;
  520. bool disable_due_to_too_many_updates = false;
  521. bool disable_due_to_failed_graph_capture = false;
  522. int number_consecutive_updates = 0;
  523. std::vector<ggml_graph_node_properties> ggml_graph_properties;
  524. std::vector<char **> updated_kernel_arg;
  525. #endif
  526. };
  527. struct ggml_backend_cuda_context {
  528. int device;
  529. std::string name;
  530. cudaEvent_t copy_event = nullptr;
  531. cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
  532. cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
  533. std::unique_ptr<ggml_cuda_graph> cuda_graph;
  534. explicit ggml_backend_cuda_context(int device) :
  535. device(device),
  536. name(GGML_CUDA_NAME + std::to_string(device)) {
  537. }
  538. ~ggml_backend_cuda_context() {
  539. if (copy_event != nullptr) {
  540. CUDA_CHECK(cudaEventDestroy(copy_event));
  541. }
  542. for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
  543. for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
  544. if (streams[i][j] != nullptr) {
  545. CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
  546. }
  547. }
  548. if (cublas_handles[i] != nullptr) {
  549. CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
  550. }
  551. }
  552. }
  553. cudaStream_t stream(int device, int stream) {
  554. if (streams[device][stream] == nullptr) {
  555. ggml_cuda_set_device(device);
  556. CUDA_CHECK(cudaStreamCreateWithFlags(&streams[device][stream], cudaStreamNonBlocking));
  557. }
  558. return streams[device][stream];
  559. }
  560. cudaStream_t stream() {
  561. return stream(device, 0);
  562. }
  563. cublasHandle_t cublas_handle(int device) {
  564. if (cublas_handles[device] == nullptr) {
  565. ggml_cuda_set_device(device);
  566. CUBLAS_CHECK(cublasCreate(&cublas_handles[device]));
  567. CUBLAS_CHECK(cublasSetMathMode(cublas_handles[device], CUBLAS_TF32_TENSOR_OP_MATH));
  568. }
  569. return cublas_handles[device];
  570. }
  571. cublasHandle_t cublas_handle() {
  572. return cublas_handle(device);
  573. }
  574. // pool
  575. std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
  576. static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device);
  577. ggml_cuda_pool & pool(int device) {
  578. if (pools[device] == nullptr) {
  579. pools[device] = new_pool_for_device(device);
  580. }
  581. return *pools[device];
  582. }
  583. ggml_cuda_pool & pool() {
  584. return pool(device);
  585. }
  586. };