|
@@ -1,5 +1,6 @@
|
|
|
#pragma once
|
|
|
|
|
|
+#define HIP_ENABLE_WARP_SYNC_BUILTINS 1
|
|
|
#include <hip/hip_runtime.h>
|
|
|
#include <hipblas/hipblas.h>
|
|
|
#include <hip/hip_fp16.h>
|
|
@@ -8,6 +9,7 @@
|
|
|
// for rocblas_initialize()
|
|
|
#include "rocblas/rocblas.h"
|
|
|
#endif // __HIP_PLATFORM_AMD__
|
|
|
+
|
|
|
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
|
|
|
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
|
|
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
|
@@ -19,6 +21,13 @@
|
|
|
#define CUBLAS_TF32_TENSOR_OP_MATH 0
|
|
|
#define CUDA_R_16F HIPBLAS_R_16F
|
|
|
#define CUDA_R_32F HIPBLAS_R_32F
|
|
|
+#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
|
|
|
+#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
|
|
|
+#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
|
|
|
+#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
|
|
|
+#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
|
|
|
+#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
|
|
|
+#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
|
|
|
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
|
|
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
|
|
|
#define cublasCreate hipblasCreate
|
|
@@ -74,6 +83,21 @@
|
|
|
#define cudaMemGetInfo hipMemGetInfo
|
|
|
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
|
|
|
#define cudaSetDevice hipSetDevice
|
|
|
+#define cuDeviceGet hipDeviceGet
|
|
|
+#define CUdevice hipDevice_t
|
|
|
+#define CUdeviceptr hipDeviceptr_t
|
|
|
+#define cuMemUnmap hipMemUnmap
|
|
|
+#define CUmemAccessDesc hipMemAccessDesc
|
|
|
+#define cuMemAddressFree hipMemAddressFree
|
|
|
+#define cuMemRelease hipMemRelease
|
|
|
+#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t
|
|
|
+#define cuMemCreate hipMemCreate
|
|
|
+#define cuMemAddressReserve hipMemAddressReserve
|
|
|
+#define cuMemMap hipMemMap
|
|
|
+#define cuMemSetAccess hipMemSetAccess
|
|
|
+#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity
|
|
|
+#define CUmemAllocationProp hipMemAllocationProp
|
|
|
+#define cuDeviceGetAttribute hipDeviceGetAttribute
|
|
|
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
|
|
|
#define cudaStreamDestroy hipStreamDestroy
|
|
|
#define cudaStreamFireAndForget hipStreamFireAndForget
|
|
@@ -81,6 +105,28 @@
|
|
|
#define cudaStreamPerThread hipStreamPerThread
|
|
|
#define cudaStreamSynchronize hipStreamSynchronize
|
|
|
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
|
|
|
+#define cudaGraphExec_t hipGraphExec_t
|
|
|
+#define cudaGraphNode_t hipGraphNode_t
|
|
|
+#define cudaKernelNodeParams hipKernelNodeParams
|
|
|
+#define cudaKernelNodeParams hipKernelNodeParams
|
|
|
+#define cudaGraphExecDestroy hipGraphExecDestroy
|
|
|
+#define cudaGraphLaunch hipGraphLaunch
|
|
|
+#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
|
|
|
+#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
|
|
|
+#define cudaGraphNodeType hipGraphNodeType
|
|
|
+#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
|
|
|
+#define cudaGraphInstantiate hipGraphInstantiate
|
|
|
+#define cudaStreamEndCapture hipStreamEndCapture
|
|
|
+#define cudaGraphDestroy hipGraphDestroy
|
|
|
+#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams
|
|
|
+#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction
|
|
|
+#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
|
|
|
+#define cudaGraphNodeGetType hipGraphNodeGetType
|
|
|
+#define cudaGraphGetNodes hipGraphGetNodes
|
|
|
+#define cudaGraphExecUpdate hipGraphExecUpdate
|
|
|
+#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed
|
|
|
+#define cudaStreamBeginCapture hipStreamBeginCapture
|
|
|
+#define cudaGraph_t hipGraph_t
|
|
|
#define cudaStream_t hipStream_t
|
|
|
#define cudaSuccess hipSuccess
|
|
|
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
|