hip.h 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. /**
  2. * llama.cpp - commit ba1cb19cdd0d92e012e0f6e009e0620f854b6afd - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #pragma once
  27. #include <hip/hip_runtime.h>
  28. #include <hipblas/hipblas.h>
  29. #include <hip/hip_fp16.h>
  30. #ifdef __HIP_PLATFORM_AMD__
  31. // for rocblas_initialize()
  32. #include "rocblas/rocblas.h"
  33. #endif // __HIP_PLATFORM_AMD__
  34. #define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
  35. #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
  36. #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
  37. #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
  38. #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
  39. #define CUBLAS_OP_N HIPBLAS_OP_N
  40. #define CUBLAS_OP_T HIPBLAS_OP_T
  41. #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
  42. #define CUBLAS_TF32_TENSOR_OP_MATH 0
  43. #define CUDA_R_16F HIPBLAS_R_16F
  44. #define CUDA_R_32F HIPBLAS_R_32F
  45. #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
  46. #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
  47. #define cublasCreate hipblasCreate
  48. #define cublasDestroy hipblasDestroy
  49. #define cublasGemmEx hipblasGemmEx
  50. #define cublasGemmBatchedEx hipblasGemmBatchedEx
  51. #define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
  52. #define cublasHandle_t hipblasHandle_t
  53. #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
  54. #define cublasSetStream hipblasSetStream
  55. #define cublasSgemm hipblasSgemm
  56. #define cublasStatus_t hipblasStatus_t
  57. #define cublasOperation_t hipblasOperation_t
  58. #define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
  59. #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
  60. #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
  61. #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
  62. #define cudaDeviceProp hipDeviceProp_t
  63. #define cudaDeviceSynchronize hipDeviceSynchronize
  64. #define cudaError_t hipError_t
  65. #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
  66. #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
  67. #define cudaEventCreateWithFlags hipEventCreateWithFlags
  68. #define cudaEventDisableTiming hipEventDisableTiming
  69. #define cudaEventRecord hipEventRecord
  70. #define cudaEventSynchronize hipEventSynchronize
  71. #define cudaEvent_t hipEvent_t
  72. #define cudaEventDestroy hipEventDestroy
  73. #define cudaFree hipFree
  74. #define cudaFreeHost hipHostFree
  75. #define cudaGetDevice hipGetDevice
  76. #define cudaGetDeviceCount hipGetDeviceCount
  77. #define cudaGetDeviceProperties hipGetDeviceProperties
  78. #define cudaGetErrorString hipGetErrorString
  79. #define cudaGetLastError hipGetLastError
  80. #define cudaHostRegister hipHostRegister
  81. #define cudaHostRegisterPortable hipHostRegisterPortable
  82. #define cudaHostRegisterReadOnly hipHostRegisterReadOnly
  83. #define cudaHostUnregister hipHostUnregister
  84. #define cudaLaunchHostFunc hipLaunchHostFunc
  85. #define cudaMalloc hipMalloc
  86. #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
  87. #define cudaMemcpy hipMemcpy
  88. #define cudaMemcpyAsync hipMemcpyAsync
  89. #define cudaMemcpyPeerAsync hipMemcpyPeerAsync
  90. #define cudaMemcpy2DAsync hipMemcpy2DAsync
  91. #define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
  92. #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
  93. #define cudaMemcpyHostToDevice hipMemcpyHostToDevice
  94. #define cudaMemcpyKind hipMemcpyKind
  95. #define cudaMemset hipMemset
  96. #define cudaMemsetAsync hipMemsetAsync
  97. #define cudaMemGetInfo hipMemGetInfo
  98. #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
  99. #define cudaSetDevice hipSetDevice
  100. #define cudaStreamCreateWithFlags hipStreamCreateWithFlags
  101. #define cudaStreamDestroy hipStreamDestroy
  102. #define cudaStreamFireAndForget hipStreamFireAndForget
  103. #define cudaStreamNonBlocking hipStreamNonBlocking
  104. #define cudaStreamPerThread hipStreamPerThread
  105. #define cudaStreamSynchronize hipStreamSynchronize
  106. #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
  107. #define cudaStream_t hipStream_t
  108. #define cudaSuccess hipSuccess
  109. #define __trap() do { abort(); __builtin_unreachable(); } while(0)
  110. #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
  111. #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
  112. #define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
  113. #define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
  114. #define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
  115. #define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
  116. #define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
  117. #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
  118. #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
  119. #define __CUDA_ARCH__ 1300
  120. #if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
  121. #define GCN
  122. #endif
  123. #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
  124. #define CDNA
  125. #endif
  126. #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
  127. defined(__gfx1150__) || defined(__gfx1151__)
  128. #define RDNA3
  129. #endif
  130. #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
  131. defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
  132. #define RDNA2
  133. #endif
  134. #if defined(__gfx1010__) || defined(__gfx1012__)
  135. #define RDNA1
  136. #endif
  137. #ifndef __has_builtin
  138. #define __has_builtin(x) 0
  139. #endif
  140. typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
  141. typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
  142. static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
  143. const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
  144. const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
  145. #if __has_builtin(__builtin_elementwise_sub_sat)
  146. const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
  147. return reinterpret_cast<const int &>(c);
  148. #else
  149. int8x4_t c;
  150. int16_t tmp;
  151. #pragma unroll
  152. for (int i = 0; i < 4; i++) {
  153. tmp = va[i] - vb[i];
  154. if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
  155. if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
  156. c[i] = tmp;
  157. }
  158. return reinterpret_cast<int &>(c);
  159. #endif // __has_builtin(__builtin_elementwise_sub_sat)
  160. }
  161. static __device__ __forceinline__ int __vsub4(const int a, const int b) {
  162. return __vsubss4(a, b);
  163. }
  164. static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
  165. const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
  166. const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
  167. unsigned int c;
  168. uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
  169. #pragma unroll
  170. for (int i = 0; i < 4; ++i) {
  171. vc[i] = va[i] == vb[i] ? 0xff : 0x00;
  172. }
  173. return c;
  174. }
  175. static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigned int b) {
  176. const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
  177. const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
  178. unsigned int c;
  179. uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
  180. #pragma unroll
  181. for (int i = 0; i < 4; ++i) {
  182. vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
  183. }
  184. return c;
  185. }
  186. #if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
  187. // __shfl_xor() for half2 was added in ROCm 5.6
  188. static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
  189. typedef union half2_b32 {
  190. half2 val;
  191. int b32;
  192. } half2_b32_t;
  193. half2_b32_t tmp;
  194. tmp.val = var;
  195. tmp.b32 = __shfl_xor(tmp.b32, laneMask, width);
  196. return tmp.val;
  197. }
  198. #endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000