gpu.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653
  1. //go:build linux || windows
  2. package gpu
  3. /*
  4. #cgo linux LDFLAGS: -lrt -lpthread -ldl -lstdc++ -lm
  5. #cgo windows LDFLAGS: -lpthread
  6. #include "gpu_info.h"
  7. */
  8. import "C"
  9. import (
  10. "fmt"
  11. "log/slog"
  12. "os"
  13. "path/filepath"
  14. "runtime"
  15. "strings"
  16. "sync"
  17. "unsafe"
  18. "github.com/ollama/ollama/envconfig"
  19. "github.com/ollama/ollama/format"
  20. )
  21. type cudaHandles struct {
  22. deviceCount int
  23. cudart *C.cudart_handle_t
  24. nvcuda *C.nvcuda_handle_t
  25. nvml *C.nvml_handle_t
  26. }
  27. type oneapiHandles struct {
  28. oneapi *C.oneapi_handle_t
  29. deviceCount int
  30. }
  31. const (
  32. cudaMinimumMemory = 457 * format.MebiByte
  33. rocmMinimumMemory = 457 * format.MebiByte
  34. // TODO OneAPI minimum memory
  35. )
  36. var (
  37. gpuMutex sync.Mutex
  38. bootstrapped bool
  39. cpuCapability CPUCapability
  40. cpus []CPUInfo
  41. cudaGPUs []CudaGPUInfo
  42. nvcudaLibPath string
  43. cudartLibPath string
  44. oneapiLibPath string
  45. nvmlLibPath string
  46. rocmGPUs []RocmGPUInfo
  47. oneapiGPUs []OneapiGPUInfo
  48. )
  49. // With our current CUDA compile flags, older than 5.0 will not work properly
  50. var CudaComputeMin = [2]C.int{5, 0}
  51. var RocmComputeMin = 9
  52. // TODO find a better way to detect iGPU instead of minimum memory
  53. const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU
  54. var CudartLinuxGlobs = []string{
  55. "/usr/local/cuda/lib64/libcudart.so*",
  56. "/usr/lib/x86_64-linux-gnu/nvidia/current/libcudart.so*",
  57. "/usr/lib/x86_64-linux-gnu/libcudart.so*",
  58. "/usr/lib/wsl/lib/libcudart.so*",
  59. "/usr/lib/wsl/drivers/*/libcudart.so*",
  60. "/opt/cuda/lib64/libcudart.so*",
  61. "/usr/local/cuda*/targets/aarch64-linux/lib/libcudart.so*",
  62. "/usr/lib/aarch64-linux-gnu/nvidia/current/libcudart.so*",
  63. "/usr/lib/aarch64-linux-gnu/libcudart.so*",
  64. "/usr/local/cuda/lib*/libcudart.so*",
  65. "/usr/lib*/libcudart.so*",
  66. "/usr/local/lib*/libcudart.so*",
  67. }
  68. var CudartWindowsGlobs = []string{
  69. "c:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*\\bin\\cudart64_*.dll",
  70. }
  71. var NvmlWindowsGlobs = []string{
  72. "c:\\Windows\\System32\\nvml.dll",
  73. }
  74. var NvcudaLinuxGlobs = []string{
  75. "/usr/local/cuda*/targets/*/lib/libcuda.so*",
  76. "/usr/lib/*-linux-gnu/nvidia/current/libcuda.so*",
  77. "/usr/lib/*-linux-gnu/libcuda.so*",
  78. "/usr/lib/wsl/lib/libcuda.so*",
  79. "/usr/lib/wsl/drivers/*/libcuda.so*",
  80. "/opt/cuda/lib*/libcuda.so*",
  81. "/usr/local/cuda/lib*/libcuda.so*",
  82. "/usr/lib*/libcuda.so*",
  83. "/usr/local/lib*/libcuda.so*",
  84. }
  85. var NvcudaWindowsGlobs = []string{
  86. "c:\\windows\\system*\\nvcuda.dll",
  87. }
  88. var OneapiWindowsGlobs = []string{
  89. "c:\\Windows\\System32\\DriverStore\\FileRepository\\*\\ze_intel_gpu64.dll",
  90. }
  91. var OneapiLinuxGlobs = []string{
  92. "/usr/lib/x86_64-linux-gnu/libze_intel_gpu.so*",
  93. "/usr/lib*/libze_intel_gpu.so*",
  94. }
  95. // Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed.
  96. // Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
  97. var CudaTegra string = os.Getenv("JETSON_JETPACK")
  98. // Note: gpuMutex must already be held
  99. func initCudaHandles() *cudaHandles {
  100. // TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
  101. cHandles := &cudaHandles{}
  102. // Short Circuit if we already know which library to use
  103. if nvmlLibPath != "" {
  104. cHandles.nvml, _ = LoadNVMLMgmt([]string{nvmlLibPath})
  105. return cHandles
  106. }
  107. if nvcudaLibPath != "" {
  108. cHandles.deviceCount, cHandles.nvcuda, _ = LoadNVCUDAMgmt([]string{nvcudaLibPath})
  109. return cHandles
  110. }
  111. if cudartLibPath != "" {
  112. cHandles.deviceCount, cHandles.cudart, _ = LoadCUDARTMgmt([]string{cudartLibPath})
  113. return cHandles
  114. }
  115. slog.Debug("searching for GPU discovery libraries for NVIDIA")
  116. var cudartMgmtName string
  117. var cudartMgmtPatterns []string
  118. var nvcudaMgmtName string
  119. var nvcudaMgmtPatterns []string
  120. var nvmlMgmtName string
  121. var nvmlMgmtPatterns []string
  122. tmpDir, _ := PayloadsDir()
  123. switch runtime.GOOS {
  124. case "windows":
  125. cudartMgmtName = "cudart64_*.dll"
  126. localAppData := os.Getenv("LOCALAPPDATA")
  127. cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)}
  128. cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...)
  129. // Aligned with driver, we can't carry as payloads
  130. nvcudaMgmtName = "nvcuda.dll"
  131. nvcudaMgmtPatterns = NvcudaWindowsGlobs
  132. // Use nvml to refresh free memory on windows only
  133. nvmlMgmtName = "nvml.dll"
  134. nvmlMgmtPatterns = make([]string, len(NvmlWindowsGlobs))
  135. copy(nvmlMgmtPatterns, NvmlWindowsGlobs)
  136. case "linux":
  137. cudartMgmtName = "libcudart.so*"
  138. if tmpDir != "" {
  139. // TODO - add "payloads" for subprocess
  140. cudartMgmtPatterns = []string{filepath.Join(tmpDir, "cuda*", cudartMgmtName)}
  141. }
  142. cudartMgmtPatterns = append(cudartMgmtPatterns, CudartLinuxGlobs...)
  143. // Aligned with driver, we can't carry as payloads
  144. nvcudaMgmtName = "libcuda.so*"
  145. nvcudaMgmtPatterns = NvcudaLinuxGlobs
  146. // nvml omitted on linux
  147. default:
  148. return cHandles
  149. }
  150. if len(nvmlMgmtPatterns) > 0 {
  151. nvmlLibPaths := FindGPULibs(nvmlMgmtName, nvmlMgmtPatterns)
  152. if len(nvmlLibPaths) > 0 {
  153. nvml, libPath := LoadNVMLMgmt(nvmlLibPaths)
  154. if nvml != nil {
  155. slog.Debug("nvidia-ml loaded", "library", libPath)
  156. cHandles.nvml = nvml
  157. nvmlLibPath = libPath
  158. }
  159. }
  160. }
  161. nvcudaLibPaths := FindGPULibs(nvcudaMgmtName, nvcudaMgmtPatterns)
  162. if len(nvcudaLibPaths) > 0 {
  163. deviceCount, nvcuda, libPath := LoadNVCUDAMgmt(nvcudaLibPaths)
  164. if nvcuda != nil {
  165. slog.Debug("detected GPUs", "count", deviceCount, "library", libPath)
  166. cHandles.nvcuda = nvcuda
  167. cHandles.deviceCount = deviceCount
  168. nvcudaLibPath = libPath
  169. return cHandles
  170. }
  171. }
  172. cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns)
  173. if len(cudartLibPaths) > 0 {
  174. deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths)
  175. if cudart != nil {
  176. slog.Debug("detected GPUs", "library", libPath, "count", deviceCount)
  177. cHandles.cudart = cudart
  178. cHandles.deviceCount = deviceCount
  179. cudartLibPath = libPath
  180. return cHandles
  181. }
  182. }
  183. return cHandles
  184. }
  185. // Note: gpuMutex must already be held
  186. func initOneAPIHandles() *oneapiHandles {
  187. oHandles := &oneapiHandles{}
  188. var oneapiMgmtName string
  189. var oneapiMgmtPatterns []string
  190. // Short Circuit if we already know which library to use
  191. if oneapiLibPath != "" {
  192. oHandles.deviceCount, oHandles.oneapi, _ = LoadOneapiMgmt([]string{oneapiLibPath})
  193. return oHandles
  194. }
  195. switch runtime.GOOS {
  196. case "windows":
  197. oneapiMgmtName = "ze_intel_gpu64.dll"
  198. oneapiMgmtPatterns = OneapiWindowsGlobs
  199. case "linux":
  200. oneapiMgmtName = "libze_intel_gpu.so"
  201. oneapiMgmtPatterns = OneapiLinuxGlobs
  202. default:
  203. return oHandles
  204. }
  205. oneapiLibPaths := FindGPULibs(oneapiMgmtName, oneapiMgmtPatterns)
  206. if len(oneapiLibPaths) > 0 {
  207. oHandles.deviceCount, oHandles.oneapi, oneapiLibPath = LoadOneapiMgmt(oneapiLibPaths)
  208. }
  209. return oHandles
  210. }
  211. func GetGPUInfo() GpuInfoList {
  212. // TODO - consider exploring lspci (and equivalent on windows) to check for
  213. // GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
  214. gpuMutex.Lock()
  215. defer gpuMutex.Unlock()
  216. needRefresh := true
  217. var cHandles *cudaHandles
  218. var oHandles *oneapiHandles
  219. defer func() {
  220. if cHandles != nil {
  221. if cHandles.cudart != nil {
  222. C.cudart_release(*cHandles.cudart)
  223. }
  224. if cHandles.nvcuda != nil {
  225. C.nvcuda_release(*cHandles.nvcuda)
  226. }
  227. if cHandles.nvml != nil {
  228. C.nvml_release(*cHandles.nvml)
  229. }
  230. }
  231. if oHandles != nil {
  232. if oHandles.oneapi != nil {
  233. // TODO - is this needed?
  234. C.oneapi_release(*oHandles.oneapi)
  235. }
  236. }
  237. }()
  238. if !bootstrapped {
  239. slog.Debug("Detecting GPUs")
  240. needRefresh = false
  241. cpuCapability = getCPUCapability()
  242. var memInfo C.mem_info_t
  243. C.cpu_check_ram(&memInfo)
  244. if memInfo.err != nil {
  245. slog.Info("error looking up CPU memory", "error", C.GoString(memInfo.err))
  246. C.free(unsafe.Pointer(memInfo.err))
  247. return []GpuInfo{}
  248. }
  249. cpuInfo := CPUInfo{
  250. GpuInfo: GpuInfo{
  251. Library: "cpu",
  252. Variant: cpuCapability.ToVariant(),
  253. },
  254. }
  255. cpuInfo.TotalMemory = uint64(memInfo.total)
  256. cpuInfo.FreeMemory = uint64(memInfo.free)
  257. cpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
  258. cpus = []CPUInfo{cpuInfo}
  259. // Fallback to CPU mode if we're lacking required vector extensions on x86
  260. if cpuCapability < GPURunnerCPUCapability && runtime.GOARCH == "amd64" {
  261. slog.Warn("CPU does not have minimum vector extensions, GPU inference disabled", "required", GPURunnerCPUCapability.ToString(), "detected", cpuCapability.ToString())
  262. bootstrapped = true
  263. // No need to do any GPU discovery, since we can't run on them
  264. return GpuInfoList{cpus[0].GpuInfo}
  265. }
  266. // On windows we bundle the nvidia library one level above the runner dir
  267. depPath := ""
  268. if runtime.GOOS == "windows" && envconfig.RunnersDir != "" {
  269. depPath = filepath.Dir(envconfig.RunnersDir)
  270. }
  271. // Load ALL libraries
  272. cHandles = initCudaHandles()
  273. // NVIDIA
  274. for i := range cHandles.deviceCount {
  275. if cHandles.cudart != nil || cHandles.nvcuda != nil {
  276. gpuInfo := CudaGPUInfo{
  277. GpuInfo: GpuInfo{
  278. Library: "cuda",
  279. },
  280. index: i,
  281. }
  282. var driverMajor int
  283. var driverMinor int
  284. if cHandles.cudart != nil {
  285. C.cudart_bootstrap(*cHandles.cudart, C.int(i), &memInfo)
  286. } else {
  287. C.nvcuda_bootstrap(*cHandles.nvcuda, C.int(i), &memInfo)
  288. driverMajor = int(cHandles.nvcuda.driver_major)
  289. driverMinor = int(cHandles.nvcuda.driver_minor)
  290. }
  291. if memInfo.err != nil {
  292. slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
  293. C.free(unsafe.Pointer(memInfo.err))
  294. continue
  295. }
  296. if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) {
  297. slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
  298. continue
  299. }
  300. gpuInfo.TotalMemory = uint64(memInfo.total)
  301. gpuInfo.FreeMemory = uint64(memInfo.free)
  302. gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
  303. gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor)
  304. gpuInfo.MinimumMemory = cudaMinimumMemory
  305. gpuInfo.DependencyPath = depPath
  306. gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
  307. gpuInfo.DriverMajor = int(driverMajor)
  308. gpuInfo.DriverMinor = int(driverMinor)
  309. // TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
  310. cudaGPUs = append(cudaGPUs, gpuInfo)
  311. }
  312. }
  313. // Intel
  314. oHandles = initOneAPIHandles()
  315. for d := 0; oHandles.oneapi != nil && d < int(oHandles.oneapi.num_drivers); d++ {
  316. if oHandles.oneapi == nil {
  317. // shouldn't happen
  318. slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers))
  319. continue
  320. }
  321. devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d))
  322. for i := 0; i < int(devCount); i++ {
  323. gpuInfo := OneapiGPUInfo{
  324. GpuInfo: GpuInfo{
  325. Library: "oneapi",
  326. },
  327. driverIndex: d,
  328. gpuIndex: i,
  329. }
  330. // TODO - split bootstrapping from updating free memory
  331. C.oneapi_check_vram(*oHandles.oneapi, C.int(d), C.int(i), &memInfo)
  332. // TODO - convert this to MinimumMemory based on testing...
  333. var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
  334. memInfo.free = C.uint64_t(totalFreeMem)
  335. gpuInfo.TotalMemory = uint64(memInfo.total)
  336. gpuInfo.FreeMemory = uint64(memInfo.free)
  337. gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
  338. gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
  339. // TODO dependency path?
  340. oneapiGPUs = append(oneapiGPUs, gpuInfo)
  341. }
  342. }
  343. rocmGPUs = AMDGetGPUInfo()
  344. bootstrapped = true
  345. }
  346. // For detected GPUs, load library if not loaded
  347. // Refresh free memory usage
  348. if needRefresh {
  349. // TODO - CPU system memory tracking/refresh
  350. var memInfo C.mem_info_t
  351. if cHandles == nil && len(cudaGPUs) > 0 {
  352. cHandles = initCudaHandles()
  353. }
  354. for i, gpu := range cudaGPUs {
  355. if cHandles.nvml != nil {
  356. C.nvml_get_free(*cHandles.nvml, C.int(gpu.index), &memInfo.free, &memInfo.total, &memInfo.used)
  357. } else if cHandles.cudart != nil {
  358. C.cudart_bootstrap(*cHandles.cudart, C.int(gpu.index), &memInfo)
  359. } else if cHandles.nvcuda != nil {
  360. C.nvcuda_get_free(*cHandles.nvcuda, C.int(gpu.index), &memInfo.free, &memInfo.total)
  361. memInfo.used = memInfo.total - memInfo.free
  362. } else {
  363. // shouldn't happen
  364. slog.Warn("no valid cuda library loaded to refresh vram usage")
  365. break
  366. }
  367. if memInfo.err != nil {
  368. slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
  369. C.free(unsafe.Pointer(memInfo.err))
  370. continue
  371. }
  372. if memInfo.free == 0 {
  373. slog.Warn("error looking up nvidia GPU memory")
  374. continue
  375. }
  376. slog.Debug("updating cuda memory data",
  377. "gpu", gpu.ID,
  378. "name", gpu.Name,
  379. slog.Group(
  380. "before",
  381. "total", format.HumanBytes2(gpu.TotalMemory),
  382. "free", format.HumanBytes2(gpu.FreeMemory),
  383. ),
  384. slog.Group(
  385. "now",
  386. "total", format.HumanBytes2(uint64(memInfo.total)),
  387. "free", format.HumanBytes2(uint64(memInfo.free)),
  388. "used", format.HumanBytes2(uint64(memInfo.used)),
  389. ),
  390. )
  391. cudaGPUs[i].FreeMemory = uint64(memInfo.free)
  392. }
  393. if oHandles == nil && len(oneapiGPUs) > 0 {
  394. oHandles = initOneAPIHandles()
  395. }
  396. for i, gpu := range oneapiGPUs {
  397. if oHandles.oneapi == nil {
  398. // shouldn't happen
  399. slog.Warn("nil oneapi handle with device count", "count", oHandles.deviceCount)
  400. continue
  401. }
  402. C.oneapi_check_vram(*oHandles.oneapi, C.int(gpu.driverIndex), C.int(gpu.gpuIndex), &memInfo)
  403. // TODO - convert this to MinimumMemory based on testing...
  404. var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
  405. memInfo.free = C.uint64_t(totalFreeMem)
  406. oneapiGPUs[i].FreeMemory = uint64(memInfo.free)
  407. }
  408. err := RocmGPUInfoList(rocmGPUs).RefreshFreeMemory()
  409. if err != nil {
  410. slog.Debug("problem refreshing ROCm free memory", "error", err)
  411. }
  412. }
  413. resp := []GpuInfo{}
  414. for _, gpu := range cudaGPUs {
  415. resp = append(resp, gpu.GpuInfo)
  416. }
  417. for _, gpu := range rocmGPUs {
  418. resp = append(resp, gpu.GpuInfo)
  419. }
  420. for _, gpu := range oneapiGPUs {
  421. resp = append(resp, gpu.GpuInfo)
  422. }
  423. if len(resp) == 0 {
  424. resp = append(resp, cpus[0].GpuInfo)
  425. }
  426. return resp
  427. }
  428. func GetCPUMem() (memInfo, error) {
  429. var ret memInfo
  430. var info C.mem_info_t
  431. C.cpu_check_ram(&info)
  432. if info.err != nil {
  433. defer C.free(unsafe.Pointer(info.err))
  434. return ret, fmt.Errorf(C.GoString(info.err))
  435. }
  436. ret.FreeMemory = uint64(info.free)
  437. ret.TotalMemory = uint64(info.total)
  438. return ret, nil
  439. }
  440. func FindGPULibs(baseLibName string, defaultPatterns []string) []string {
  441. // Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
  442. var ldPaths []string
  443. var patterns []string
  444. gpuLibPaths := []string{}
  445. slog.Debug("Searching for GPU library", "name", baseLibName)
  446. switch runtime.GOOS {
  447. case "windows":
  448. ldPaths = strings.Split(os.Getenv("PATH"), ";")
  449. case "linux":
  450. ldPaths = strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
  451. default:
  452. return gpuLibPaths
  453. }
  454. // Start with whatever we find in the PATH/LD_LIBRARY_PATH
  455. for _, ldPath := range ldPaths {
  456. d, err := filepath.Abs(ldPath)
  457. if err != nil {
  458. continue
  459. }
  460. patterns = append(patterns, filepath.Join(d, baseLibName+"*"))
  461. }
  462. patterns = append(patterns, defaultPatterns...)
  463. slog.Debug("gpu library search", "globs", patterns)
  464. for _, pattern := range patterns {
  465. // Nvidia PhysX known to return bogus results
  466. if strings.Contains(pattern, "PhysX") {
  467. slog.Debug("skipping PhysX cuda library path", "path", pattern)
  468. continue
  469. }
  470. // Ignore glob discovery errors
  471. matches, _ := filepath.Glob(pattern)
  472. for _, match := range matches {
  473. // Resolve any links so we don't try the same lib multiple times
  474. // and weed out any dups across globs
  475. libPath := match
  476. tmp := match
  477. var err error
  478. for ; err == nil; tmp, err = os.Readlink(libPath) {
  479. if !filepath.IsAbs(tmp) {
  480. tmp = filepath.Join(filepath.Dir(libPath), tmp)
  481. }
  482. libPath = tmp
  483. }
  484. new := true
  485. for _, cmp := range gpuLibPaths {
  486. if cmp == libPath {
  487. new = false
  488. break
  489. }
  490. }
  491. if new {
  492. gpuLibPaths = append(gpuLibPaths, libPath)
  493. }
  494. }
  495. }
  496. slog.Debug("discovered GPU libraries", "paths", gpuLibPaths)
  497. return gpuLibPaths
  498. }
  499. func LoadCUDARTMgmt(cudartLibPaths []string) (int, *C.cudart_handle_t, string) {
  500. var resp C.cudart_init_resp_t
  501. resp.ch.verbose = getVerboseState()
  502. for _, libPath := range cudartLibPaths {
  503. lib := C.CString(libPath)
  504. defer C.free(unsafe.Pointer(lib))
  505. C.cudart_init(lib, &resp)
  506. if resp.err != nil {
  507. slog.Debug("Unable to load cudart", "library", libPath, "error", C.GoString(resp.err))
  508. C.free(unsafe.Pointer(resp.err))
  509. } else {
  510. return int(resp.num_devices), &resp.ch, libPath
  511. }
  512. }
  513. return 0, nil, ""
  514. }
  515. func LoadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string) {
  516. var resp C.nvcuda_init_resp_t
  517. resp.ch.verbose = getVerboseState()
  518. for _, libPath := range nvcudaLibPaths {
  519. lib := C.CString(libPath)
  520. defer C.free(unsafe.Pointer(lib))
  521. C.nvcuda_init(lib, &resp)
  522. if resp.err != nil {
  523. slog.Debug("Unable to load nvcuda", "library", libPath, "error", C.GoString(resp.err))
  524. C.free(unsafe.Pointer(resp.err))
  525. } else {
  526. return int(resp.num_devices), &resp.ch, libPath
  527. }
  528. }
  529. return 0, nil, ""
  530. }
  531. func LoadNVMLMgmt(nvmlLibPaths []string) (*C.nvml_handle_t, string) {
  532. var resp C.nvml_init_resp_t
  533. resp.ch.verbose = getVerboseState()
  534. for _, libPath := range nvmlLibPaths {
  535. lib := C.CString(libPath)
  536. defer C.free(unsafe.Pointer(lib))
  537. C.nvml_init(lib, &resp)
  538. if resp.err != nil {
  539. slog.Info(fmt.Sprintf("Unable to load NVML management library %s: %s", libPath, C.GoString(resp.err)))
  540. C.free(unsafe.Pointer(resp.err))
  541. } else {
  542. return &resp.ch, libPath
  543. }
  544. }
  545. return nil, ""
  546. }
  547. func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) {
  548. var resp C.oneapi_init_resp_t
  549. num_devices := 0
  550. resp.oh.verbose = getVerboseState()
  551. for _, libPath := range oneapiLibPaths {
  552. lib := C.CString(libPath)
  553. defer C.free(unsafe.Pointer(lib))
  554. C.oneapi_init(lib, &resp)
  555. if resp.err != nil {
  556. slog.Debug("Unable to load oneAPI management library", "library", libPath, "error", C.GoString(resp.err))
  557. C.free(unsafe.Pointer(resp.err))
  558. } else {
  559. for i := 0; i < int(resp.oh.num_drivers); i++ {
  560. num_devices += int(C.oneapi_get_device_count(resp.oh, C.int(i)))
  561. }
  562. return num_devices, &resp.oh, libPath
  563. }
  564. }
  565. return 0, nil, ""
  566. }
  567. func getVerboseState() C.uint16_t {
  568. if envconfig.Debug {
  569. return C.uint16_t(1)
  570. }
  571. return C.uint16_t(0)
  572. }
  573. // Given the list of GPUs this instantiation is targeted for,
  574. // figure out the visible devices environment variable
  575. //
  576. // If different libraries are detected, the first one is what we use
  577. func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
  578. if len(l) == 0 {
  579. return "", ""
  580. }
  581. switch l[0].Library {
  582. case "cuda":
  583. return cudaGetVisibleDevicesEnv(l)
  584. case "rocm":
  585. return rocmGetVisibleDevicesEnv(l)
  586. case "oneapi":
  587. return oneapiGetVisibleDevicesEnv(l)
  588. default:
  589. slog.Debug("no filter required for library " + l[0].Library)
  590. return "", ""
  591. }
  592. }