amd_linux.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. package gpu
  2. import (
  3. "bufio"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "log/slog"
  8. "os"
  9. "path/filepath"
  10. "slices"
  11. "strconv"
  12. "strings"
  13. )
  14. // Discovery logic for AMD/ROCm GPUs
  15. const (
  16. DriverVersionFile = "/sys/module/amdgpu/version"
  17. AMDNodesSysfsDir = "/sys/class/kfd/kfd/topology/nodes/"
  18. GPUPropertiesFileGlob = AMDNodesSysfsDir + "*/properties"
  19. // Prefix with the node dir
  20. GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
  21. GPUUsedMemoryFileGlob = "mem_banks/*/used_memory"
  22. RocmStandardLocation = "/opt/rocm/lib"
  23. // TODO find a better way to detect iGPU instead of minimum memory
  24. IGPUMemLimit = 1024 * 1024 * 1024 // 512G is what they typically report, so anything less than 1G must be iGPU
  25. )
  26. var (
  27. // Used to validate if the given ROCm lib is usable
  28. ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
  29. )
  30. // Gather GPU information from the amdgpu driver if any supported GPUs are detected
  31. // HIP_VISIBLE_DEVICES will be set if we detect a mix of unsupported and supported devices
  32. // and the user hasn't already set this variable
  33. func AMDGetGPUInfo(resp *GpuInfo) {
  34. // TODO - DRY this out with windows
  35. if !AMDDetected() {
  36. return
  37. }
  38. skip := map[int]interface{}{}
  39. // Opportunistic logging of driver version to aid in troubleshooting
  40. ver, err := AMDDriverVersion()
  41. if err == nil {
  42. slog.Info("AMD Driver: " + ver)
  43. } else {
  44. // TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
  45. slog.Warn(fmt.Sprintf("ollama recommends running the https://www.amd.com/en/support/linux-drivers: %s", err))
  46. }
  47. // If the user has specified exactly which GPUs to use, look up their memory
  48. visibleDevices := os.Getenv("HIP_VISIBLE_DEVICES")
  49. if visibleDevices != "" {
  50. ids := []int{}
  51. for _, idStr := range strings.Split(visibleDevices, ",") {
  52. id, err := strconv.Atoi(idStr)
  53. if err != nil {
  54. slog.Warn(fmt.Sprintf("malformed HIP_VISIBLE_DEVICES=%s %s", visibleDevices, err))
  55. } else {
  56. ids = append(ids, id)
  57. }
  58. }
  59. amdProcMemLookup(resp, nil, ids)
  60. return
  61. }
  62. // Gather GFX version information from all detected cards
  63. gfx := AMDGFXVersions()
  64. verStrings := []string{}
  65. for i, v := range gfx {
  66. verStrings = append(verStrings, v.ToGFXString())
  67. if v.Major == 0 {
  68. // Silently skip CPUs
  69. skip[i] = struct{}{}
  70. continue
  71. }
  72. if v.Major < 9 {
  73. // TODO consider this a build-time setting if we can support 8xx family GPUs
  74. slog.Warn(fmt.Sprintf("amdgpu [%d] too old %s", i, v.ToGFXString()))
  75. skip[i] = struct{}{}
  76. }
  77. }
  78. slog.Info(fmt.Sprintf("detected amdgpu versions %v", verStrings))
  79. // Abort if all GPUs are skipped
  80. if len(skip) >= len(gfx) {
  81. slog.Info("all detected amdgpus are skipped, falling back to CPU")
  82. return
  83. }
  84. // If we got this far, then we have at least 1 GPU that's a ROCm candidate, so make sure we have a lib
  85. libDir, err := AMDValidateLibDir()
  86. if err != nil {
  87. slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
  88. return
  89. }
  90. gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
  91. if gfxOverride == "" {
  92. supported, err := GetSupportedGFX(libDir)
  93. if err != nil {
  94. slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
  95. return
  96. }
  97. slog.Debug(fmt.Sprintf("rocm supported GPU types %v", supported))
  98. for i, v := range gfx {
  99. if !slices.Contains[[]string, string](supported, v.ToGFXString()) {
  100. slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported))
  101. // TODO - consider discrete markdown just for ROCM troubleshooting?
  102. slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
  103. skip[i] = struct{}{}
  104. } else {
  105. slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString()))
  106. }
  107. }
  108. } else {
  109. slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
  110. }
  111. if len(skip) >= len(gfx) {
  112. slog.Info("all detected amdgpus are skipped, falling back to CPU")
  113. return
  114. }
  115. ids := make([]int, len(gfx))
  116. i := 0
  117. for k := range gfx {
  118. ids[i] = k
  119. i++
  120. }
  121. amdProcMemLookup(resp, skip, ids)
  122. if resp.memInfo.DeviceCount == 0 {
  123. return
  124. }
  125. if len(skip) > 0 {
  126. amdSetVisibleDevices(ids, skip)
  127. }
  128. }
  129. // Walk the sysfs nodes for the available GPUs and gather information from them
  130. // skipping over any devices in the skip map
  131. func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
  132. resp.memInfo.DeviceCount = 0
  133. resp.memInfo.TotalMemory = 0
  134. resp.memInfo.FreeMemory = 0
  135. slog.Debug("discovering VRAM for amdgpu devices")
  136. if len(ids) == 0 {
  137. entries, err := os.ReadDir(AMDNodesSysfsDir)
  138. if err != nil {
  139. slog.Warn(fmt.Sprintf("failed to read amdgpu sysfs %s - %s", AMDNodesSysfsDir, err))
  140. return
  141. }
  142. for _, node := range entries {
  143. if !node.IsDir() {
  144. continue
  145. }
  146. id, err := strconv.Atoi(node.Name())
  147. if err != nil {
  148. slog.Warn("malformed amdgpu sysfs node id " + node.Name())
  149. continue
  150. }
  151. ids = append(ids, id)
  152. }
  153. }
  154. slog.Debug(fmt.Sprintf("amdgpu devices %v", ids))
  155. for _, id := range ids {
  156. if _, skipped := skip[id]; skipped {
  157. continue
  158. }
  159. totalMemory := uint64(0)
  160. usedMemory := uint64(0)
  161. // Adjust for sysfs vs HIP ids
  162. propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id+1), GPUTotalMemoryFileGlob)
  163. propFiles, err := filepath.Glob(propGlob)
  164. if err != nil {
  165. slog.Warn(fmt.Sprintf("error looking up total GPU memory: %s %s", propGlob, err))
  166. }
  167. // 1 or more memory banks - sum the values of all of them
  168. for _, propFile := range propFiles {
  169. fp, err := os.Open(propFile)
  170. if err != nil {
  171. slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", propFile, err))
  172. continue
  173. }
  174. defer fp.Close()
  175. scanner := bufio.NewScanner(fp)
  176. for scanner.Scan() {
  177. line := strings.TrimSpace(scanner.Text())
  178. if strings.HasPrefix(line, "size_in_bytes") {
  179. ver := strings.Fields(line)
  180. if len(ver) != 2 {
  181. slog.Warn("malformed " + line)
  182. continue
  183. }
  184. bankSizeInBytes, err := strconv.ParseUint(ver[1], 10, 64)
  185. if err != nil {
  186. slog.Warn("malformed int " + line)
  187. continue
  188. }
  189. totalMemory += bankSizeInBytes
  190. }
  191. }
  192. }
  193. if totalMemory == 0 {
  194. slog.Warn(fmt.Sprintf("amdgpu [%d] reports zero total memory, skipping", id))
  195. skip[id] = struct{}{}
  196. continue
  197. }
  198. if totalMemory < IGPUMemLimit {
  199. slog.Info(fmt.Sprintf("amdgpu [%d] appears to be an iGPU with %dM reported total memory, skipping", id, totalMemory/1024/1024))
  200. skip[id] = struct{}{}
  201. continue
  202. }
  203. usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUUsedMemoryFileGlob)
  204. usedFiles, err := filepath.Glob(usedGlob)
  205. if err != nil {
  206. slog.Warn(fmt.Sprintf("error looking up used GPU memory: %s %s", usedGlob, err))
  207. continue
  208. }
  209. for _, usedFile := range usedFiles {
  210. fp, err := os.Open(usedFile)
  211. if err != nil {
  212. slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", usedFile, err))
  213. continue
  214. }
  215. defer fp.Close()
  216. data, err := io.ReadAll(fp)
  217. if err != nil {
  218. slog.Warn(fmt.Sprintf("failed to read sysfs node file %s: %s", usedFile, err))
  219. continue
  220. }
  221. used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
  222. if err != nil {
  223. slog.Warn(fmt.Sprintf("malformed used memory %s: %s", string(data), err))
  224. continue
  225. }
  226. usedMemory += used
  227. }
  228. slog.Info(fmt.Sprintf("[%d] amdgpu totalMemory %dM", id, totalMemory/1024/1024))
  229. slog.Info(fmt.Sprintf("[%d] amdgpu freeMemory %dM", id, (totalMemory-usedMemory)/1024/1024))
  230. resp.memInfo.DeviceCount++
  231. resp.memInfo.TotalMemory += totalMemory
  232. resp.memInfo.FreeMemory += (totalMemory - usedMemory)
  233. }
  234. if resp.memInfo.DeviceCount > 0 {
  235. resp.Library = "rocm"
  236. }
  237. }
  238. // Quick check for AMD driver so we can skip amdgpu discovery if not present
  239. func AMDDetected() bool {
  240. // Some driver versions (older?) don't have a version file, so just lookup the parent dir
  241. sysfsDir := filepath.Dir(DriverVersionFile)
  242. _, err := os.Stat(sysfsDir)
  243. if errors.Is(err, os.ErrNotExist) {
  244. slog.Debug("amdgpu driver not detected " + sysfsDir)
  245. return false
  246. } else if err != nil {
  247. slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err))
  248. return false
  249. }
  250. return true
  251. }
  252. func setupLink(source, target string) error {
  253. if err := os.RemoveAll(target); err != nil {
  254. return fmt.Errorf("failed to remove old rocm directory %s %w", target, err)
  255. }
  256. if err := os.Symlink(source, target); err != nil {
  257. return fmt.Errorf("failed to create link %s => %s %w", source, target, err)
  258. }
  259. slog.Debug(fmt.Sprintf("host rocm linked %s => %s", source, target))
  260. return nil
  261. }
  262. // Ensure the AMD rocm lib dir is wired up
  263. // Prefer to use host installed ROCm, as long as it meets our minimum requirements
  264. // failing that, tell the user how to download it on their own
  265. func AMDValidateLibDir() (string, error) {
  266. // We rely on the rpath compiled into our library to find rocm
  267. // so we establish a symlink to wherever we find it on the system
  268. // to <payloads>/rocm
  269. payloadsDir, err := PayloadsDir()
  270. if err != nil {
  271. return "", err
  272. }
  273. // If we already have a rocm dependency wired, nothing more to do
  274. rocmTargetDir := filepath.Clean(filepath.Join(payloadsDir, "..", "rocm"))
  275. if rocmLibUsable(rocmTargetDir) {
  276. return rocmTargetDir, nil
  277. }
  278. // next to the running binary
  279. exe, err := os.Executable()
  280. if err == nil {
  281. peerDir := filepath.Dir(exe)
  282. if rocmLibUsable(peerDir) {
  283. slog.Debug("detected ROCM next to ollama executable " + peerDir)
  284. return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
  285. }
  286. peerDir = filepath.Join(filepath.Dir(exe), "rocm")
  287. if rocmLibUsable(peerDir) {
  288. slog.Debug("detected ROCM next to ollama executable " + peerDir)
  289. return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
  290. }
  291. }
  292. // Well known ollama installer path
  293. installedRocmDir := "/usr/share/ollama/lib/rocm"
  294. if rocmLibUsable(installedRocmDir) {
  295. return rocmTargetDir, setupLink(installedRocmDir, rocmTargetDir)
  296. }
  297. // Prefer explicit HIP env var
  298. hipPath := os.Getenv("HIP_PATH")
  299. if hipPath != "" {
  300. hipLibDir := filepath.Join(hipPath, "lib")
  301. if rocmLibUsable(hipLibDir) {
  302. slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
  303. return rocmTargetDir, setupLink(hipLibDir, rocmTargetDir)
  304. }
  305. }
  306. // Scan the library path for potential matches
  307. ldPaths := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
  308. for _, ldPath := range ldPaths {
  309. d, err := filepath.Abs(ldPath)
  310. if err != nil {
  311. continue
  312. }
  313. if rocmLibUsable(d) {
  314. return rocmTargetDir, setupLink(d, rocmTargetDir)
  315. }
  316. }
  317. // Well known location(s)
  318. if rocmLibUsable("/opt/rocm/lib") {
  319. return rocmTargetDir, setupLink("/opt/rocm/lib", rocmTargetDir)
  320. }
  321. // If we still haven't found a usable rocm, the user will have to install it on their own
  322. slog.Warn("amdgpu detected, but no compatible rocm library found. Either install rocm v6, or follow manual install instructions at https://github.com/ollama/ollama/blob/main/docs/linux.md#manual-install")
  323. return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
  324. }
  325. func AMDDriverVersion() (string, error) {
  326. _, err := os.Stat(DriverVersionFile)
  327. if err != nil {
  328. return "", fmt.Errorf("amdgpu version file missing: %s %w", DriverVersionFile, err)
  329. }
  330. fp, err := os.Open(DriverVersionFile)
  331. if err != nil {
  332. return "", err
  333. }
  334. defer fp.Close()
  335. verString, err := io.ReadAll(fp)
  336. if err != nil {
  337. return "", err
  338. }
  339. return strings.TrimSpace(string(verString)), nil
  340. }
  341. func AMDGFXVersions() map[int]Version {
  342. // The amdgpu driver always exposes the host CPU as node 0, but we have to skip that and subtract one
  343. // from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
  344. res := map[int]Version{}
  345. matches, _ := filepath.Glob(GPUPropertiesFileGlob)
  346. for _, match := range matches {
  347. fp, err := os.Open(match)
  348. if err != nil {
  349. slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err))
  350. continue
  351. }
  352. defer fp.Close()
  353. i, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
  354. if err != nil {
  355. slog.Debug(fmt.Sprintf("failed to parse node ID %s", err))
  356. continue
  357. }
  358. if i == 0 {
  359. // Skipping the CPU
  360. continue
  361. }
  362. // Align with HIP IDs (zero is first GPU, not CPU)
  363. i -= 1
  364. scanner := bufio.NewScanner(fp)
  365. for scanner.Scan() {
  366. line := strings.TrimSpace(scanner.Text())
  367. if strings.HasPrefix(line, "gfx_target_version") {
  368. ver := strings.Fields(line)
  369. if len(ver) != 2 || len(ver[1]) < 5 {
  370. if ver[1] != "0" {
  371. slog.Debug("malformed " + line)
  372. }
  373. res[i] = Version{
  374. Major: 0,
  375. Minor: 0,
  376. Patch: 0,
  377. }
  378. continue
  379. }
  380. l := len(ver[1])
  381. patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32)
  382. minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
  383. major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32)
  384. if err1 != nil || err2 != nil || err3 != nil {
  385. slog.Debug("malformed int " + line)
  386. continue
  387. }
  388. res[i] = Version{
  389. Major: uint(major),
  390. Minor: uint(minor),
  391. Patch: uint(patch),
  392. }
  393. }
  394. }
  395. }
  396. return res
  397. }
  398. func (v Version) ToGFXString() string {
  399. return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch)
  400. }