hip.h 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. /**
  2. * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #pragma once
  27. #include <hip/hip_runtime.h>
  28. #include <hipblas/hipblas.h>
  29. #include <hip/hip_fp16.h>
  30. #include <hip/hip_bfloat16.h>
  31. #ifdef __HIP_PLATFORM_AMD__
  32. // for rocblas_initialize()
  33. #include "rocblas/rocblas.h"
  34. #endif // __HIP_PLATFORM_AMD__
  35. #define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
  36. #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
  37. #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
  38. #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
  39. #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
  40. #define CUBLAS_OP_N HIPBLAS_OP_N
  41. #define CUBLAS_OP_T HIPBLAS_OP_T
  42. #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
  43. #define CUBLAS_TF32_TENSOR_OP_MATH 0
  44. #define CUDA_R_16F HIPBLAS_R_16F
  45. #define CUDA_R_32F HIPBLAS_R_32F
  46. #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
  47. #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
  48. #define cublasCreate hipblasCreate
  49. #define cublasDestroy hipblasDestroy
  50. #define cublasGemmEx hipblasGemmEx
  51. #define cublasGemmBatchedEx hipblasGemmBatchedEx
  52. #define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
  53. #define cublasHandle_t hipblasHandle_t
  54. #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
  55. #define cublasSetStream hipblasSetStream
  56. #define cublasSgemm hipblasSgemm
  57. #define cublasStatus_t hipblasStatus_t
  58. #define cublasOperation_t hipblasOperation_t
  59. #define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
  60. #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
  61. #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
  62. #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
  63. #define cudaDeviceProp hipDeviceProp_t
  64. #define cudaDeviceSynchronize hipDeviceSynchronize
  65. #define cudaError_t hipError_t
  66. #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
  67. #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
  68. #define cudaEventCreateWithFlags hipEventCreateWithFlags
  69. #define cudaEventDisableTiming hipEventDisableTiming
  70. #define cudaEventRecord hipEventRecord
  71. #define cudaEventSynchronize hipEventSynchronize
  72. #define cudaEvent_t hipEvent_t
  73. #define cudaEventDestroy hipEventDestroy
  74. #define cudaFree hipFree
  75. #define cudaFreeHost hipHostFree
  76. #define cudaGetDevice hipGetDevice
  77. #define cudaGetDeviceCount hipGetDeviceCount
  78. #define cudaGetDeviceProperties hipGetDeviceProperties
  79. #define cudaGetErrorString hipGetErrorString
  80. #define cudaGetLastError hipGetLastError
  81. #define cudaHostRegister hipHostRegister
  82. #define cudaHostRegisterPortable hipHostRegisterPortable
  83. #define cudaHostRegisterReadOnly hipHostRegisterReadOnly
  84. #define cudaHostUnregister hipHostUnregister
  85. #define cudaLaunchHostFunc hipLaunchHostFunc
  86. #define cudaMalloc hipMalloc
  87. #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
  88. #define cudaMemcpy hipMemcpy
  89. #define cudaMemcpyAsync hipMemcpyAsync
  90. #define cudaMemcpyPeerAsync hipMemcpyPeerAsync
  91. #define cudaMemcpy2DAsync hipMemcpy2DAsync
  92. #define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
  93. #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
  94. #define cudaMemcpyHostToDevice hipMemcpyHostToDevice
  95. #define cudaMemcpyKind hipMemcpyKind
  96. #define cudaMemset hipMemset
  97. #define cudaMemsetAsync hipMemsetAsync
  98. #define cudaMemGetInfo hipMemGetInfo
  99. #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
  100. #define cudaSetDevice hipSetDevice
  101. #define cudaStreamCreateWithFlags hipStreamCreateWithFlags
  102. #define cudaStreamDestroy hipStreamDestroy
  103. #define cudaStreamFireAndForget hipStreamFireAndForget
  104. #define cudaStreamNonBlocking hipStreamNonBlocking
  105. #define cudaStreamPerThread hipStreamPerThread
  106. #define cudaStreamSynchronize hipStreamSynchronize
  107. #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
  108. #define cudaStream_t hipStream_t
  109. #define cudaSuccess hipSuccess
  110. #define __trap() do { abort(); __builtin_unreachable(); } while(0)
  111. #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
  112. #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
  113. #define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
  114. #define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
  115. #define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
  116. #define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
  117. #define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
  118. #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
  119. #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
  120. #define __CUDA_ARCH__ 1300
  121. #if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
  122. #define GCN
  123. #endif
  124. #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
  125. #define CDNA
  126. #endif
  127. #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
  128. defined(__gfx1150__) || defined(__gfx1151__)
  129. #define RDNA3
  130. #endif
  131. #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
  132. defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
  133. #define RDNA2
  134. #endif
  135. #if defined(__gfx1010__) || defined(__gfx1012__)
  136. #define RDNA1
  137. #endif
  138. #ifndef __has_builtin
  139. #define __has_builtin(x) 0
  140. #endif
  141. typedef hip_bfloat16 nv_bfloat16;
  142. typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
  143. typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
  144. static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
  145. const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
  146. const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
  147. #if __has_builtin(__builtin_elementwise_sub_sat)
  148. const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
  149. return reinterpret_cast<const int &>(c);
  150. #else
  151. int8x4_t c;
  152. int16_t tmp;
  153. #pragma unroll
  154. for (int i = 0; i < 4; i++) {
  155. tmp = va[i] - vb[i];
  156. if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
  157. if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
  158. c[i] = tmp;
  159. }
  160. return reinterpret_cast<int &>(c);
  161. #endif // __has_builtin(__builtin_elementwise_sub_sat)
  162. }
  163. static __device__ __forceinline__ int __vsub4(const int a, const int b) {
  164. return __vsubss4(a, b);
  165. }
  166. static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
  167. const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
  168. const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
  169. unsigned int c;
  170. uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
  171. #pragma unroll
  172. for (int i = 0; i < 4; ++i) {
  173. vc[i] = va[i] == vb[i] ? 0xff : 0x00;
  174. }
  175. return c;
  176. }
  177. static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigned int b) {
  178. const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
  179. const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
  180. unsigned int c;
  181. uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
  182. #pragma unroll
  183. for (int i = 0; i < 4; ++i) {
  184. vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
  185. }
  186. return c;
  187. }
  188. #if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
  189. // __shfl_xor() for half2 was added in ROCm 5.6
  190. static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
  191. typedef union half2_b32 {
  192. half2 val;
  193. int b32;
  194. } half2_b32_t;
  195. half2_b32_t tmp;
  196. tmp.val = var;
  197. tmp.b32 = __shfl_xor(tmp.b32, laneMask, width);
  198. return tmp.val;
  199. }
  200. #endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000