amd_linux.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. package gpu
  2. import (
  3. "bufio"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "io/fs"
  8. "log/slog"
  9. "os"
  10. "path/filepath"
  11. "regexp"
  12. "slices"
  13. "sort"
  14. "strconv"
  15. "strings"
  16. "github.com/ollama/ollama/envconfig"
  17. "github.com/ollama/ollama/format"
  18. )
  19. // Discovery logic for AMD/ROCm GPUs
  20. const (
  21. DriverVersionFile = "/sys/module/amdgpu/version"
  22. AMDNodesSysfsDir = "/sys/class/kfd/kfd/topology/nodes/"
  23. GPUPropertiesFileGlob = AMDNodesSysfsDir + "*/properties"
  24. // Prefix with the node dir
  25. GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
  26. // Direct Rendering Manager sysfs location
  27. DRMDeviceDirGlob = "/sys/class/drm/card*/device"
  28. DRMTotalMemoryFile = "mem_info_vram_total"
  29. DRMUsedMemoryFile = "mem_info_vram_used"
  30. // In hex; properties file is in decimal
  31. DRMUniqueIDFile = "unique_id"
  32. DRMVendorFile = "vendor"
  33. DRMDeviceFile = "device"
  34. )
  35. var (
  36. // Used to validate if the given ROCm lib is usable
  37. ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
  38. RocmStandardLocations = []string{"/opt/rocm/lib", "/usr/lib64"}
  39. )
  40. // Gather GPU information from the amdgpu driver if any supported GPUs are detected
  41. func AMDGetGPUInfo() []RocmGPUInfo {
  42. resp := []RocmGPUInfo{}
  43. if !AMDDetected() {
  44. return resp
  45. }
  46. // Opportunistic logging of driver version to aid in troubleshooting
  47. driverMajor, driverMinor, err := AMDDriverVersion()
  48. if err != nil {
  49. // TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
  50. slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err)
  51. }
  52. // Determine if the user has already pre-selected which GPUs to look at, then ignore the others
  53. var visibleDevices []string
  54. hipVD := envconfig.HipVisibleDevices() // zero based index only
  55. rocrVD := envconfig.RocrVisibleDevices() // zero based index or UUID, but consumer cards seem to not support UUID
  56. gpuDO := envconfig.GpuDeviceOrdinal() // zero based index
  57. switch {
  58. // TODO is this priorty order right?
  59. case hipVD != "":
  60. visibleDevices = strings.Split(hipVD, ",")
  61. case rocrVD != "":
  62. visibleDevices = strings.Split(rocrVD, ",")
  63. // TODO - since we don't yet support UUIDs, consider detecting and reporting here
  64. // all our test systems show GPU-XX indicating UUID is not supported
  65. case gpuDO != "":
  66. visibleDevices = strings.Split(gpuDO, ",")
  67. }
  68. gfxOverride := envconfig.HsaOverrideGfxVersion()
  69. var supported []string
  70. libDir := ""
  71. // The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract
  72. // from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
  73. matches, _ := filepath.Glob(GPUPropertiesFileGlob)
  74. sort.Slice(matches, func(i, j int) bool {
  75. // /sys/class/kfd/kfd/topology/nodes/<number>/properties
  76. a, err := strconv.ParseInt(filepath.Base(filepath.Dir(matches[i])), 10, 64)
  77. if err != nil {
  78. slog.Debug("parse err", "error", err, "match", matches[i])
  79. return false
  80. }
  81. b, err := strconv.ParseInt(filepath.Base(filepath.Dir(matches[j])), 10, 64)
  82. if err != nil {
  83. slog.Debug("parse err", "error", err, "match", matches[i])
  84. return false
  85. }
  86. return a < b
  87. })
  88. cpuCount := 0
  89. for _, match := range matches {
  90. slog.Debug("evaluating amdgpu node " + match)
  91. fp, err := os.Open(match)
  92. if err != nil {
  93. slog.Debug("failed to open sysfs node", "file", match, "error", err)
  94. continue
  95. }
  96. defer fp.Close()
  97. nodeID, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
  98. if err != nil {
  99. slog.Debug("failed to parse node ID", "error", err)
  100. continue
  101. }
  102. scanner := bufio.NewScanner(fp)
  103. isCPU := false
  104. var major, minor, patch uint64
  105. var vendor, device, uniqueID uint64
  106. for scanner.Scan() {
  107. line := strings.TrimSpace(scanner.Text())
  108. // Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
  109. if strings.HasPrefix(line, "gfx_target_version") {
  110. ver := strings.Fields(line)
  111. // Detect CPUs
  112. if len(ver) == 2 && ver[1] == "0" {
  113. slog.Debug("detected CPU " + match)
  114. isCPU = true
  115. break
  116. }
  117. if len(ver) != 2 || len(ver[1]) < 5 {
  118. slog.Warn("malformed "+match, "gfx_target_version", line)
  119. // If this winds up being a CPU, our offsets may be wrong
  120. continue
  121. }
  122. l := len(ver[1])
  123. var err1, err2, err3 error
  124. patch, err1 = strconv.ParseUint(ver[1][l-2:l], 10, 32)
  125. minor, err2 = strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
  126. major, err3 = strconv.ParseUint(ver[1][:l-4], 10, 32)
  127. if err1 != nil || err2 != nil || err3 != nil {
  128. slog.Debug("malformed int " + line)
  129. continue
  130. }
  131. } else if strings.HasPrefix(line, "vendor_id") {
  132. ver := strings.Fields(line)
  133. if len(ver) != 2 {
  134. slog.Debug("malformed", "vendor_id", line)
  135. continue
  136. }
  137. vendor, err = strconv.ParseUint(ver[1], 10, 64)
  138. if err != nil {
  139. slog.Debug("malformed", "vendor_id", line, "error", err)
  140. }
  141. } else if strings.HasPrefix(line, "device_id") {
  142. ver := strings.Fields(line)
  143. if len(ver) != 2 {
  144. slog.Debug("malformed", "device_id", line)
  145. continue
  146. }
  147. device, err = strconv.ParseUint(ver[1], 10, 64)
  148. if err != nil {
  149. slog.Debug("malformed", "device_id", line, "error", err)
  150. }
  151. } else if strings.HasPrefix(line, "unique_id") {
  152. ver := strings.Fields(line)
  153. if len(ver) != 2 {
  154. slog.Debug("malformed", "unique_id", line)
  155. continue
  156. }
  157. uniqueID, err = strconv.ParseUint(ver[1], 10, 64)
  158. if err != nil {
  159. slog.Debug("malformed", "unique_id", line, "error", err)
  160. }
  161. }
  162. // TODO - any other properties we want to extract and record?
  163. // vendor_id + device_id -> pci lookup for "Name"
  164. // Other metrics that may help us understand relative performance between multiple GPUs
  165. }
  166. // Note: while ./mem_banks/*/used_memory exists, it doesn't appear to take other VRAM consumers
  167. // into consideration, so we instead map the device over to the DRM driver sysfs nodes which
  168. // do reliably report VRAM usage.
  169. if isCPU {
  170. cpuCount++
  171. continue
  172. }
  173. // CPUs are always first in the list
  174. gpuID := nodeID - cpuCount
  175. // Shouldn't happen, but just in case...
  176. if gpuID < 0 {
  177. slog.Error("unexpected amdgpu sysfs data resulted in negative GPU ID, please set OLLAMA_DEBUG=1 and report an issue")
  178. return nil
  179. }
  180. if int(major) < RocmComputeMin {
  181. slog.Warn(fmt.Sprintf("amdgpu too old gfx%d%x%x", major, minor, patch), "gpu", gpuID)
  182. continue
  183. }
  184. // Look up the memory for the current node
  185. totalMemory := uint64(0)
  186. usedMemory := uint64(0)
  187. var usedFile string
  188. mapping := []struct {
  189. id uint64
  190. filename string
  191. }{
  192. {vendor, DRMVendorFile},
  193. {device, DRMDeviceFile},
  194. {uniqueID, DRMUniqueIDFile}, // Not all devices will report this
  195. }
  196. slog.Debug("mapping amdgpu to drm sysfs nodes", "amdgpu", match, "vendor", vendor, "device", device, "unique_id", uniqueID)
  197. // Map over to DRM location to find the total/free memory
  198. drmMatches, _ := filepath.Glob(DRMDeviceDirGlob)
  199. for _, devDir := range drmMatches {
  200. matched := true
  201. for _, m := range mapping {
  202. if m.id == 0 {
  203. // Null ID means it didn't populate, so we can't use it to match
  204. continue
  205. }
  206. filename := filepath.Join(devDir, m.filename)
  207. buf, err := os.ReadFile(filename)
  208. if err != nil {
  209. slog.Debug("failed to read sysfs node", "file", filename, "error", err)
  210. matched = false
  211. break
  212. }
  213. // values here are in hex, strip off the lead 0x and parse so we can compare the numeric (decimal) values in amdgpu
  214. cmp, err := strconv.ParseUint(strings.TrimPrefix(strings.TrimSpace(string(buf)), "0x"), 16, 64)
  215. if err != nil {
  216. slog.Debug("failed to parse sysfs node", "file", filename, "error", err)
  217. matched = false
  218. break
  219. }
  220. if cmp != m.id {
  221. matched = false
  222. break
  223. }
  224. }
  225. if !matched {
  226. continue
  227. }
  228. // Found the matching DRM directory
  229. slog.Debug("matched", "amdgpu", match, "drm", devDir)
  230. totalFile := filepath.Join(devDir, DRMTotalMemoryFile)
  231. buf, err := os.ReadFile(totalFile)
  232. if err != nil {
  233. slog.Debug("failed to read sysfs node", "file", totalFile, "error", err)
  234. break
  235. }
  236. totalMemory, err = strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64)
  237. if err != nil {
  238. slog.Debug("failed to parse sysfs node", "file", totalFile, "error", err)
  239. break
  240. }
  241. usedFile = filepath.Join(devDir, DRMUsedMemoryFile)
  242. usedMemory, err = getFreeMemory(usedFile)
  243. if err != nil {
  244. slog.Debug("failed to update used memory", "error", err)
  245. }
  246. break
  247. }
  248. // iGPU detection, remove this check once we can support an iGPU variant of the rocm library
  249. if totalMemory < IGPUMemLimit {
  250. slog.Info("unsupported Radeon iGPU detected skipping", "id", gpuID, "total", format.HumanBytes2(totalMemory))
  251. continue
  252. }
  253. var name string
  254. // TODO - PCI ID lookup
  255. if vendor > 0 && device > 0 {
  256. name = fmt.Sprintf("%04x:%04x", vendor, device)
  257. }
  258. slog.Debug("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
  259. slog.Debug("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
  260. gpuInfo := RocmGPUInfo{
  261. GpuInfo: GpuInfo{
  262. Library: "rocm",
  263. memInfo: memInfo{
  264. TotalMemory: totalMemory,
  265. FreeMemory: (totalMemory - usedMemory),
  266. },
  267. ID: strconv.Itoa(gpuID),
  268. Name: name,
  269. Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch),
  270. MinimumMemory: rocmMinimumMemory,
  271. DriverMajor: driverMajor,
  272. DriverMinor: driverMinor,
  273. },
  274. usedFilepath: usedFile,
  275. }
  276. // If the user wants to filter to a subset of devices, filter out if we aren't a match
  277. if len(visibleDevices) > 0 {
  278. include := false
  279. for _, visible := range visibleDevices {
  280. if visible == gpuInfo.ID {
  281. include = true
  282. break
  283. }
  284. }
  285. if !include {
  286. slog.Info("filtering out device per user request", "id", gpuInfo.ID, "visible_devices", visibleDevices)
  287. continue
  288. }
  289. }
  290. // Final validation is gfx compatibility - load the library if we haven't already loaded it
  291. // even if the user overrides, we still need to validate the library
  292. if libDir == "" {
  293. libDir, err = AMDValidateLibDir()
  294. if err != nil {
  295. slog.Warn("unable to verify rocm library, will use cpu", "error", err)
  296. return nil
  297. }
  298. }
  299. gpuInfo.DependencyPath = libDir
  300. if gfxOverride == "" {
  301. // Only load supported list once
  302. if len(supported) == 0 {
  303. supported, err = GetSupportedGFX(libDir)
  304. if err != nil {
  305. slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
  306. return nil
  307. }
  308. slog.Debug("rocm supported GPUs", "types", supported)
  309. }
  310. gfx := gpuInfo.Compute
  311. if !slices.Contains[[]string, string](supported, gfx) {
  312. slog.Warn("amdgpu is not supported", "gpu", gpuInfo.ID, "gpu_type", gfx, "library", libDir, "supported_types", supported)
  313. // TODO - consider discrete markdown just for ROCM troubleshooting?
  314. slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
  315. continue
  316. } else {
  317. slog.Info("amdgpu is supported", "gpu", gpuInfo.ID, "gpu_type", gfx)
  318. }
  319. } else {
  320. slog.Info("skipping rocm gfx compatibility check", "HSA_OVERRIDE_GFX_VERSION", gfxOverride)
  321. }
  322. // Check for env var workarounds
  323. if name == "1002:687f" { // Vega RX 56
  324. gpuInfo.EnvWorkarounds = append(gpuInfo.EnvWorkarounds, [2]string{"HSA_ENABLE_SDMA", "0"})
  325. }
  326. // The GPU has passed all the verification steps and is supported
  327. resp = append(resp, gpuInfo)
  328. }
  329. if len(resp) == 0 {
  330. slog.Info("no compatible amdgpu devices detected")
  331. }
  332. if err := verifyKFDDriverAccess(); err != nil {
  333. slog.Error("amdgpu devices detected but permission problems block access", "error", err)
  334. return nil
  335. }
  336. return resp
  337. }
  338. // Quick check for AMD driver so we can skip amdgpu discovery if not present
  339. func AMDDetected() bool {
  340. // Some driver versions (older?) don't have a version file, so just lookup the parent dir
  341. sysfsDir := filepath.Dir(DriverVersionFile)
  342. _, err := os.Stat(sysfsDir)
  343. if errors.Is(err, os.ErrNotExist) {
  344. slog.Debug("amdgpu driver not detected " + sysfsDir)
  345. return false
  346. } else if err != nil {
  347. slog.Debug("error looking up amd driver", "path", sysfsDir, "error", err)
  348. return false
  349. }
  350. return true
  351. }
  352. // Prefer to use host installed ROCm, as long as it meets our minimum requirements
  353. // failing that, tell the user how to download it on their own
  354. func AMDValidateLibDir() (string, error) {
  355. libDir, err := commonAMDValidateLibDir()
  356. if err == nil {
  357. return libDir, nil
  358. }
  359. // Well known ollama installer path
  360. installedRocmDir := "/usr/share/ollama/lib/rocm"
  361. if rocmLibUsable(installedRocmDir) {
  362. return installedRocmDir, nil
  363. }
  364. // If we still haven't found a usable rocm, the user will have to install it on their own
  365. slog.Warn("amdgpu detected, but no compatible rocm library found. Either install rocm v6, or follow manual install instructions at https://github.com/ollama/ollama/blob/main/docs/linux.md#manual-install")
  366. return "", errors.New("no suitable rocm found, falling back to CPU")
  367. }
  368. func AMDDriverVersion() (driverMajor, driverMinor int, err error) {
  369. _, err = os.Stat(DriverVersionFile)
  370. if err != nil {
  371. return 0, 0, fmt.Errorf("amdgpu version file missing: %s %w", DriverVersionFile, err)
  372. }
  373. fp, err := os.Open(DriverVersionFile)
  374. if err != nil {
  375. return 0, 0, err
  376. }
  377. defer fp.Close()
  378. verString, err := io.ReadAll(fp)
  379. if err != nil {
  380. return 0, 0, err
  381. }
  382. pattern := `\A(\d+)\.(\d+).*`
  383. regex := regexp.MustCompile(pattern)
  384. match := regex.FindStringSubmatch(string(verString))
  385. if len(match) < 2 {
  386. return 0, 0, fmt.Errorf("malformed version string %s", string(verString))
  387. }
  388. driverMajor, err = strconv.Atoi(match[1])
  389. if err != nil {
  390. return 0, 0, err
  391. }
  392. driverMinor, err = strconv.Atoi(match[2])
  393. if err != nil {
  394. return 0, 0, err
  395. }
  396. return driverMajor, driverMinor, nil
  397. }
  398. func (gpus RocmGPUInfoList) RefreshFreeMemory() error {
  399. if len(gpus) == 0 {
  400. return nil
  401. }
  402. for i := range gpus {
  403. usedMemory, err := getFreeMemory(gpus[i].usedFilepath)
  404. if err != nil {
  405. return err
  406. }
  407. slog.Debug("updating rocm free memory", "gpu", gpus[i].ID, "name", gpus[i].Name, "before", format.HumanBytes2(gpus[i].FreeMemory), "now", format.HumanBytes2(gpus[i].TotalMemory-usedMemory))
  408. gpus[i].FreeMemory = gpus[i].TotalMemory - usedMemory
  409. }
  410. return nil
  411. }
  412. func getFreeMemory(usedFile string) (uint64, error) {
  413. buf, err := os.ReadFile(usedFile)
  414. if err != nil {
  415. return 0, fmt.Errorf("failed to read sysfs node %s %w", usedFile, err)
  416. }
  417. usedMemory, err := strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64)
  418. if err != nil {
  419. slog.Debug("failed to parse sysfs node", "file", usedFile, "error", err)
  420. return 0, fmt.Errorf("failed to parse sysfs node %s %w", usedFile, err)
  421. }
  422. return usedMemory, nil
  423. }
  424. func verifyKFDDriverAccess() error {
  425. // Verify we have permissions - either running as root, or we have group access to the driver
  426. fd, err := os.OpenFile("/dev/kfd", os.O_RDWR, 0o666)
  427. if err != nil {
  428. if errors.Is(err, fs.ErrPermission) {
  429. return fmt.Errorf("permissions not set up properly. Either run ollama as root, or add you user account to the render group. %w", err)
  430. } else if errors.Is(err, fs.ErrNotExist) {
  431. // Container runtime failure?
  432. return fmt.Errorf("kfd driver not loaded. If running in a container, remember to include '--device /dev/kfd --device /dev/dri'")
  433. }
  434. return fmt.Errorf("failed to check permission on /dev/kfd: %w", err)
  435. }
  436. fd.Close()
  437. return nil
  438. }