amd_linux.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. package gpu
  2. import (
  3. "bufio"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "log/slog"
  8. "os"
  9. "path/filepath"
  10. "regexp"
  11. "slices"
  12. "strconv"
  13. "strings"
  14. "github.com/ollama/ollama/format"
  15. )
  16. // Discovery logic for AMD/ROCm GPUs
  17. const (
  18. DriverVersionFile = "/sys/module/amdgpu/version"
  19. AMDNodesSysfsDir = "/sys/class/kfd/kfd/topology/nodes/"
  20. GPUPropertiesFileGlob = AMDNodesSysfsDir + "*/properties"
  21. // Prefix with the node dir
  22. GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
  23. // Direct Rendering Manager sysfs location
  24. DRMDeviceDirGlob = "/sys/class/drm/card*/device"
  25. DRMTotalMemoryFile = "mem_info_vram_total"
  26. DRMUsedMemoryFile = "mem_info_vram_used"
  27. // In hex; properties file is in decimal
  28. DRMUniqueIDFile = "unique_id"
  29. DRMVendorFile = "vendor"
  30. DRMDeviceFile = "device"
  31. )
  32. var (
  33. // Used to validate if the given ROCm lib is usable
  34. ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
  35. RocmStandardLocations = []string{"/opt/rocm/lib", "/usr/lib64"}
  36. )
  37. // Gather GPU information from the amdgpu driver if any supported GPUs are detected
  38. func AMDGetGPUInfo() []RocmGPUInfo {
  39. resp := []RocmGPUInfo{}
  40. if !AMDDetected() {
  41. return resp
  42. }
  43. // Opportunistic logging of driver version to aid in troubleshooting
  44. driverMajor, driverMinor, err := AMDDriverVersion()
  45. if err != nil {
  46. // TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
  47. slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err)
  48. }
  49. // Determine if the user has already pre-selected which GPUs to look at, then ignore the others
  50. var visibleDevices []string
  51. hipVD := os.Getenv("HIP_VISIBLE_DEVICES") // zero based index only
  52. rocrVD := os.Getenv("ROCR_VISIBLE_DEVICES") // zero based index or UUID, but consumer cards seem to not support UUID
  53. gpuDO := os.Getenv("GPU_DEVICE_ORDINAL") // zero based index
  54. switch {
  55. // TODO is this priorty order right?
  56. case hipVD != "":
  57. visibleDevices = strings.Split(hipVD, ",")
  58. case rocrVD != "":
  59. visibleDevices = strings.Split(rocrVD, ",")
  60. // TODO - since we don't yet support UUIDs, consider detecting and reporting here
  61. // all our test systems show GPU-XX indicating UUID is not supported
  62. case gpuDO != "":
  63. visibleDevices = strings.Split(gpuDO, ",")
  64. }
  65. gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
  66. var supported []string
  67. libDir := ""
  68. // The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract
  69. // from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
  70. matches, _ := filepath.Glob(GPUPropertiesFileGlob)
  71. cpuCount := 0
  72. for _, match := range matches {
  73. slog.Debug("evaluating amdgpu node " + match)
  74. fp, err := os.Open(match)
  75. if err != nil {
  76. slog.Debug("failed to open sysfs node", "file", match, "error", err)
  77. continue
  78. }
  79. defer fp.Close()
  80. nodeID, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
  81. if err != nil {
  82. slog.Debug("failed to parse node ID", "error", err)
  83. continue
  84. }
  85. scanner := bufio.NewScanner(fp)
  86. isCPU := false
  87. var major, minor, patch uint64
  88. var vendor, device, uniqueID uint64
  89. for scanner.Scan() {
  90. line := strings.TrimSpace(scanner.Text())
  91. // Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
  92. if strings.HasPrefix(line, "gfx_target_version") {
  93. ver := strings.Fields(line)
  94. // Detect CPUs
  95. if len(ver) == 2 && ver[1] == "0" {
  96. slog.Debug("detected CPU " + match)
  97. isCPU = true
  98. break
  99. }
  100. if len(ver) != 2 || len(ver[1]) < 5 {
  101. slog.Warn("malformed "+match, "gfx_target_version", line)
  102. // If this winds up being a CPU, our offsets may be wrong
  103. continue
  104. }
  105. l := len(ver[1])
  106. var err1, err2, err3 error
  107. patch, err1 = strconv.ParseUint(ver[1][l-2:l], 10, 32)
  108. minor, err2 = strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
  109. major, err3 = strconv.ParseUint(ver[1][:l-4], 10, 32)
  110. if err1 != nil || err2 != nil || err3 != nil {
  111. slog.Debug("malformed int " + line)
  112. continue
  113. }
  114. } else if strings.HasPrefix(line, "vendor_id") {
  115. ver := strings.Fields(line)
  116. if len(ver) != 2 {
  117. slog.Debug("malformed", "vendor_id", line)
  118. continue
  119. }
  120. vendor, err = strconv.ParseUint(ver[1], 10, 64)
  121. if err != nil {
  122. slog.Debug("malformed", "vendor_id", line, "error", err)
  123. }
  124. } else if strings.HasPrefix(line, "device_id") {
  125. ver := strings.Fields(line)
  126. if len(ver) != 2 {
  127. slog.Debug("malformed", "device_id", line)
  128. continue
  129. }
  130. device, err = strconv.ParseUint(ver[1], 10, 64)
  131. if err != nil {
  132. slog.Debug("malformed", "device_id", line, "error", err)
  133. }
  134. } else if strings.HasPrefix(line, "unique_id") {
  135. ver := strings.Fields(line)
  136. if len(ver) != 2 {
  137. slog.Debug("malformed", "unique_id", line)
  138. continue
  139. }
  140. uniqueID, err = strconv.ParseUint(ver[1], 10, 64)
  141. if err != nil {
  142. slog.Debug("malformed", "unique_id", line, "error", err)
  143. }
  144. }
  145. // TODO - any other properties we want to extract and record?
  146. // vendor_id + device_id -> pci lookup for "Name"
  147. // Other metrics that may help us understand relative performance between multiple GPUs
  148. }
  149. // Note: while ./mem_banks/*/used_memory exists, it doesn't appear to take other VRAM consumers
  150. // into consideration, so we instead map the device over to the DRM driver sysfs nodes which
  151. // do reliably report VRAM usage.
  152. if isCPU {
  153. cpuCount++
  154. continue
  155. }
  156. // CPUs are always first in the list
  157. gpuID := nodeID - cpuCount
  158. // Shouldn't happen, but just in case...
  159. if gpuID < 0 {
  160. slog.Error("unexpected amdgpu sysfs data resulted in negative GPU ID, please set OLLAMA_DEBUG=1 and report an issue")
  161. return []RocmGPUInfo{}
  162. }
  163. if int(major) < RocmComputeMin {
  164. slog.Warn(fmt.Sprintf("amdgpu too old gfx%d%x%x", major, minor, patch), "gpu", gpuID)
  165. continue
  166. }
  167. // Look up the memory for the current node
  168. totalMemory := uint64(0)
  169. usedMemory := uint64(0)
  170. var usedFile string
  171. mapping := []struct {
  172. id uint64
  173. filename string
  174. }{
  175. {vendor, DRMVendorFile},
  176. {device, DRMDeviceFile},
  177. {uniqueID, DRMUniqueIDFile}, // Not all devices will report this
  178. }
  179. slog.Debug("mapping amdgpu to drm sysfs nodes", "amdgpu", match, "vendor", vendor, "device", device, "unique_id", uniqueID)
  180. // Map over to DRM location to find the total/free memory
  181. drmMatches, _ := filepath.Glob(DRMDeviceDirGlob)
  182. for _, devDir := range drmMatches {
  183. matched := true
  184. for _, m := range mapping {
  185. if m.id == 0 {
  186. continue
  187. }
  188. filename := filepath.Join(devDir, m.filename)
  189. fp, err := os.Open(filename)
  190. if err != nil {
  191. slog.Debug("failed to open sysfs node", "file", filename, "error", err)
  192. matched = false
  193. break
  194. }
  195. defer fp.Close()
  196. buf, err := io.ReadAll(fp)
  197. if err != nil {
  198. slog.Debug("failed to read sysfs node", "file", filename, "error", err)
  199. matched = false
  200. break
  201. }
  202. cmp, err := strconv.ParseUint(strings.TrimPrefix(strings.TrimSpace(string(buf)), "0x"), 16, 64)
  203. if err != nil {
  204. slog.Debug("failed to parse sysfs node", "file", filename, "error", err)
  205. matched = false
  206. break
  207. }
  208. if cmp != m.id {
  209. matched = false
  210. break
  211. }
  212. }
  213. if !matched {
  214. continue
  215. }
  216. // Found the matching DRM directory
  217. slog.Debug("matched", "amdgpu", match, "drm", devDir)
  218. totalFile := filepath.Join(devDir, DRMTotalMemoryFile)
  219. totalFp, err := os.Open(totalFile)
  220. if err != nil {
  221. slog.Debug("failed to open sysfs node", "file", totalFile, "error", err)
  222. break
  223. }
  224. defer totalFp.Close()
  225. buf, err := io.ReadAll(totalFp)
  226. if err != nil {
  227. slog.Debug("failed to read sysfs node", "file", totalFile, "error", err)
  228. break
  229. }
  230. totalMemory, err = strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64)
  231. if err != nil {
  232. slog.Debug("failed to parse sysfs node", "file", totalFile, "error", err)
  233. break
  234. }
  235. usedFile = filepath.Join(devDir, DRMUsedMemoryFile)
  236. usedMemory, err = getFreeMemory(usedFile)
  237. if err != nil {
  238. slog.Debug("failed to update used memory", "error", err)
  239. }
  240. break
  241. }
  242. // iGPU detection, remove this check once we can support an iGPU variant of the rocm library
  243. if totalMemory < IGPUMemLimit {
  244. slog.Info("unsupported Radeon iGPU detected skipping", "id", gpuID, "total", format.HumanBytes2(totalMemory))
  245. continue
  246. }
  247. var name string
  248. // TODO - PCI ID lookup
  249. if vendor > 0 && device > 0 {
  250. name = fmt.Sprintf("%04x:%04x", vendor, device)
  251. }
  252. slog.Debug("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
  253. slog.Debug("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
  254. gpuInfo := RocmGPUInfo{
  255. GpuInfo: GpuInfo{
  256. Library: "rocm",
  257. memInfo: memInfo{
  258. TotalMemory: totalMemory,
  259. FreeMemory: (totalMemory - usedMemory),
  260. },
  261. ID: fmt.Sprintf("%d", gpuID),
  262. Name: name,
  263. Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch),
  264. MinimumMemory: rocmMinimumMemory,
  265. DriverMajor: driverMajor,
  266. DriverMinor: driverMinor,
  267. },
  268. usedFilepath: usedFile,
  269. }
  270. // If the user wants to filter to a subset of devices, filter out if we aren't a match
  271. if len(visibleDevices) > 0 {
  272. include := false
  273. for _, visible := range visibleDevices {
  274. if visible == gpuInfo.ID {
  275. include = true
  276. break
  277. }
  278. }
  279. if !include {
  280. slog.Info("filtering out device per user request", "id", gpuInfo.ID, "visible_devices", visibleDevices)
  281. continue
  282. }
  283. }
  284. // Final validation is gfx compatibility - load the library if we haven't already loaded it
  285. // even if the user overrides, we still need to validate the library
  286. if libDir == "" {
  287. libDir, err = AMDValidateLibDir()
  288. if err != nil {
  289. slog.Warn("unable to verify rocm library, will use cpu", "error", err)
  290. return []RocmGPUInfo{}
  291. }
  292. }
  293. gpuInfo.DependencyPath = libDir
  294. if gfxOverride == "" {
  295. // Only load supported list once
  296. if len(supported) == 0 {
  297. supported, err = GetSupportedGFX(libDir)
  298. if err != nil {
  299. slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
  300. return []RocmGPUInfo{}
  301. }
  302. slog.Debug("rocm supported GPUs", "types", supported)
  303. }
  304. gfx := gpuInfo.Compute
  305. if !slices.Contains[[]string, string](supported, gfx) {
  306. slog.Warn("amdgpu is not supported", "gpu", gpuInfo.ID, "gpu_type", gfx, "library", libDir, "supported_types", supported)
  307. // TODO - consider discrete markdown just for ROCM troubleshooting?
  308. slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
  309. continue
  310. } else {
  311. slog.Info("amdgpu is supported", "gpu", gpuInfo.ID, "gpu_type", gfx)
  312. }
  313. } else {
  314. slog.Info("skipping rocm gfx compatibility check", "HSA_OVERRIDE_GFX_VERSION", gfxOverride)
  315. }
  316. // The GPU has passed all the verification steps and is supported
  317. resp = append(resp, gpuInfo)
  318. }
  319. if len(resp) == 0 {
  320. slog.Info("no compatible amdgpu devices detected")
  321. }
  322. return resp
  323. }
  324. // Quick check for AMD driver so we can skip amdgpu discovery if not present
  325. func AMDDetected() bool {
  326. // Some driver versions (older?) don't have a version file, so just lookup the parent dir
  327. sysfsDir := filepath.Dir(DriverVersionFile)
  328. _, err := os.Stat(sysfsDir)
  329. if errors.Is(err, os.ErrNotExist) {
  330. slog.Debug("amdgpu driver not detected " + sysfsDir)
  331. return false
  332. } else if err != nil {
  333. slog.Debug("error looking up amd driver", "path", sysfsDir, "error", err)
  334. return false
  335. }
  336. return true
  337. }
  338. // Prefer to use host installed ROCm, as long as it meets our minimum requirements
  339. // failing that, tell the user how to download it on their own
  340. func AMDValidateLibDir() (string, error) {
  341. libDir, err := commonAMDValidateLibDir()
  342. if err == nil {
  343. return libDir, nil
  344. }
  345. // Well known ollama installer path
  346. installedRocmDir := "/usr/share/ollama/lib/rocm"
  347. if rocmLibUsable(installedRocmDir) {
  348. return installedRocmDir, nil
  349. }
  350. // If we still haven't found a usable rocm, the user will have to install it on their own
  351. slog.Warn("amdgpu detected, but no compatible rocm library found. Either install rocm v6, or follow manual install instructions at https://github.com/ollama/ollama/blob/main/docs/linux.md#manual-install")
  352. return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
  353. }
  354. func AMDDriverVersion() (driverMajor, driverMinor int, err error) {
  355. _, err = os.Stat(DriverVersionFile)
  356. if err != nil {
  357. return 0, 0, fmt.Errorf("amdgpu version file missing: %s %w", DriverVersionFile, err)
  358. }
  359. fp, err := os.Open(DriverVersionFile)
  360. if err != nil {
  361. return 0, 0, err
  362. }
  363. defer fp.Close()
  364. verString, err := io.ReadAll(fp)
  365. if err != nil {
  366. return 0, 0, err
  367. }
  368. pattern := `\A(\d+)\.(\d+).*`
  369. regex := regexp.MustCompile(pattern)
  370. match := regex.FindStringSubmatch(string(verString))
  371. if len(match) < 2 {
  372. return 0, 0, fmt.Errorf("malformed version string %s", string(verString))
  373. }
  374. driverMajor, err = strconv.Atoi(match[1])
  375. if err != nil {
  376. return 0, 0, err
  377. }
  378. driverMinor, err = strconv.Atoi(match[2])
  379. if err != nil {
  380. return 0, 0, err
  381. }
  382. return driverMajor, driverMinor, nil
  383. }
  384. func (gpus RocmGPUInfoList) RefreshFreeMemory() error {
  385. if len(gpus) == 0 {
  386. return nil
  387. }
  388. for i := range gpus {
  389. usedMemory, err := getFreeMemory(gpus[i].usedFilepath)
  390. if err != nil {
  391. return err
  392. }
  393. slog.Debug("updating rocm free memory", "gpu", gpus[i].ID, "name", gpus[i].Name, "before", format.HumanBytes2(gpus[i].FreeMemory), "now", format.HumanBytes2(gpus[i].TotalMemory-usedMemory))
  394. gpus[i].FreeMemory = gpus[i].TotalMemory - usedMemory
  395. }
  396. return nil
  397. }
  398. func getFreeMemory(usedFile string) (uint64, error) {
  399. usedFp, err := os.Open(usedFile)
  400. if err != nil {
  401. return 0, fmt.Errorf("failed to open sysfs node %s %w", usedFile, err)
  402. }
  403. defer usedFp.Close()
  404. buf, err := io.ReadAll(usedFp)
  405. if err != nil {
  406. return 0, fmt.Errorf("failed to read sysfs node %s %w", usedFile, err)
  407. }
  408. usedMemory, err := strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64)
  409. if err != nil {
  410. slog.Debug("failed to parse sysfs node", "file", usedFile, "error", err)
  411. return 0, fmt.Errorf("failed to parse sysfs node %s %w", usedFile, err)
  412. }
  413. return usedMemory, nil
  414. }