amd_linux.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. package discover
  2. import (
  3. "bufio"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "io/fs"
  8. "log/slog"
  9. "os"
  10. "path/filepath"
  11. "regexp"
  12. "slices"
  13. "sort"
  14. "strconv"
  15. "strings"
  16. "github.com/ollama/ollama/envconfig"
  17. "github.com/ollama/ollama/format"
  18. )
  19. // Discovery logic for AMD/ROCm GPUs
  20. const (
  21. DriverVersionFile = "/sys/module/amdgpu/version"
  22. AMDNodesSysfsDir = "/sys/class/kfd/kfd/topology/nodes/"
  23. GPUPropertiesFileGlob = AMDNodesSysfsDir + "*/properties"
  24. // Prefix with the node dir
  25. GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
  26. // Direct Rendering Manager sysfs location
  27. DRMDeviceDirGlob = "/sys/class/drm/card*/device"
  28. DRMTotalMemoryFile = "mem_info_vram_total"
  29. DRMUsedMemoryFile = "mem_info_vram_used"
  30. // In hex; properties file is in decimal
  31. DRMUniqueIDFile = "unique_id"
  32. DRMVendorFile = "vendor"
  33. DRMDeviceFile = "device"
  34. )
  35. var (
  36. // Used to validate if the given ROCm lib is usable
  37. ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
  38. RocmStandardLocations = []string{"/opt/rocm/lib", "/usr/lib64"}
  39. )
  40. // Gather GPU information from the amdgpu driver if any supported GPUs are detected
  41. // Only called once during bootstrap
  42. func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
  43. resp := []RocmGPUInfo{}
  44. if !AMDDetected() {
  45. return resp, fmt.Errorf("AMD GPUs not detected")
  46. }
  47. // Opportunistic logging of driver version to aid in troubleshooting
  48. driverMajor, driverMinor, err := AMDDriverVersion()
  49. if err != nil {
  50. // TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
  51. slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err)
  52. }
  53. // Determine if the user has already pre-selected which GPUs to look at, then ignore the others
  54. var visibleDevices []string
  55. hipVD := envconfig.HipVisibleDevices() // zero based index only
  56. rocrVD := envconfig.RocrVisibleDevices() // zero based index or UUID
  57. gpuDO := envconfig.GpuDeviceOrdinal() // zero based index
  58. switch {
  59. case rocrVD != "":
  60. visibleDevices = strings.Split(rocrVD, ",")
  61. case hipVD != "":
  62. visibleDevices = strings.Split(hipVD, ",")
  63. case gpuDO != "":
  64. visibleDevices = strings.Split(gpuDO, ",")
  65. }
  66. gfxOverride := envconfig.HsaOverrideGfxVersion()
  67. var supported []string
  68. depPaths := LibraryDirs()
  69. libDir := ""
  70. // The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract
  71. // from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
  72. matches, _ := filepath.Glob(GPUPropertiesFileGlob)
  73. sort.Slice(matches, func(i, j int) bool {
  74. // /sys/class/kfd/kfd/topology/nodes/<number>/properties
  75. a, err := strconv.ParseInt(filepath.Base(filepath.Dir(matches[i])), 10, 64)
  76. if err != nil {
  77. slog.Debug("parse err", "error", err, "match", matches[i])
  78. return false
  79. }
  80. b, err := strconv.ParseInt(filepath.Base(filepath.Dir(matches[j])), 10, 64)
  81. if err != nil {
  82. slog.Debug("parse err", "error", err, "match", matches[i])
  83. return false
  84. }
  85. return a < b
  86. })
  87. gpuCount := 0
  88. for _, match := range matches {
  89. slog.Debug("evaluating amdgpu node " + match)
  90. fp, err := os.Open(match)
  91. if err != nil {
  92. slog.Debug("failed to open sysfs node", "file", match, "error", err)
  93. continue
  94. }
  95. defer fp.Close()
  96. scanner := bufio.NewScanner(fp)
  97. isCPU := false
  98. var major, minor, patch uint64
  99. var vendor, device, uniqueID uint64
  100. for scanner.Scan() {
  101. line := strings.TrimSpace(scanner.Text())
  102. // Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
  103. if strings.HasPrefix(line, "gfx_target_version") {
  104. ver := strings.Fields(line)
  105. // Detect CPUs
  106. if len(ver) == 2 && ver[1] == "0" {
  107. slog.Debug("detected CPU " + match)
  108. isCPU = true
  109. break
  110. }
  111. if len(ver) != 2 || len(ver[1]) < 5 {
  112. slog.Warn("malformed "+match, "gfx_target_version", line)
  113. // If this winds up being a CPU, our offsets may be wrong
  114. continue
  115. }
  116. l := len(ver[1])
  117. var err1, err2, err3 error
  118. patch, err1 = strconv.ParseUint(ver[1][l-2:l], 10, 32)
  119. minor, err2 = strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
  120. major, err3 = strconv.ParseUint(ver[1][:l-4], 10, 32)
  121. if err1 != nil || err2 != nil || err3 != nil {
  122. slog.Debug("malformed int " + line)
  123. continue
  124. }
  125. } else if strings.HasPrefix(line, "vendor_id") {
  126. ver := strings.Fields(line)
  127. if len(ver) != 2 {
  128. slog.Debug("malformed", "vendor_id", line)
  129. continue
  130. }
  131. vendor, err = strconv.ParseUint(ver[1], 10, 64)
  132. if err != nil {
  133. slog.Debug("malformed", "vendor_id", line, "error", err)
  134. }
  135. } else if strings.HasPrefix(line, "device_id") {
  136. ver := strings.Fields(line)
  137. if len(ver) != 2 {
  138. slog.Debug("malformed", "device_id", line)
  139. continue
  140. }
  141. device, err = strconv.ParseUint(ver[1], 10, 64)
  142. if err != nil {
  143. slog.Debug("malformed", "device_id", line, "error", err)
  144. }
  145. } else if strings.HasPrefix(line, "unique_id") {
  146. ver := strings.Fields(line)
  147. if len(ver) != 2 {
  148. slog.Debug("malformed", "unique_id", line)
  149. continue
  150. }
  151. uniqueID, err = strconv.ParseUint(ver[1], 10, 64)
  152. if err != nil {
  153. slog.Debug("malformed", "unique_id", line, "error", err)
  154. }
  155. }
  156. // TODO - any other properties we want to extract and record?
  157. // vendor_id + device_id -> pci lookup for "Name"
  158. // Other metrics that may help us understand relative performance between multiple GPUs
  159. }
  160. // Note: while ./mem_banks/*/used_memory exists, it doesn't appear to take other VRAM consumers
  161. // into consideration, so we instead map the device over to the DRM driver sysfs nodes which
  162. // do reliably report VRAM usage.
  163. if isCPU {
  164. continue
  165. }
  166. // Skip over any GPUs that are masked
  167. if major == 0 && minor == 0 && patch == 0 {
  168. slog.Debug("skipping gpu with gfx000")
  169. continue
  170. }
  171. // Keep track of numeric IDs based on valid GPUs
  172. gpuID := gpuCount
  173. gpuCount += 1
  174. // Look up the memory for the current node
  175. totalMemory := uint64(0)
  176. usedMemory := uint64(0)
  177. var usedFile string
  178. mapping := []struct {
  179. id uint64
  180. filename string
  181. }{
  182. {vendor, DRMVendorFile},
  183. {device, DRMDeviceFile},
  184. {uniqueID, DRMUniqueIDFile}, // Not all devices will report this
  185. }
  186. slog.Debug("mapping amdgpu to drm sysfs nodes", "amdgpu", match, "vendor", vendor, "device", device, "unique_id", uniqueID)
  187. // Map over to DRM location to find the total/free memory
  188. drmMatches, _ := filepath.Glob(DRMDeviceDirGlob)
  189. for _, devDir := range drmMatches {
  190. matched := true
  191. for _, m := range mapping {
  192. if m.id == 0 {
  193. // Null ID means it didn't populate, so we can't use it to match
  194. continue
  195. }
  196. filename := filepath.Join(devDir, m.filename)
  197. buf, err := os.ReadFile(filename)
  198. if err != nil {
  199. slog.Debug("failed to read sysfs node", "file", filename, "error", err)
  200. matched = false
  201. break
  202. }
  203. // values here are in hex, strip off the lead 0x and parse so we can compare the numeric (decimal) values in amdgpu
  204. cmp, err := strconv.ParseUint(strings.TrimPrefix(strings.TrimSpace(string(buf)), "0x"), 16, 64)
  205. if err != nil {
  206. slog.Debug("failed to parse sysfs node", "file", filename, "error", err)
  207. matched = false
  208. break
  209. }
  210. if cmp != m.id {
  211. matched = false
  212. break
  213. }
  214. }
  215. if !matched {
  216. continue
  217. }
  218. // Found the matching DRM directory
  219. slog.Debug("matched", "amdgpu", match, "drm", devDir)
  220. totalFile := filepath.Join(devDir, DRMTotalMemoryFile)
  221. buf, err := os.ReadFile(totalFile)
  222. if err != nil {
  223. slog.Debug("failed to read sysfs node", "file", totalFile, "error", err)
  224. break
  225. }
  226. totalMemory, err = strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64)
  227. if err != nil {
  228. slog.Debug("failed to parse sysfs node", "file", totalFile, "error", err)
  229. break
  230. }
  231. usedFile = filepath.Join(devDir, DRMUsedMemoryFile)
  232. usedMemory, err = getFreeMemory(usedFile)
  233. if err != nil {
  234. slog.Debug("failed to update used memory", "error", err)
  235. }
  236. break
  237. }
  238. var name string
  239. // TODO - PCI ID lookup
  240. if vendor > 0 && device > 0 {
  241. name = fmt.Sprintf("%04x:%04x", vendor, device)
  242. }
  243. // Favor UUIDs if available to reduce possibility of getting the numeric IDs wrong
  244. var ID string
  245. if uniqueID != 0 {
  246. ID = fmt.Sprintf("GPU-%016x", uniqueID)
  247. } else {
  248. ID = strconv.Itoa(gpuID)
  249. }
  250. gpuInfo := RocmGPUInfo{
  251. GpuInfo: GpuInfo{
  252. Library: "rocm",
  253. memInfo: memInfo{
  254. TotalMemory: totalMemory,
  255. FreeMemory: (totalMemory - usedMemory),
  256. },
  257. ID: ID,
  258. Name: name,
  259. Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch),
  260. MinimumMemory: rocmMinimumMemory,
  261. DriverMajor: driverMajor,
  262. DriverMinor: driverMinor,
  263. },
  264. usedFilepath: usedFile,
  265. index: gpuID,
  266. }
  267. // iGPU detection, remove this check once we can support an iGPU variant of the rocm library
  268. if totalMemory < IGPUMemLimit {
  269. reason := "unsupported Radeon iGPU detected skipping"
  270. slog.Info(reason, "id", gpuID, "total", format.HumanBytes2(totalMemory))
  271. unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{
  272. GpuInfo: gpuInfo.GpuInfo,
  273. Reason: reason,
  274. })
  275. continue
  276. }
  277. minVer, err := strconv.Atoi(RocmComputeMajorMin)
  278. if err != nil {
  279. slog.Error("invalid RocmComputeMajorMin setting", "value", RocmComputeMajorMin, "error", err)
  280. }
  281. if int(major) < minVer {
  282. reason := fmt.Sprintf("amdgpu too old gfx%d%x%x", major, minor, patch)
  283. slog.Warn(reason, "gpu", gpuID)
  284. unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{
  285. GpuInfo: gpuInfo.GpuInfo,
  286. Reason: reason,
  287. })
  288. continue
  289. }
  290. slog.Debug("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
  291. slog.Debug("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
  292. // If the user wants to filter to a subset of devices, filter out if we aren't a match
  293. if len(visibleDevices) > 0 {
  294. include := false
  295. for _, visible := range visibleDevices {
  296. if visible == gpuInfo.ID || visible == strconv.Itoa(gpuInfo.index) {
  297. include = true
  298. break
  299. }
  300. }
  301. if !include {
  302. reason := "filtering out device per user request"
  303. slog.Info(reason, "id", gpuInfo.ID, "visible_devices", visibleDevices)
  304. unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{
  305. GpuInfo: gpuInfo.GpuInfo,
  306. Reason: reason,
  307. })
  308. continue
  309. }
  310. }
  311. // Final validation is gfx compatibility - load the library if we haven't already loaded it
  312. // even if the user overrides, we still need to validate the library
  313. if libDir == "" {
  314. libDir, err = AMDValidateLibDir()
  315. if err != nil {
  316. err = fmt.Errorf("unable to verify rocm library: %w", err)
  317. slog.Warn(err.Error())
  318. unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{
  319. GpuInfo: gpuInfo.GpuInfo,
  320. Reason: err.Error(),
  321. })
  322. return nil, err
  323. }
  324. depPaths = append(depPaths, libDir)
  325. }
  326. gpuInfo.DependencyPath = depPaths
  327. if gfxOverride == "" {
  328. // Only load supported list once
  329. if len(supported) == 0 {
  330. supported, err = GetSupportedGFX(libDir)
  331. if err != nil {
  332. err = fmt.Errorf("failed to lookup supported GFX types: %w", err)
  333. slog.Warn(err.Error())
  334. unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{
  335. GpuInfo: gpuInfo.GpuInfo,
  336. Reason: err.Error(),
  337. })
  338. return nil, err
  339. }
  340. slog.Debug("rocm supported GPUs", "types", supported)
  341. }
  342. gfx := gpuInfo.Compute
  343. if !slices.Contains[[]string, string](supported, gfx) {
  344. reason := fmt.Sprintf("amdgpu is not supported (supported types:%s)", supported)
  345. slog.Warn(reason, "gpu_type", gfx, "gpu", gpuInfo.ID, "library", libDir)
  346. unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{
  347. GpuInfo: gpuInfo.GpuInfo,
  348. Reason: reason,
  349. })
  350. // TODO - consider discrete markdown just for ROCM troubleshooting?
  351. slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
  352. continue
  353. } else {
  354. slog.Info("amdgpu is supported", "gpu", gpuInfo.ID, "gpu_type", gfx)
  355. }
  356. } else {
  357. slog.Info("skipping rocm gfx compatibility check", "HSA_OVERRIDE_GFX_VERSION", gfxOverride)
  358. }
  359. // Check for env var workarounds
  360. if name == "1002:687f" { // Vega RX 56
  361. gpuInfo.EnvWorkarounds = append(gpuInfo.EnvWorkarounds, [2]string{"HSA_ENABLE_SDMA", "0"})
  362. }
  363. // The GPU has passed all the verification steps and is supported
  364. resp = append(resp, gpuInfo)
  365. }
  366. if len(resp) == 0 {
  367. err := fmt.Errorf("no compatible amdgpu devices detected")
  368. slog.Info(err.Error())
  369. return nil, err
  370. }
  371. if err := verifyKFDDriverAccess(); err != nil {
  372. err = fmt.Errorf("amdgpu devices detected but permission problems block access: %w", err)
  373. slog.Error(err.Error())
  374. return nil, err
  375. }
  376. return resp, nil
  377. }
  378. // Quick check for AMD driver so we can skip amdgpu discovery if not present
  379. func AMDDetected() bool {
  380. // Some driver versions (older?) don't have a version file, so just lookup the parent dir
  381. sysfsDir := filepath.Dir(DriverVersionFile)
  382. _, err := os.Stat(sysfsDir)
  383. if errors.Is(err, os.ErrNotExist) {
  384. slog.Debug("amdgpu driver not detected " + sysfsDir)
  385. return false
  386. } else if err != nil {
  387. slog.Debug("error looking up amd driver", "path", sysfsDir, "error", err)
  388. return false
  389. }
  390. return true
  391. }
  392. // Prefer to use host installed ROCm, as long as it meets our minimum requirements
  393. // failing that, tell the user how to download it on their own
  394. func AMDValidateLibDir() (string, error) {
  395. libDir, err := commonAMDValidateLibDir()
  396. if err == nil {
  397. return libDir, nil
  398. }
  399. // Well known ollama installer path
  400. installedRocmDir := "/usr/share/ollama/lib/rocm"
  401. if rocmLibUsable(installedRocmDir) {
  402. return installedRocmDir, nil
  403. }
  404. // If we still haven't found a usable rocm, the user will have to install it on their own
  405. slog.Warn("amdgpu detected, but no compatible rocm library found. Either install rocm v6, or follow manual install instructions at https://github.com/ollama/ollama/blob/main/docs/linux.md#manual-install")
  406. return "", errors.New("no suitable rocm found, falling back to CPU")
  407. }
  408. func AMDDriverVersion() (driverMajor, driverMinor int, err error) {
  409. _, err = os.Stat(DriverVersionFile)
  410. if err != nil {
  411. return 0, 0, fmt.Errorf("amdgpu version file missing: %s %w", DriverVersionFile, err)
  412. }
  413. fp, err := os.Open(DriverVersionFile)
  414. if err != nil {
  415. return 0, 0, err
  416. }
  417. defer fp.Close()
  418. verString, err := io.ReadAll(fp)
  419. if err != nil {
  420. return 0, 0, err
  421. }
  422. pattern := `\A(\d+)\.(\d+).*`
  423. regex := regexp.MustCompile(pattern)
  424. match := regex.FindStringSubmatch(string(verString))
  425. if len(match) < 2 {
  426. return 0, 0, fmt.Errorf("malformed version string %s", string(verString))
  427. }
  428. driverMajor, err = strconv.Atoi(match[1])
  429. if err != nil {
  430. return 0, 0, err
  431. }
  432. driverMinor, err = strconv.Atoi(match[2])
  433. if err != nil {
  434. return 0, 0, err
  435. }
  436. return driverMajor, driverMinor, nil
  437. }
  438. func (gpus RocmGPUInfoList) RefreshFreeMemory() error {
  439. if len(gpus) == 0 {
  440. return nil
  441. }
  442. for i := range gpus {
  443. usedMemory, err := getFreeMemory(gpus[i].usedFilepath)
  444. if err != nil {
  445. return err
  446. }
  447. slog.Debug("updating rocm free memory", "gpu", gpus[i].ID, "name", gpus[i].Name, "before", format.HumanBytes2(gpus[i].FreeMemory), "now", format.HumanBytes2(gpus[i].TotalMemory-usedMemory))
  448. gpus[i].FreeMemory = gpus[i].TotalMemory - usedMemory
  449. }
  450. return nil
  451. }
  452. func getFreeMemory(usedFile string) (uint64, error) {
  453. buf, err := os.ReadFile(usedFile)
  454. if err != nil {
  455. return 0, fmt.Errorf("failed to read sysfs node %s %w", usedFile, err)
  456. }
  457. usedMemory, err := strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64)
  458. if err != nil {
  459. slog.Debug("failed to parse sysfs node", "file", usedFile, "error", err)
  460. return 0, fmt.Errorf("failed to parse sysfs node %s %w", usedFile, err)
  461. }
  462. return usedMemory, nil
  463. }
  464. func verifyKFDDriverAccess() error {
  465. // Verify we have permissions - either running as root, or we have group access to the driver
  466. fd, err := os.OpenFile("/dev/kfd", os.O_RDWR, 0o666)
  467. if err != nil {
  468. if errors.Is(err, fs.ErrPermission) {
  469. return fmt.Errorf("permissions not set up properly. Either run ollama as root, or add you user account to the render group. %w", err)
  470. } else if errors.Is(err, fs.ErrNotExist) {
  471. // Container runtime failure?
  472. return fmt.Errorf("kfd driver not loaded. If running in a container, remember to include '--device /dev/kfd --device /dev/dri'")
  473. }
  474. return fmt.Errorf("failed to check permission on /dev/kfd: %w", err)
  475. }
  476. fd.Close()
  477. return nil
  478. }
  479. func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
  480. ids := []string{}
  481. for _, info := range gpuInfo {
  482. if info.Library != "rocm" {
  483. // TODO shouldn't happen if things are wired correctly...
  484. slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
  485. continue
  486. }
  487. ids = append(ids, info.ID)
  488. }
  489. // There are 3 potential env vars to use to select GPUs.
  490. // ROCR_VISIBLE_DEVICES supports UUID or numeric so is our preferred on linux
  491. // GPU_DEVICE_ORDINAL supports numeric IDs only
  492. // HIP_VISIBLE_DEVICES supports numeric IDs only
  493. return "ROCR_VISIBLE_DEVICES", strings.Join(ids, ",")
  494. }