musa.h 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. /**
  2. * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #pragma once
  27. #include <musa_runtime.h>
  28. #include <musa.h>
  29. #include <mublas.h>
  30. #include <musa_bf16.h>
  31. #include <musa_fp16.h>
  32. #define CUBLAS_COMPUTE_16F CUDA_R_16F
  33. #define CUBLAS_COMPUTE_32F CUDA_R_32F
  34. #define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
  35. #define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
  36. #define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
  37. #define CUBLAS_OP_N MUBLAS_OP_N
  38. #define CUBLAS_OP_T MUBLAS_OP_T
  39. #define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
  40. #define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
  41. #define CUDA_R_16F MUSA_R_16F
  42. #define CUDA_R_32F MUSA_R_32F
  43. #define cublasComputeType_t cudaDataType_t
  44. #define cublasCreate mublasCreate
  45. #define cublasDestroy mublasDestroy
  46. #define cublasGemmEx mublasGemmEx
  47. #define cublasGemmBatchedEx mublasGemmBatchedEx
  48. #define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
  49. #define cublasHandle_t mublasHandle_t
  50. #define cublasSetMathMode mublasSetMathMode
  51. #define cublasSetStream mublasSetStream
  52. #define cublasSgemm mublasSgemm
  53. #define cublasStatus_t mublasStatus_t
  54. #define cublasOperation_t mublasOperation_t
  55. #define cublasGetStatusString mublasStatus_to_string
  56. #define cudaDataType_t musaDataType_t
  57. #define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
  58. #define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
  59. #define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
  60. #define cudaDeviceProp musaDeviceProp
  61. #define cudaDeviceSynchronize musaDeviceSynchronize
  62. #define cudaError_t musaError_t
  63. #define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
  64. #define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
  65. #define cudaEventCreateWithFlags musaEventCreateWithFlags
  66. #define cudaEventDisableTiming musaEventDisableTiming
  67. #define cudaEventRecord musaEventRecord
  68. #define cudaEventSynchronize musaEventSynchronize
  69. #define cudaEvent_t musaEvent_t
  70. #define cudaEventDestroy musaEventDestroy
  71. #define cudaFree musaFree
  72. #define cudaFreeHost musaFreeHost
  73. #define cudaGetDevice musaGetDevice
  74. #define cudaGetDeviceCount musaGetDeviceCount
  75. #define cudaGetDeviceProperties musaGetDeviceProperties
  76. #define cudaGetErrorString musaGetErrorString
  77. #define cudaGetLastError musaGetLastError
  78. #define cudaHostRegister musaHostRegister
  79. #define cudaHostRegisterPortable musaHostRegisterPortable
  80. #define cudaHostRegisterReadOnly musaHostRegisterReadOnly
  81. #define cudaHostUnregister musaHostUnregister
  82. #define cudaLaunchHostFunc musaLaunchHostFunc
  83. #define cudaMalloc musaMalloc
  84. #define cudaMallocHost musaMallocHost
  85. #define cudaMallocManaged musaMallocManaged
  86. #define cudaMemcpy musaMemcpy
  87. #define cudaMemcpyAsync musaMemcpyAsync
  88. #define cudaMemcpyPeerAsync musaMemcpyPeerAsync
  89. #define cudaMemcpy2DAsync musaMemcpy2DAsync
  90. #define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
  91. #define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
  92. #define cudaMemcpyHostToDevice musaMemcpyHostToDevice
  93. #define cudaMemcpyKind musaMemcpyKind
  94. #define cudaMemset musaMemset
  95. #define cudaMemsetAsync musaMemsetAsync
  96. #define cudaMemGetInfo musaMemGetInfo
  97. #define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
  98. #define cudaSetDevice musaSetDevice
  99. #define cudaStreamCreateWithFlags musaStreamCreateWithFlags
  100. #define cudaStreamDestroy musaStreamDestroy
  101. #define cudaStreamFireAndForget musaStreamFireAndForget
  102. #define cudaStreamNonBlocking musaStreamNonBlocking
  103. #define cudaStreamPerThread musaStreamPerThread
  104. #define cudaStreamSynchronize musaStreamSynchronize
  105. #define cudaStreamWaitEvent musaStreamWaitEvent
  106. #define cudaStream_t musaStream_t
  107. #define cudaSuccess musaSuccess
  108. // Additional mappings for MUSA virtual memory pool
  109. #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
  110. #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
  111. #define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
  112. #define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
  113. #define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
  114. #define CUdevice MUdevice
  115. #define CUdeviceptr MUdeviceptr
  116. #define CUmemAccessDesc MUmemAccessDesc
  117. #define CUmemAllocationProp MUmemAllocationProp
  118. #define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
  119. #define cuDeviceGet muDeviceGet
  120. #define cuDeviceGetAttribute muDeviceGetAttribute
  121. #define cuMemAddressFree muMemAddressFree
  122. #define cuMemAddressReserve muMemAddressReserve
  123. #define cuMemCreate muMemCreate
  124. #define cuMemGetAllocationGranularity muMemGetAllocationGranularity
  125. #define cuMemMap muMemMap
  126. #define cuMemRelease muMemRelease
  127. #define cuMemSetAccess muMemSetAccess
  128. #define cuMemUnmap muMemUnmap
  129. #define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
  130. #define cudaFuncSetAttribute musaFuncSetAttribute
  131. #define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
  132. #define make_cudaExtent make_musaExtent
  133. #define make_cudaPitchedPtr make_musaPitchedPtr
  134. // Additional mappings for MUSA graphs
  135. #define CUDA_SUCCESS MUSA_SUCCESS
  136. #define CUresult MUresult
  137. #define cuGetErrorString muGetErrorString
  138. #define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
  139. #define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
  140. #define cudaGraphDestroy musaGraphDestroy
  141. #define cudaGraphExecDestroy musaGraphExecDestroy
  142. #define cudaGraphExec_t musaGraphExec_t
  143. #define cudaGraphExecUpdate musaGraphExecUpdate
  144. #define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
  145. #define cudaGraphGetNodes musaGraphGetNodes
  146. #define cudaGraphInstantiate musaGraphInstantiate
  147. #define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
  148. #define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
  149. #define cudaGraphLaunch musaGraphLaunch
  150. #define cudaGraphNodeGetType musaGraphNodeGetType
  151. #define cudaGraphNode_t musaGraphNode_t
  152. #define cudaGraphNodeType musaGraphNodeType
  153. #define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
  154. #define cudaGraph_t musaGraph_t
  155. #define cudaKernelNodeParams musaKernelNodeParams
  156. #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
  157. #define cudaStreamEndCapture musaStreamEndCapture
  158. typedef mt_bfloat16 nv_bfloat16;