amd_linux.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. package gpu
  2. import (
  3. "bufio"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "log/slog"
  8. "os"
  9. "path/filepath"
  10. "slices"
  11. "strconv"
  12. "strings"
  13. )
  14. // Discovery logic for AMD/ROCm GPUs
  15. const (
  16. DriverVersionFile = "/sys/module/amdgpu/version"
  17. AMDNodesSysfsDir = "/sys/class/kfd/kfd/topology/nodes/"
  18. GPUPropertiesFileGlob = AMDNodesSysfsDir + "*/properties"
  19. // Prefix with the node dir
  20. GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
  21. GPUUsedMemoryFileGlob = "mem_banks/*/used_memory"
  22. RocmStandardLocation = "/opt/rocm/lib"
  23. // TODO find a better way to detect iGPU instead of minimum memory
  24. IGPUMemLimit = 1024 * 1024 * 1024 // 512G is what they typically report, so anything less than 1G must be iGPU
  25. )
  26. var (
  27. // Used to validate if the given ROCm lib is usable
  28. ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
  29. )
  30. // Gather GPU information from the amdgpu driver if any supported GPUs are detected
  31. // HIP_VISIBLE_DEVICES will be set if we detect a mix of unsupported and supported devices
  32. // and the user hasn't already set this variable
  33. func AMDGetGPUInfo(resp *GpuInfo) {
  34. // TODO - DRY this out with windows
  35. if !AMDDetected() {
  36. return
  37. }
  38. skip := map[int]interface{}{}
  39. // Opportunistic logging of driver version to aid in troubleshooting
  40. ver, err := AMDDriverVersion()
  41. if err == nil {
  42. slog.Info("AMD Driver: " + ver)
  43. } else {
  44. // TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
  45. slog.Warn(fmt.Sprintf("ollama recommends running the https://www.amd.com/en/support/linux-drivers: %s", err))
  46. }
  47. // If the user has specified exactly which GPUs to use, look up their memory
  48. visibleDevices := os.Getenv("HIP_VISIBLE_DEVICES")
  49. if visibleDevices != "" {
  50. ids := []int{}
  51. for _, idStr := range strings.Split(visibleDevices, ",") {
  52. id, err := strconv.Atoi(idStr)
  53. if err != nil {
  54. slog.Warn(fmt.Sprintf("malformed HIP_VISIBLE_DEVICES=%s %s", visibleDevices, err))
  55. } else {
  56. ids = append(ids, id)
  57. }
  58. }
  59. amdProcMemLookup(resp, nil, ids)
  60. return
  61. }
  62. // Gather GFX version information from all detected cards
  63. gfx := AMDGFXVersions()
  64. verStrings := []string{}
  65. for i, v := range gfx {
  66. verStrings = append(verStrings, v.ToGFXString())
  67. if v.Major == 0 {
  68. // Silently skip CPUs
  69. skip[i] = struct{}{}
  70. continue
  71. }
  72. if v.Major < 9 {
  73. // TODO consider this a build-time setting if we can support 8xx family GPUs
  74. slog.Warn(fmt.Sprintf("amdgpu [%d] too old %s", i, v.ToGFXString()))
  75. skip[i] = struct{}{}
  76. }
  77. }
  78. slog.Info(fmt.Sprintf("detected amdgpu versions %v", verStrings))
  79. // Abort if all GPUs are skipped
  80. if len(skip) >= len(gfx) {
  81. slog.Info("all detected amdgpus are skipped, falling back to CPU")
  82. return
  83. }
  84. // If we got this far, then we have at least 1 GPU that's a ROCm candidate, so make sure we have a lib
  85. libDir, err := AMDValidateLibDir()
  86. if err != nil {
  87. slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
  88. return
  89. }
  90. updateLibPath(libDir)
  91. gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
  92. if gfxOverride == "" {
  93. supported, err := GetSupportedGFX(libDir)
  94. if err != nil {
  95. slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
  96. return
  97. }
  98. slog.Debug(fmt.Sprintf("rocm supported GPU types %v", supported))
  99. for i, v := range gfx {
  100. if !slices.Contains[[]string, string](supported, v.ToGFXString()) {
  101. slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported))
  102. // TODO - consider discrete markdown just for ROCM troubleshooting?
  103. slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
  104. skip[i] = struct{}{}
  105. } else {
  106. slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString()))
  107. }
  108. }
  109. } else {
  110. slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
  111. }
  112. if len(skip) >= len(gfx) {
  113. slog.Info("all detected amdgpus are skipped, falling back to CPU")
  114. return
  115. }
  116. ids := make([]int, len(gfx))
  117. i := 0
  118. for k := range gfx {
  119. ids[i] = k
  120. i++
  121. }
  122. amdProcMemLookup(resp, skip, ids)
  123. if resp.memInfo.DeviceCount == 0 {
  124. return
  125. }
  126. if len(skip) > 0 {
  127. amdSetVisibleDevices(ids, skip)
  128. }
  129. }
  130. func updateLibPath(libDir string) {
  131. ldPaths := []string{}
  132. if val, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
  133. ldPaths = strings.Split(val, ":")
  134. }
  135. for _, d := range ldPaths {
  136. if d == libDir {
  137. return
  138. }
  139. }
  140. val := strings.Join(append(ldPaths, libDir), ":")
  141. slog.Debug("updated lib path", "LD_LIBRARY_PATH", val)
  142. os.Setenv("LD_LIBRARY_PATH", val)
  143. }
  144. // Walk the sysfs nodes for the available GPUs and gather information from them
  145. // skipping over any devices in the skip map
  146. func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
  147. resp.memInfo.DeviceCount = 0
  148. resp.memInfo.TotalMemory = 0
  149. resp.memInfo.FreeMemory = 0
  150. slog.Debug("discovering VRAM for amdgpu devices")
  151. if len(ids) == 0 {
  152. entries, err := os.ReadDir(AMDNodesSysfsDir)
  153. if err != nil {
  154. slog.Warn(fmt.Sprintf("failed to read amdgpu sysfs %s - %s", AMDNodesSysfsDir, err))
  155. return
  156. }
  157. for _, node := range entries {
  158. if !node.IsDir() {
  159. continue
  160. }
  161. id, err := strconv.Atoi(node.Name())
  162. if err != nil {
  163. slog.Warn("malformed amdgpu sysfs node id " + node.Name())
  164. continue
  165. }
  166. ids = append(ids, id)
  167. }
  168. }
  169. slog.Debug(fmt.Sprintf("amdgpu devices %v", ids))
  170. for _, id := range ids {
  171. if _, skipped := skip[id]; skipped {
  172. continue
  173. }
  174. totalMemory := uint64(0)
  175. usedMemory := uint64(0)
  176. // Adjust for sysfs vs HIP ids
  177. propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id+1), GPUTotalMemoryFileGlob)
  178. propFiles, err := filepath.Glob(propGlob)
  179. if err != nil {
  180. slog.Warn(fmt.Sprintf("error looking up total GPU memory: %s %s", propGlob, err))
  181. }
  182. // 1 or more memory banks - sum the values of all of them
  183. for _, propFile := range propFiles {
  184. fp, err := os.Open(propFile)
  185. if err != nil {
  186. slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", propFile, err))
  187. continue
  188. }
  189. defer fp.Close()
  190. scanner := bufio.NewScanner(fp)
  191. for scanner.Scan() {
  192. line := strings.TrimSpace(scanner.Text())
  193. if strings.HasPrefix(line, "size_in_bytes") {
  194. ver := strings.Fields(line)
  195. if len(ver) != 2 {
  196. slog.Warn("malformed " + line)
  197. continue
  198. }
  199. bankSizeInBytes, err := strconv.ParseUint(ver[1], 10, 64)
  200. if err != nil {
  201. slog.Warn("malformed int " + line)
  202. continue
  203. }
  204. totalMemory += bankSizeInBytes
  205. }
  206. }
  207. }
  208. if totalMemory == 0 {
  209. slog.Warn(fmt.Sprintf("amdgpu [%d] reports zero total memory, skipping", id))
  210. skip[id] = struct{}{}
  211. continue
  212. }
  213. if totalMemory < IGPUMemLimit {
  214. slog.Info(fmt.Sprintf("amdgpu [%d] appears to be an iGPU with %dM reported total memory, skipping", id, totalMemory/1024/1024))
  215. skip[id] = struct{}{}
  216. continue
  217. }
  218. usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUUsedMemoryFileGlob)
  219. usedFiles, err := filepath.Glob(usedGlob)
  220. if err != nil {
  221. slog.Warn(fmt.Sprintf("error looking up used GPU memory: %s %s", usedGlob, err))
  222. continue
  223. }
  224. for _, usedFile := range usedFiles {
  225. fp, err := os.Open(usedFile)
  226. if err != nil {
  227. slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", usedFile, err))
  228. continue
  229. }
  230. defer fp.Close()
  231. data, err := io.ReadAll(fp)
  232. if err != nil {
  233. slog.Warn(fmt.Sprintf("failed to read sysfs node file %s: %s", usedFile, err))
  234. continue
  235. }
  236. used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
  237. if err != nil {
  238. slog.Warn(fmt.Sprintf("malformed used memory %s: %s", string(data), err))
  239. continue
  240. }
  241. usedMemory += used
  242. }
  243. slog.Info(fmt.Sprintf("[%d] amdgpu totalMemory %dM", id, totalMemory/1024/1024))
  244. slog.Info(fmt.Sprintf("[%d] amdgpu freeMemory %dM", id, (totalMemory-usedMemory)/1024/1024))
  245. resp.memInfo.DeviceCount++
  246. resp.memInfo.TotalMemory += totalMemory
  247. resp.memInfo.FreeMemory += (totalMemory - usedMemory)
  248. }
  249. if resp.memInfo.DeviceCount > 0 {
  250. resp.Library = "rocm"
  251. }
  252. }
  253. // Quick check for AMD driver so we can skip amdgpu discovery if not present
  254. func AMDDetected() bool {
  255. // Some driver versions (older?) don't have a version file, so just lookup the parent dir
  256. sysfsDir := filepath.Dir(DriverVersionFile)
  257. _, err := os.Stat(sysfsDir)
  258. if errors.Is(err, os.ErrNotExist) {
  259. slog.Debug("amdgpu driver not detected " + sysfsDir)
  260. return false
  261. } else if err != nil {
  262. slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err))
  263. return false
  264. }
  265. return true
  266. }
  267. func setupLink(source, target string) error {
  268. if err := os.RemoveAll(target); err != nil {
  269. return fmt.Errorf("failed to remove old rocm directory %s %w", target, err)
  270. }
  271. if err := os.Symlink(source, target); err != nil {
  272. return fmt.Errorf("failed to create link %s => %s %w", source, target, err)
  273. }
  274. slog.Debug(fmt.Sprintf("host rocm linked %s => %s", source, target))
  275. return nil
  276. }
  277. // Ensure the AMD rocm lib dir is wired up
  278. // Prefer to use host installed ROCm, as long as it meets our minimum requirements
  279. // failing that, tell the user how to download it on their own
  280. func AMDValidateLibDir() (string, error) {
  281. // We rely on the rpath compiled into our library to find rocm
  282. // so we establish a symlink to wherever we find it on the system
  283. // to <payloads>/rocm
  284. payloadsDir, err := PayloadsDir()
  285. if err != nil {
  286. return "", err
  287. }
  288. // If we already have a rocm dependency wired, nothing more to do
  289. rocmTargetDir := filepath.Clean(filepath.Join(payloadsDir, "..", "rocm"))
  290. if rocmLibUsable(rocmTargetDir) {
  291. return rocmTargetDir, nil
  292. }
  293. // next to the running binary
  294. exe, err := os.Executable()
  295. if err == nil {
  296. peerDir := filepath.Dir(exe)
  297. if rocmLibUsable(peerDir) {
  298. slog.Debug("detected ROCM next to ollama executable " + peerDir)
  299. return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
  300. }
  301. peerDir = filepath.Join(filepath.Dir(exe), "rocm")
  302. if rocmLibUsable(peerDir) {
  303. slog.Debug("detected ROCM next to ollama executable " + peerDir)
  304. return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
  305. }
  306. }
  307. // Well known ollama installer path
  308. installedRocmDir := "/usr/share/ollama/lib/rocm"
  309. if rocmLibUsable(installedRocmDir) {
  310. return rocmTargetDir, setupLink(installedRocmDir, rocmTargetDir)
  311. }
  312. // Prefer explicit HIP env var
  313. hipPath := os.Getenv("HIP_PATH")
  314. if hipPath != "" {
  315. hipLibDir := filepath.Join(hipPath, "lib")
  316. if rocmLibUsable(hipLibDir) {
  317. slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
  318. return rocmTargetDir, setupLink(hipLibDir, rocmTargetDir)
  319. }
  320. }
  321. // Scan the library path for potential matches
  322. ldPaths := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
  323. for _, ldPath := range ldPaths {
  324. d, err := filepath.Abs(ldPath)
  325. if err != nil {
  326. continue
  327. }
  328. if rocmLibUsable(d) {
  329. return rocmTargetDir, setupLink(d, rocmTargetDir)
  330. }
  331. }
  332. // Well known location(s)
  333. if rocmLibUsable("/opt/rocm/lib") {
  334. return rocmTargetDir, setupLink("/opt/rocm/lib", rocmTargetDir)
  335. }
  336. // If we still haven't found a usable rocm, the user will have to install it on their own
  337. slog.Warn("amdgpu detected, but no compatible rocm library found. Either install rocm v6, or follow manual install instructions at https://github.com/ollama/ollama/blob/main/docs/linux.md#manual-install")
  338. return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
  339. }
  340. func AMDDriverVersion() (string, error) {
  341. _, err := os.Stat(DriverVersionFile)
  342. if err != nil {
  343. return "", fmt.Errorf("amdgpu version file missing: %s %w", DriverVersionFile, err)
  344. }
  345. fp, err := os.Open(DriverVersionFile)
  346. if err != nil {
  347. return "", err
  348. }
  349. defer fp.Close()
  350. verString, err := io.ReadAll(fp)
  351. if err != nil {
  352. return "", err
  353. }
  354. return strings.TrimSpace(string(verString)), nil
  355. }
  356. func AMDGFXVersions() map[int]Version {
  357. // The amdgpu driver always exposes the host CPU as node 0, but we have to skip that and subtract one
  358. // from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
  359. res := map[int]Version{}
  360. matches, _ := filepath.Glob(GPUPropertiesFileGlob)
  361. for _, match := range matches {
  362. fp, err := os.Open(match)
  363. if err != nil {
  364. slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err))
  365. continue
  366. }
  367. defer fp.Close()
  368. i, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
  369. if err != nil {
  370. slog.Debug(fmt.Sprintf("failed to parse node ID %s", err))
  371. continue
  372. }
  373. if i == 0 {
  374. // Skipping the CPU
  375. continue
  376. }
  377. // Align with HIP IDs (zero is first GPU, not CPU)
  378. i -= 1
  379. scanner := bufio.NewScanner(fp)
  380. for scanner.Scan() {
  381. line := strings.TrimSpace(scanner.Text())
  382. if strings.HasPrefix(line, "gfx_target_version") {
  383. ver := strings.Fields(line)
  384. if len(ver) != 2 || len(ver[1]) < 5 {
  385. if ver[1] != "0" {
  386. slog.Debug("malformed " + line)
  387. }
  388. res[i] = Version{
  389. Major: 0,
  390. Minor: 0,
  391. Patch: 0,
  392. }
  393. continue
  394. }
  395. l := len(ver[1])
  396. patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32)
  397. minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
  398. major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32)
  399. if err1 != nil || err2 != nil || err3 != nil {
  400. slog.Debug("malformed int " + line)
  401. continue
  402. }
  403. res[i] = Version{
  404. Major: uint(major),
  405. Minor: uint(minor),
  406. Patch: uint(patch),
  407. }
  408. }
  409. }
  410. }
  411. return res
  412. }
  413. func (v Version) ToGFXString() string {
  414. return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch)
  415. }