amd_linux.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. package gpu
  2. import (
  3. "bufio"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "log/slog"
  8. "os"
  9. "path/filepath"
  10. "regexp"
  11. "slices"
  12. "sort"
  13. "strconv"
  14. "strings"
  15. "github.com/ollama/ollama/envconfig"
  16. "github.com/ollama/ollama/format"
  17. )
  18. // Discovery logic for AMD/ROCm GPUs
  19. const (
  20. DriverVersionFile = "/sys/module/amdgpu/version"
  21. AMDNodesSysfsDir = "/sys/class/kfd/kfd/topology/nodes/"
  22. GPUPropertiesFileGlob = AMDNodesSysfsDir + "*/properties"
  23. // Prefix with the node dir
  24. GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
  25. // Direct Rendering Manager sysfs location
  26. DRMDeviceDirGlob = "/sys/class/drm/card*/device"
  27. DRMTotalMemoryFile = "mem_info_vram_total"
  28. DRMUsedMemoryFile = "mem_info_vram_used"
  29. // In hex; properties file is in decimal
  30. DRMUniqueIDFile = "unique_id"
  31. DRMVendorFile = "vendor"
  32. DRMDeviceFile = "device"
  33. )
  34. var (
  35. // Used to validate if the given ROCm lib is usable
  36. ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
  37. RocmStandardLocations = []string{"/opt/rocm/lib", "/usr/lib64"}
  38. )
  39. // Gather GPU information from the amdgpu driver if any supported GPUs are detected
  40. func AMDGetGPUInfo() []RocmGPUInfo {
  41. resp := []RocmGPUInfo{}
  42. if !AMDDetected() {
  43. return resp
  44. }
  45. // Opportunistic logging of driver version to aid in troubleshooting
  46. driverMajor, driverMinor, err := AMDDriverVersion()
  47. if err != nil {
  48. // TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
  49. slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err)
  50. }
  51. // Determine if the user has already pre-selected which GPUs to look at, then ignore the others
  52. var visibleDevices []string
  53. hipVD := envconfig.HipVisibleDevices() // zero based index only
  54. rocrVD := envconfig.RocrVisibleDevices() // zero based index or UUID, but consumer cards seem to not support UUID
  55. gpuDO := envconfig.GpuDeviceOrdinal() // zero based index
  56. switch {
  57. // TODO is this priorty order right?
  58. case hipVD != "":
  59. visibleDevices = strings.Split(hipVD, ",")
  60. case rocrVD != "":
  61. visibleDevices = strings.Split(rocrVD, ",")
  62. // TODO - since we don't yet support UUIDs, consider detecting and reporting here
  63. // all our test systems show GPU-XX indicating UUID is not supported
  64. case gpuDO != "":
  65. visibleDevices = strings.Split(gpuDO, ",")
  66. }
  67. gfxOverride := envconfig.HsaOverrideGfxVersion()
  68. var supported []string
  69. libDir := ""
  70. // The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract
  71. // from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
  72. matches, _ := filepath.Glob(GPUPropertiesFileGlob)
  73. sort.Slice(matches, func(i, j int) bool {
  74. // /sys/class/kfd/kfd/topology/nodes/<number>/properties
  75. a, err := strconv.ParseInt(filepath.Base(filepath.Dir(matches[i])), 10, 64)
  76. if err != nil {
  77. slog.Debug("parse err", "error", err, "match", matches[i])
  78. return false
  79. }
  80. b, err := strconv.ParseInt(filepath.Base(filepath.Dir(matches[j])), 10, 64)
  81. if err != nil {
  82. slog.Debug("parse err", "error", err, "match", matches[i])
  83. return false
  84. }
  85. return a < b
  86. })
  87. cpuCount := 0
  88. for _, match := range matches {
  89. slog.Debug("evaluating amdgpu node " + match)
  90. fp, err := os.Open(match)
  91. if err != nil {
  92. slog.Debug("failed to open sysfs node", "file", match, "error", err)
  93. continue
  94. }
  95. defer fp.Close()
  96. nodeID, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
  97. if err != nil {
  98. slog.Debug("failed to parse node ID", "error", err)
  99. continue
  100. }
  101. scanner := bufio.NewScanner(fp)
  102. isCPU := false
  103. var major, minor, patch uint64
  104. var vendor, device, uniqueID uint64
  105. for scanner.Scan() {
  106. line := strings.TrimSpace(scanner.Text())
  107. // Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
  108. if strings.HasPrefix(line, "gfx_target_version") {
  109. ver := strings.Fields(line)
  110. // Detect CPUs
  111. if len(ver) == 2 && ver[1] == "0" {
  112. slog.Debug("detected CPU " + match)
  113. isCPU = true
  114. break
  115. }
  116. if len(ver) != 2 || len(ver[1]) < 5 {
  117. slog.Warn("malformed "+match, "gfx_target_version", line)
  118. // If this winds up being a CPU, our offsets may be wrong
  119. continue
  120. }
  121. l := len(ver[1])
  122. var err1, err2, err3 error
  123. patch, err1 = strconv.ParseUint(ver[1][l-2:l], 10, 32)
  124. minor, err2 = strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
  125. major, err3 = strconv.ParseUint(ver[1][:l-4], 10, 32)
  126. if err1 != nil || err2 != nil || err3 != nil {
  127. slog.Debug("malformed int " + line)
  128. continue
  129. }
  130. } else if strings.HasPrefix(line, "vendor_id") {
  131. ver := strings.Fields(line)
  132. if len(ver) != 2 {
  133. slog.Debug("malformed", "vendor_id", line)
  134. continue
  135. }
  136. vendor, err = strconv.ParseUint(ver[1], 10, 64)
  137. if err != nil {
  138. slog.Debug("malformed", "vendor_id", line, "error", err)
  139. }
  140. } else if strings.HasPrefix(line, "device_id") {
  141. ver := strings.Fields(line)
  142. if len(ver) != 2 {
  143. slog.Debug("malformed", "device_id", line)
  144. continue
  145. }
  146. device, err = strconv.ParseUint(ver[1], 10, 64)
  147. if err != nil {
  148. slog.Debug("malformed", "device_id", line, "error", err)
  149. }
  150. } else if strings.HasPrefix(line, "unique_id") {
  151. ver := strings.Fields(line)
  152. if len(ver) != 2 {
  153. slog.Debug("malformed", "unique_id", line)
  154. continue
  155. }
  156. uniqueID, err = strconv.ParseUint(ver[1], 10, 64)
  157. if err != nil {
  158. slog.Debug("malformed", "unique_id", line, "error", err)
  159. }
  160. }
  161. // TODO - any other properties we want to extract and record?
  162. // vendor_id + device_id -> pci lookup for "Name"
  163. // Other metrics that may help us understand relative performance between multiple GPUs
  164. }
  165. // Note: while ./mem_banks/*/used_memory exists, it doesn't appear to take other VRAM consumers
  166. // into consideration, so we instead map the device over to the DRM driver sysfs nodes which
  167. // do reliably report VRAM usage.
  168. if isCPU {
  169. cpuCount++
  170. continue
  171. }
  172. // CPUs are always first in the list
  173. gpuID := nodeID - cpuCount
  174. // Shouldn't happen, but just in case...
  175. if gpuID < 0 {
  176. slog.Error("unexpected amdgpu sysfs data resulted in negative GPU ID, please set OLLAMA_DEBUG=1 and report an issue")
  177. return nil
  178. }
  179. if int(major) < RocmComputeMin {
  180. slog.Warn(fmt.Sprintf("amdgpu too old gfx%d%x%x", major, minor, patch), "gpu", gpuID)
  181. continue
  182. }
  183. // Look up the memory for the current node
  184. totalMemory := uint64(0)
  185. usedMemory := uint64(0)
  186. var usedFile string
  187. mapping := []struct {
  188. id uint64
  189. filename string
  190. }{
  191. {vendor, DRMVendorFile},
  192. {device, DRMDeviceFile},
  193. {uniqueID, DRMUniqueIDFile}, // Not all devices will report this
  194. }
  195. slog.Debug("mapping amdgpu to drm sysfs nodes", "amdgpu", match, "vendor", vendor, "device", device, "unique_id", uniqueID)
  196. // Map over to DRM location to find the total/free memory
  197. drmMatches, _ := filepath.Glob(DRMDeviceDirGlob)
  198. for _, devDir := range drmMatches {
  199. matched := true
  200. for _, m := range mapping {
  201. if m.id == 0 {
  202. // Null ID means it didn't populate, so we can't use it to match
  203. continue
  204. }
  205. filename := filepath.Join(devDir, m.filename)
  206. buf, err := os.ReadFile(filename)
  207. if err != nil {
  208. slog.Debug("failed to read sysfs node", "file", filename, "error", err)
  209. matched = false
  210. break
  211. }
  212. // values here are in hex, strip off the lead 0x and parse so we can compare the numeric (decimal) values in amdgpu
  213. cmp, err := strconv.ParseUint(strings.TrimPrefix(strings.TrimSpace(string(buf)), "0x"), 16, 64)
  214. if err != nil {
  215. slog.Debug("failed to parse sysfs node", "file", filename, "error", err)
  216. matched = false
  217. break
  218. }
  219. if cmp != m.id {
  220. matched = false
  221. break
  222. }
  223. }
  224. if !matched {
  225. continue
  226. }
  227. // Found the matching DRM directory
  228. slog.Debug("matched", "amdgpu", match, "drm", devDir)
  229. totalFile := filepath.Join(devDir, DRMTotalMemoryFile)
  230. buf, err := os.ReadFile(totalFile)
  231. if err != nil {
  232. slog.Debug("failed to read sysfs node", "file", totalFile, "error", err)
  233. break
  234. }
  235. totalMemory, err = strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64)
  236. if err != nil {
  237. slog.Debug("failed to parse sysfs node", "file", totalFile, "error", err)
  238. break
  239. }
  240. usedFile = filepath.Join(devDir, DRMUsedMemoryFile)
  241. usedMemory, err = getFreeMemory(usedFile)
  242. if err != nil {
  243. slog.Debug("failed to update used memory", "error", err)
  244. }
  245. break
  246. }
  247. // iGPU detection, remove this check once we can support an iGPU variant of the rocm library
  248. if totalMemory < IGPUMemLimit {
  249. slog.Info("unsupported Radeon iGPU detected skipping", "id", gpuID, "total", format.HumanBytes2(totalMemory))
  250. continue
  251. }
  252. var name string
  253. // TODO - PCI ID lookup
  254. if vendor > 0 && device > 0 {
  255. name = fmt.Sprintf("%04x:%04x", vendor, device)
  256. }
  257. slog.Debug("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
  258. slog.Debug("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
  259. gpuInfo := RocmGPUInfo{
  260. GpuInfo: GpuInfo{
  261. Library: "rocm",
  262. memInfo: memInfo{
  263. TotalMemory: totalMemory,
  264. FreeMemory: (totalMemory - usedMemory),
  265. },
  266. ID: strconv.Itoa(gpuID),
  267. Name: name,
  268. Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch),
  269. MinimumMemory: rocmMinimumMemory,
  270. DriverMajor: driverMajor,
  271. DriverMinor: driverMinor,
  272. },
  273. usedFilepath: usedFile,
  274. }
  275. // If the user wants to filter to a subset of devices, filter out if we aren't a match
  276. if len(visibleDevices) > 0 {
  277. include := false
  278. for _, visible := range visibleDevices {
  279. if visible == gpuInfo.ID {
  280. include = true
  281. break
  282. }
  283. }
  284. if !include {
  285. slog.Info("filtering out device per user request", "id", gpuInfo.ID, "visible_devices", visibleDevices)
  286. continue
  287. }
  288. }
  289. // Final validation is gfx compatibility - load the library if we haven't already loaded it
  290. // even if the user overrides, we still need to validate the library
  291. if libDir == "" {
  292. libDir, err = AMDValidateLibDir()
  293. if err != nil {
  294. slog.Warn("unable to verify rocm library, will use cpu", "error", err)
  295. return nil
  296. }
  297. }
  298. gpuInfo.DependencyPath = libDir
  299. if gfxOverride == "" {
  300. // Only load supported list once
  301. if len(supported) == 0 {
  302. supported, err = GetSupportedGFX(libDir)
  303. if err != nil {
  304. slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
  305. return nil
  306. }
  307. slog.Debug("rocm supported GPUs", "types", supported)
  308. }
  309. gfx := gpuInfo.Compute
  310. if !slices.Contains[[]string, string](supported, gfx) {
  311. slog.Warn("amdgpu is not supported", "gpu", gpuInfo.ID, "gpu_type", gfx, "library", libDir, "supported_types", supported)
  312. // TODO - consider discrete markdown just for ROCM troubleshooting?
  313. slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
  314. continue
  315. } else {
  316. slog.Info("amdgpu is supported", "gpu", gpuInfo.ID, "gpu_type", gfx)
  317. }
  318. } else {
  319. slog.Info("skipping rocm gfx compatibility check", "HSA_OVERRIDE_GFX_VERSION", gfxOverride)
  320. }
  321. // Check for env var workarounds
  322. if name == "1002:687f" { // Vega RX 56
  323. gpuInfo.EnvWorkarounds = append(gpuInfo.EnvWorkarounds, [2]string{"HSA_ENABLE_SDMA", "0"})
  324. }
  325. // The GPU has passed all the verification steps and is supported
  326. resp = append(resp, gpuInfo)
  327. }
  328. if len(resp) == 0 {
  329. slog.Info("no compatible amdgpu devices detected")
  330. }
  331. return resp
  332. }
  333. // Quick check for AMD driver so we can skip amdgpu discovery if not present
  334. func AMDDetected() bool {
  335. // Some driver versions (older?) don't have a version file, so just lookup the parent dir
  336. sysfsDir := filepath.Dir(DriverVersionFile)
  337. _, err := os.Stat(sysfsDir)
  338. if errors.Is(err, os.ErrNotExist) {
  339. slog.Debug("amdgpu driver not detected " + sysfsDir)
  340. return false
  341. } else if err != nil {
  342. slog.Debug("error looking up amd driver", "path", sysfsDir, "error", err)
  343. return false
  344. }
  345. return true
  346. }
  347. // Prefer to use host installed ROCm, as long as it meets our minimum requirements
  348. // failing that, tell the user how to download it on their own
  349. func AMDValidateLibDir() (string, error) {
  350. libDir, err := commonAMDValidateLibDir()
  351. if err == nil {
  352. return libDir, nil
  353. }
  354. // Well known ollama installer path
  355. installedRocmDir := "/usr/share/ollama/lib/rocm"
  356. if rocmLibUsable(installedRocmDir) {
  357. return installedRocmDir, nil
  358. }
  359. // If we still haven't found a usable rocm, the user will have to install it on their own
  360. slog.Warn("amdgpu detected, but no compatible rocm library found. Either install rocm v6, or follow manual install instructions at https://github.com/ollama/ollama/blob/main/docs/linux.md#manual-install")
  361. return "", errors.New("no suitable rocm found, falling back to CPU")
  362. }
  363. func AMDDriverVersion() (driverMajor, driverMinor int, err error) {
  364. _, err = os.Stat(DriverVersionFile)
  365. if err != nil {
  366. return 0, 0, fmt.Errorf("amdgpu version file missing: %s %w", DriverVersionFile, err)
  367. }
  368. fp, err := os.Open(DriverVersionFile)
  369. if err != nil {
  370. return 0, 0, err
  371. }
  372. defer fp.Close()
  373. verString, err := io.ReadAll(fp)
  374. if err != nil {
  375. return 0, 0, err
  376. }
  377. pattern := `\A(\d+)\.(\d+).*`
  378. regex := regexp.MustCompile(pattern)
  379. match := regex.FindStringSubmatch(string(verString))
  380. if len(match) < 2 {
  381. return 0, 0, fmt.Errorf("malformed version string %s", string(verString))
  382. }
  383. driverMajor, err = strconv.Atoi(match[1])
  384. if err != nil {
  385. return 0, 0, err
  386. }
  387. driverMinor, err = strconv.Atoi(match[2])
  388. if err != nil {
  389. return 0, 0, err
  390. }
  391. return driverMajor, driverMinor, nil
  392. }
  393. func (gpus RocmGPUInfoList) RefreshFreeMemory() error {
  394. if len(gpus) == 0 {
  395. return nil
  396. }
  397. for i := range gpus {
  398. usedMemory, err := getFreeMemory(gpus[i].usedFilepath)
  399. if err != nil {
  400. return err
  401. }
  402. slog.Debug("updating rocm free memory", "gpu", gpus[i].ID, "name", gpus[i].Name, "before", format.HumanBytes2(gpus[i].FreeMemory), "now", format.HumanBytes2(gpus[i].TotalMemory-usedMemory))
  403. gpus[i].FreeMemory = gpus[i].TotalMemory - usedMemory
  404. }
  405. return nil
  406. }
  407. func getFreeMemory(usedFile string) (uint64, error) {
  408. buf, err := os.ReadFile(usedFile)
  409. if err != nil {
  410. return 0, fmt.Errorf("failed to read sysfs node %s %w", usedFile, err)
  411. }
  412. usedMemory, err := strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64)
  413. if err != nil {
  414. slog.Debug("failed to parse sysfs node", "file", usedFile, "error", err)
  415. return 0, fmt.Errorf("failed to parse sysfs node %s %w", usedFile, err)
  416. }
  417. return usedMemory, nil
  418. }