gpu.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580
  1. //go:build linux || windows
  2. package gpu
  3. /*
  4. #cgo linux LDFLAGS: -lrt -lpthread -ldl -lstdc++ -lm
  5. #cgo windows LDFLAGS: -lpthread
  6. #include "gpu_info.h"
  7. */
  8. import "C"
  9. import (
  10. "fmt"
  11. "log/slog"
  12. "os"
  13. "path/filepath"
  14. "runtime"
  15. "strings"
  16. "sync"
  17. "unsafe"
  18. "github.com/ollama/ollama/envconfig"
  19. "github.com/ollama/ollama/format"
  20. )
  21. type cudaHandles struct {
  22. deviceCount int
  23. cudart *C.cudart_handle_t
  24. nvcuda *C.nvcuda_handle_t
  25. }
  26. type oneapiHandles struct {
  27. oneapi *C.oneapi_handle_t
  28. deviceCount int
  29. }
  30. const (
  31. cudaMinimumMemory = 457 * format.MebiByte
  32. rocmMinimumMemory = 457 * format.MebiByte
  33. // TODO OneAPI minimum memory
  34. )
  35. var (
  36. gpuMutex sync.Mutex
  37. bootstrapped bool
  38. cpuCapability CPUCapability
  39. cpus []CPUInfo
  40. cudaGPUs []CudaGPUInfo
  41. nvcudaLibPath string
  42. cudartLibPath string
  43. oneapiLibPath string
  44. rocmGPUs []RocmGPUInfo
  45. oneapiGPUs []OneapiGPUInfo
  46. )
  47. // With our current CUDA compile flags, older than 5.0 will not work properly
  48. var CudaComputeMin = [2]C.int{5, 0}
  49. var RocmComputeMin = 9
  50. // TODO find a better way to detect iGPU instead of minimum memory
  51. const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU
  52. var CudartLinuxGlobs = []string{
  53. "/usr/local/cuda/lib64/libcudart.so*",
  54. "/usr/lib/x86_64-linux-gnu/nvidia/current/libcudart.so*",
  55. "/usr/lib/x86_64-linux-gnu/libcudart.so*",
  56. "/usr/lib/wsl/lib/libcudart.so*",
  57. "/usr/lib/wsl/drivers/*/libcudart.so*",
  58. "/opt/cuda/lib64/libcudart.so*",
  59. "/usr/local/cuda*/targets/aarch64-linux/lib/libcudart.so*",
  60. "/usr/lib/aarch64-linux-gnu/nvidia/current/libcudart.so*",
  61. "/usr/lib/aarch64-linux-gnu/libcudart.so*",
  62. "/usr/local/cuda/lib*/libcudart.so*",
  63. "/usr/lib*/libcudart.so*",
  64. "/usr/local/lib*/libcudart.so*",
  65. }
  66. var CudartWindowsGlobs = []string{
  67. "c:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*\\bin\\cudart64_*.dll",
  68. }
  69. var NvcudaLinuxGlobs = []string{
  70. "/usr/local/cuda*/targets/*/lib/libcuda.so*",
  71. "/usr/lib/*-linux-gnu/nvidia/current/libcuda.so*",
  72. "/usr/lib/*-linux-gnu/libcuda.so*",
  73. "/usr/lib/wsl/lib/libcuda.so*",
  74. "/usr/lib/wsl/drivers/*/libcuda.so*",
  75. "/opt/cuda/lib*/libcuda.so*",
  76. "/usr/local/cuda/lib*/libcuda.so*",
  77. "/usr/lib*/libcuda.so*",
  78. "/usr/local/lib*/libcuda.so*",
  79. }
  80. var NvcudaWindowsGlobs = []string{
  81. "c:\\windows\\system*\\nvcuda.dll",
  82. }
  83. var OneapiWindowsGlobs = []string{
  84. "c:\\Windows\\System32\\DriverStore\\FileRepository\\*\\ze_intel_gpu64.dll",
  85. }
  86. var OneapiLinuxGlobs = []string{
  87. "/usr/lib/x86_64-linux-gnu/libze_intel_gpu.so*",
  88. "/usr/lib*/libze_intel_gpu.so*",
  89. }
  90. // Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed.
  91. // Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
  92. var CudaTegra string = os.Getenv("JETSON_JETPACK")
  93. // Note: gpuMutex must already be held
  94. func initCudaHandles() *cudaHandles {
  95. // TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
  96. cHandles := &cudaHandles{}
  97. // Short Circuit if we already know which library to use
  98. if nvcudaLibPath != "" {
  99. cHandles.deviceCount, cHandles.nvcuda, _ = LoadNVCUDAMgmt([]string{nvcudaLibPath})
  100. return cHandles
  101. }
  102. if cudartLibPath != "" {
  103. cHandles.deviceCount, cHandles.cudart, _ = LoadCUDARTMgmt([]string{cudartLibPath})
  104. return cHandles
  105. }
  106. slog.Debug("searching for GPU discovery libraries for NVIDIA")
  107. var cudartMgmtName string
  108. var cudartMgmtPatterns []string
  109. var nvcudaMgmtName string
  110. var nvcudaMgmtPatterns []string
  111. tmpDir, _ := PayloadsDir()
  112. switch runtime.GOOS {
  113. case "windows":
  114. cudartMgmtName = "cudart64_*.dll"
  115. localAppData := os.Getenv("LOCALAPPDATA")
  116. cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)}
  117. cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...)
  118. // Aligned with driver, we can't carry as payloads
  119. nvcudaMgmtName = "nvcuda.dll"
  120. nvcudaMgmtPatterns = NvcudaWindowsGlobs
  121. case "linux":
  122. cudartMgmtName = "libcudart.so*"
  123. if tmpDir != "" {
  124. // TODO - add "payloads" for subprocess
  125. cudartMgmtPatterns = []string{filepath.Join(tmpDir, "cuda*", cudartMgmtName)}
  126. }
  127. cudartMgmtPatterns = append(cudartMgmtPatterns, CudartLinuxGlobs...)
  128. // Aligned with driver, we can't carry as payloads
  129. nvcudaMgmtName = "libcuda.so*"
  130. nvcudaMgmtPatterns = NvcudaLinuxGlobs
  131. default:
  132. return cHandles
  133. }
  134. nvcudaLibPaths := FindGPULibs(nvcudaMgmtName, nvcudaMgmtPatterns)
  135. if len(nvcudaLibPaths) > 0 {
  136. deviceCount, nvcuda, libPath := LoadNVCUDAMgmt(nvcudaLibPaths)
  137. if nvcuda != nil {
  138. slog.Debug("detected GPUs", "count", deviceCount, "library", libPath)
  139. cHandles.nvcuda = nvcuda
  140. cHandles.deviceCount = deviceCount
  141. nvcudaLibPath = libPath
  142. return cHandles
  143. }
  144. }
  145. cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns)
  146. if len(cudartLibPaths) > 0 {
  147. deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths)
  148. if cudart != nil {
  149. slog.Debug("detected GPUs", "library", libPath, "count", deviceCount)
  150. cHandles.cudart = cudart
  151. cHandles.deviceCount = deviceCount
  152. cudartLibPath = libPath
  153. return cHandles
  154. }
  155. }
  156. return cHandles
  157. }
  158. // Note: gpuMutex must already be held
  159. func initOneAPIHandles() *oneapiHandles {
  160. oHandles := &oneapiHandles{}
  161. var oneapiMgmtName string
  162. var oneapiMgmtPatterns []string
  163. // Short Circuit if we already know which library to use
  164. if oneapiLibPath != "" {
  165. oHandles.deviceCount, oHandles.oneapi, _ = LoadOneapiMgmt([]string{oneapiLibPath})
  166. return oHandles
  167. }
  168. switch runtime.GOOS {
  169. case "windows":
  170. oneapiMgmtName = "ze_intel_gpu64.dll"
  171. oneapiMgmtPatterns = OneapiWindowsGlobs
  172. case "linux":
  173. oneapiMgmtName = "libze_intel_gpu.so"
  174. oneapiMgmtPatterns = OneapiLinuxGlobs
  175. default:
  176. return oHandles
  177. }
  178. oneapiLibPaths := FindGPULibs(oneapiMgmtName, oneapiMgmtPatterns)
  179. if len(oneapiLibPaths) > 0 {
  180. oHandles.deviceCount, oHandles.oneapi, oneapiLibPath = LoadOneapiMgmt(oneapiLibPaths)
  181. }
  182. return oHandles
  183. }
  184. func GetGPUInfo() GpuInfoList {
  185. // TODO - consider exploring lspci (and equivalent on windows) to check for
  186. // GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
  187. gpuMutex.Lock()
  188. defer gpuMutex.Unlock()
  189. needRefresh := true
  190. var cHandles *cudaHandles
  191. var oHandles *oneapiHandles
  192. defer func() {
  193. if cHandles != nil {
  194. if cHandles.cudart != nil {
  195. C.cudart_release(*cHandles.cudart)
  196. }
  197. if cHandles.nvcuda != nil {
  198. C.nvcuda_release(*cHandles.nvcuda)
  199. }
  200. }
  201. if oHandles != nil {
  202. if oHandles.oneapi != nil {
  203. // TODO - is this needed?
  204. C.oneapi_release(*oHandles.oneapi)
  205. }
  206. }
  207. }()
  208. if !bootstrapped {
  209. slog.Debug("Detecting GPUs")
  210. needRefresh = false
  211. cpuCapability = getCPUCapability()
  212. var memInfo C.mem_info_t
  213. C.cpu_check_ram(&memInfo)
  214. if memInfo.err != nil {
  215. slog.Info("error looking up CPU memory", "error", C.GoString(memInfo.err))
  216. C.free(unsafe.Pointer(memInfo.err))
  217. return []GpuInfo{}
  218. }
  219. cpuInfo := CPUInfo{
  220. GpuInfo: GpuInfo{
  221. Library: "cpu",
  222. Variant: cpuCapability.ToVariant(),
  223. },
  224. }
  225. cpuInfo.TotalMemory = uint64(memInfo.total)
  226. cpuInfo.FreeMemory = uint64(memInfo.free)
  227. cpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
  228. cpus = []CPUInfo{cpuInfo}
  229. // Fallback to CPU mode if we're lacking required vector extensions on x86
  230. if cpuCapability < GPURunnerCPUCapability && runtime.GOARCH == "amd64" {
  231. slog.Warn("CPU does not have minimum vector extensions, GPU inference disabled", "required", GPURunnerCPUCapability.ToString(), "detected", cpuCapability.ToString())
  232. bootstrapped = true
  233. // No need to do any GPU discovery, since we can't run on them
  234. return GpuInfoList{cpus[0].GpuInfo}
  235. }
  236. // On windows we bundle the nvidia library one level above the runner dir
  237. depPath := ""
  238. if runtime.GOOS == "windows" && envconfig.RunnersDir != "" {
  239. depPath = filepath.Dir(envconfig.RunnersDir)
  240. }
  241. // Load ALL libraries
  242. cHandles = initCudaHandles()
  243. // NVIDIA
  244. for i := range cHandles.deviceCount {
  245. if cHandles.cudart != nil || cHandles.nvcuda != nil {
  246. gpuInfo := CudaGPUInfo{
  247. GpuInfo: GpuInfo{
  248. Library: "cuda",
  249. },
  250. index: i,
  251. }
  252. var driverMajor int
  253. var driverMinor int
  254. if cHandles.cudart != nil {
  255. C.cudart_bootstrap(*cHandles.cudart, C.int(i), &memInfo)
  256. } else {
  257. C.nvcuda_bootstrap(*cHandles.nvcuda, C.int(i), &memInfo)
  258. driverMajor = int(cHandles.nvcuda.driver_major)
  259. driverMinor = int(cHandles.nvcuda.driver_minor)
  260. }
  261. if memInfo.err != nil {
  262. slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
  263. C.free(unsafe.Pointer(memInfo.err))
  264. continue
  265. }
  266. if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) {
  267. slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
  268. continue
  269. }
  270. gpuInfo.TotalMemory = uint64(memInfo.total)
  271. gpuInfo.FreeMemory = uint64(memInfo.free)
  272. gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
  273. gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor)
  274. gpuInfo.MinimumMemory = cudaMinimumMemory
  275. gpuInfo.DependencyPath = depPath
  276. gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
  277. gpuInfo.DriverMajor = int(driverMajor)
  278. gpuInfo.DriverMinor = int(driverMinor)
  279. // TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
  280. cudaGPUs = append(cudaGPUs, gpuInfo)
  281. }
  282. }
  283. // Intel
  284. oHandles = initOneAPIHandles()
  285. for d := 0; oHandles.oneapi != nil && d < int(oHandles.oneapi.num_drivers); d++ {
  286. if oHandles.oneapi == nil {
  287. // shouldn't happen
  288. slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers))
  289. continue
  290. }
  291. devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d))
  292. for i := 0; i < int(devCount); i++ {
  293. gpuInfo := OneapiGPUInfo{
  294. GpuInfo: GpuInfo{
  295. Library: "oneapi",
  296. },
  297. driverIndex: d,
  298. gpuIndex: i,
  299. }
  300. // TODO - split bootstrapping from updating free memory
  301. C.oneapi_check_vram(*oHandles.oneapi, C.int(d), C.int(i), &memInfo)
  302. // TODO - convert this to MinimumMemory based on testing...
  303. var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
  304. memInfo.free = C.uint64_t(totalFreeMem)
  305. gpuInfo.TotalMemory = uint64(memInfo.total)
  306. gpuInfo.FreeMemory = uint64(memInfo.free)
  307. gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
  308. gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
  309. // TODO dependency path?
  310. oneapiGPUs = append(oneapiGPUs, gpuInfo)
  311. }
  312. }
  313. rocmGPUs = AMDGetGPUInfo()
  314. bootstrapped = true
  315. }
  316. // For detected GPUs, load library if not loaded
  317. // Refresh free memory usage
  318. if needRefresh {
  319. // TODO - CPU system memory tracking/refresh
  320. var memInfo C.mem_info_t
  321. if cHandles == nil && len(cudaGPUs) > 0 {
  322. cHandles = initCudaHandles()
  323. }
  324. for i, gpu := range cudaGPUs {
  325. if cHandles.cudart != nil {
  326. C.cudart_bootstrap(*cHandles.cudart, C.int(gpu.index), &memInfo)
  327. } else {
  328. C.nvcuda_get_free(*cHandles.nvcuda, C.int(gpu.index), &memInfo.free)
  329. }
  330. if memInfo.err != nil {
  331. slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
  332. C.free(unsafe.Pointer(memInfo.err))
  333. continue
  334. }
  335. if memInfo.free == 0 {
  336. slog.Warn("error looking up nvidia GPU memory")
  337. continue
  338. }
  339. slog.Debug("updating cuda free memory", "gpu", gpu.ID, "name", gpu.Name, "before", format.HumanBytes2(gpu.FreeMemory), "now", format.HumanBytes2(uint64(memInfo.free)))
  340. cudaGPUs[i].FreeMemory = uint64(memInfo.free)
  341. }
  342. if oHandles == nil && len(oneapiGPUs) > 0 {
  343. oHandles = initOneAPIHandles()
  344. }
  345. for i, gpu := range oneapiGPUs {
  346. if oHandles.oneapi == nil {
  347. // shouldn't happen
  348. slog.Warn("nil oneapi handle with device count", "count", oHandles.deviceCount)
  349. continue
  350. }
  351. C.oneapi_check_vram(*oHandles.oneapi, C.int(gpu.driverIndex), C.int(gpu.gpuIndex), &memInfo)
  352. // TODO - convert this to MinimumMemory based on testing...
  353. var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
  354. memInfo.free = C.uint64_t(totalFreeMem)
  355. oneapiGPUs[i].FreeMemory = uint64(memInfo.free)
  356. }
  357. err := RocmGPUInfoList(rocmGPUs).RefreshFreeMemory()
  358. if err != nil {
  359. slog.Debug("problem refreshing ROCm free memory", "error", err)
  360. }
  361. }
  362. resp := []GpuInfo{}
  363. for _, gpu := range cudaGPUs {
  364. resp = append(resp, gpu.GpuInfo)
  365. }
  366. for _, gpu := range rocmGPUs {
  367. resp = append(resp, gpu.GpuInfo)
  368. }
  369. for _, gpu := range oneapiGPUs {
  370. resp = append(resp, gpu.GpuInfo)
  371. }
  372. if len(resp) == 0 {
  373. resp = append(resp, cpus[0].GpuInfo)
  374. }
  375. return resp
  376. }
  377. func GetCPUMem() (memInfo, error) {
  378. var ret memInfo
  379. var info C.mem_info_t
  380. C.cpu_check_ram(&info)
  381. if info.err != nil {
  382. defer C.free(unsafe.Pointer(info.err))
  383. return ret, fmt.Errorf(C.GoString(info.err))
  384. }
  385. ret.FreeMemory = uint64(info.free)
  386. ret.TotalMemory = uint64(info.total)
  387. return ret, nil
  388. }
  389. func FindGPULibs(baseLibName string, defaultPatterns []string) []string {
  390. // Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
  391. var ldPaths []string
  392. var patterns []string
  393. gpuLibPaths := []string{}
  394. slog.Debug("Searching for GPU library", "name", baseLibName)
  395. switch runtime.GOOS {
  396. case "windows":
  397. ldPaths = strings.Split(os.Getenv("PATH"), ";")
  398. case "linux":
  399. ldPaths = strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
  400. default:
  401. return gpuLibPaths
  402. }
  403. // Start with whatever we find in the PATH/LD_LIBRARY_PATH
  404. for _, ldPath := range ldPaths {
  405. d, err := filepath.Abs(ldPath)
  406. if err != nil {
  407. continue
  408. }
  409. patterns = append(patterns, filepath.Join(d, baseLibName+"*"))
  410. }
  411. patterns = append(patterns, defaultPatterns...)
  412. slog.Debug("gpu library search", "globs", patterns)
  413. for _, pattern := range patterns {
  414. // Nvidia PhysX known to return bogus results
  415. if strings.Contains(pattern, "PhysX") {
  416. slog.Debug("skipping PhysX cuda library path", "path", pattern)
  417. continue
  418. }
  419. // Ignore glob discovery errors
  420. matches, _ := filepath.Glob(pattern)
  421. for _, match := range matches {
  422. // Resolve any links so we don't try the same lib multiple times
  423. // and weed out any dups across globs
  424. libPath := match
  425. tmp := match
  426. var err error
  427. for ; err == nil; tmp, err = os.Readlink(libPath) {
  428. if !filepath.IsAbs(tmp) {
  429. tmp = filepath.Join(filepath.Dir(libPath), tmp)
  430. }
  431. libPath = tmp
  432. }
  433. new := true
  434. for _, cmp := range gpuLibPaths {
  435. if cmp == libPath {
  436. new = false
  437. break
  438. }
  439. }
  440. if new {
  441. gpuLibPaths = append(gpuLibPaths, libPath)
  442. }
  443. }
  444. }
  445. slog.Debug("discovered GPU libraries", "paths", gpuLibPaths)
  446. return gpuLibPaths
  447. }
  448. func LoadCUDARTMgmt(cudartLibPaths []string) (int, *C.cudart_handle_t, string) {
  449. var resp C.cudart_init_resp_t
  450. resp.ch.verbose = getVerboseState()
  451. for _, libPath := range cudartLibPaths {
  452. lib := C.CString(libPath)
  453. defer C.free(unsafe.Pointer(lib))
  454. C.cudart_init(lib, &resp)
  455. if resp.err != nil {
  456. slog.Debug("Unable to load cudart", "library", libPath, "error", C.GoString(resp.err))
  457. C.free(unsafe.Pointer(resp.err))
  458. } else {
  459. return int(resp.num_devices), &resp.ch, libPath
  460. }
  461. }
  462. return 0, nil, ""
  463. }
  464. func LoadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string) {
  465. var resp C.nvcuda_init_resp_t
  466. resp.ch.verbose = getVerboseState()
  467. for _, libPath := range nvcudaLibPaths {
  468. lib := C.CString(libPath)
  469. defer C.free(unsafe.Pointer(lib))
  470. C.nvcuda_init(lib, &resp)
  471. if resp.err != nil {
  472. slog.Debug("Unable to load nvcuda", "library", libPath, "error", C.GoString(resp.err))
  473. C.free(unsafe.Pointer(resp.err))
  474. } else {
  475. return int(resp.num_devices), &resp.ch, libPath
  476. }
  477. }
  478. return 0, nil, ""
  479. }
  480. func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) {
  481. var resp C.oneapi_init_resp_t
  482. num_devices := 0
  483. resp.oh.verbose = getVerboseState()
  484. for _, libPath := range oneapiLibPaths {
  485. lib := C.CString(libPath)
  486. defer C.free(unsafe.Pointer(lib))
  487. C.oneapi_init(lib, &resp)
  488. if resp.err != nil {
  489. slog.Debug("Unable to load oneAPI management library", "library", libPath, "error", C.GoString(resp.err))
  490. C.free(unsafe.Pointer(resp.err))
  491. } else {
  492. for i := 0; i < int(resp.oh.num_drivers); i++ {
  493. num_devices += int(C.oneapi_get_device_count(resp.oh, C.int(i)))
  494. }
  495. return num_devices, &resp.oh, libPath
  496. }
  497. }
  498. return 0, nil, ""
  499. }
  500. func getVerboseState() C.uint16_t {
  501. if envconfig.Debug {
  502. return C.uint16_t(1)
  503. }
  504. return C.uint16_t(0)
  505. }
  506. // Given the list of GPUs this instantiation is targeted for,
  507. // figure out the visible devices environment variable
  508. //
  509. // If different libraries are detected, the first one is what we use
  510. func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
  511. if len(l) == 0 {
  512. return "", ""
  513. }
  514. switch l[0].Library {
  515. case "cuda":
  516. return cudaGetVisibleDevicesEnv(l)
  517. case "rocm":
  518. return rocmGetVisibleDevicesEnv(l)
  519. case "oneapi":
  520. return oneapiGetVisibleDevicesEnv(l)
  521. default:
  522. slog.Debug("no filter required for library " + l[0].Library)
  523. return "", ""
  524. }
  525. }