amd_linux.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. package gpu
  2. import (
  3. "bufio"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "log/slog"
  8. "os"
  9. "path/filepath"
  10. "slices"
  11. "strconv"
  12. "strings"
  13. )
  14. // Discovery logic for AMD/ROCm GPUs
  15. const (
  16. DriverVersionFile = "/sys/module/amdgpu/version"
  17. AMDNodesSysfsDir = "/sys/class/kfd/kfd/topology/nodes/"
  18. GPUPropertiesFileGlob = AMDNodesSysfsDir + "*/properties"
  19. // Prefix with the node dir
  20. GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
  21. GPUUsedMemoryFileGlob = "mem_banks/*/used_memory"
  22. RocmStandardLocation = "/opt/rocm/lib"
  23. )
  24. var (
  25. // Used to validate if the given ROCm lib is usable
  26. ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
  27. )
  28. // Gather GPU information from the amdgpu driver if any supported GPUs are detected
  29. // HIP_VISIBLE_DEVICES will be set if we detect a mix of unsupported and supported devices
  30. // and the user hasn't already set this variable
  31. func AMDGetGPUInfo(resp *GpuInfo) {
  32. // TODO - DRY this out with windows
  33. if !AMDDetected() {
  34. return
  35. }
  36. skip := map[int]interface{}{}
  37. // Opportunistic logging of driver version to aid in troubleshooting
  38. ver, err := AMDDriverVersion()
  39. if err == nil {
  40. slog.Info("AMD Driver: " + ver)
  41. } else {
  42. // TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
  43. slog.Warn(fmt.Sprintf("ollama recommends running the https://www.amd.com/en/support/linux-drivers: %s", err))
  44. }
  45. // If the user has specified exactly which GPUs to use, look up their memory
  46. visibleDevices := os.Getenv("HIP_VISIBLE_DEVICES")
  47. if visibleDevices != "" {
  48. ids := []int{}
  49. for _, idStr := range strings.Split(visibleDevices, ",") {
  50. id, err := strconv.Atoi(idStr)
  51. if err != nil {
  52. slog.Warn(fmt.Sprintf("malformed HIP_VISIBLE_DEVICES=%s %s", visibleDevices, err))
  53. } else {
  54. ids = append(ids, id)
  55. }
  56. }
  57. amdProcMemLookup(resp, nil, ids)
  58. return
  59. }
  60. // Gather GFX version information from all detected cards
  61. gfx := AMDGFXVersions()
  62. verStrings := []string{}
  63. for i, v := range gfx {
  64. verStrings = append(verStrings, v.ToGFXString())
  65. if v.Major == 0 {
  66. // Silently skip CPUs
  67. skip[i] = struct{}{}
  68. continue
  69. }
  70. if v.Major < 9 {
  71. // TODO consider this a build-time setting if we can support 8xx family GPUs
  72. slog.Warn(fmt.Sprintf("amdgpu [%d] too old %s", i, v.ToGFXString()))
  73. skip[i] = struct{}{}
  74. }
  75. }
  76. slog.Info(fmt.Sprintf("detected amdgpu versions %v", verStrings))
  77. // Abort if all GPUs are skipped
  78. if len(skip) >= len(gfx) {
  79. slog.Info("all detected amdgpus are skipped, falling back to CPU")
  80. return
  81. }
  82. // If we got this far, then we have at least 1 GPU that's a ROCm candidate, so make sure we have a lib
  83. libDir, err := AMDValidateLibDir()
  84. if err != nil {
  85. slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
  86. return
  87. }
  88. gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
  89. if gfxOverride == "" {
  90. supported, err := GetSupportedGFX(libDir)
  91. if err != nil {
  92. slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
  93. return
  94. }
  95. slog.Debug(fmt.Sprintf("rocm supported GPU types %v", supported))
  96. for i, v := range gfx {
  97. if !slices.Contains[[]string, string](supported, v.ToGFXString()) {
  98. slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported))
  99. // TODO - consider discrete markdown just for ROCM troubleshooting?
  100. slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
  101. skip[i] = struct{}{}
  102. } else {
  103. slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString()))
  104. }
  105. }
  106. } else {
  107. slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
  108. }
  109. if len(skip) >= len(gfx) {
  110. slog.Info("all detected amdgpus are skipped, falling back to CPU")
  111. return
  112. }
  113. ids := make([]int, len(gfx))
  114. i := 0
  115. for k := range gfx {
  116. ids[i] = k
  117. i++
  118. }
  119. amdProcMemLookup(resp, skip, ids)
  120. if resp.memInfo.DeviceCount == 0 {
  121. return
  122. }
  123. if len(skip) > 0 {
  124. amdSetVisibleDevices(ids, skip)
  125. }
  126. }
  127. // Walk the sysfs nodes for the available GPUs and gather information from them
  128. // skipping over any devices in the skip map
  129. func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
  130. resp.memInfo.DeviceCount = 0
  131. resp.memInfo.TotalMemory = 0
  132. resp.memInfo.FreeMemory = 0
  133. if len(ids) == 0 {
  134. slog.Debug("discovering all amdgpu devices")
  135. entries, err := os.ReadDir(AMDNodesSysfsDir)
  136. if err != nil {
  137. slog.Warn(fmt.Sprintf("failed to read amdgpu sysfs %s - %s", AMDNodesSysfsDir, err))
  138. return
  139. }
  140. for _, node := range entries {
  141. if !node.IsDir() {
  142. continue
  143. }
  144. id, err := strconv.Atoi(node.Name())
  145. if err != nil {
  146. slog.Warn("malformed amdgpu sysfs node id " + node.Name())
  147. continue
  148. }
  149. ids = append(ids, id)
  150. }
  151. }
  152. slog.Debug(fmt.Sprintf("discovering amdgpu devices %v", ids))
  153. for _, id := range ids {
  154. if _, skipped := skip[id]; skipped {
  155. continue
  156. }
  157. totalMemory := uint64(0)
  158. usedMemory := uint64(0)
  159. propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUTotalMemoryFileGlob)
  160. propFiles, err := filepath.Glob(propGlob)
  161. if err != nil {
  162. slog.Warn(fmt.Sprintf("error looking up total GPU memory: %s %s", propGlob, err))
  163. }
  164. // 1 or more memory banks - sum the values of all of them
  165. for _, propFile := range propFiles {
  166. fp, err := os.Open(propFile)
  167. if err != nil {
  168. slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", propFile, err))
  169. continue
  170. }
  171. defer fp.Close()
  172. scanner := bufio.NewScanner(fp)
  173. for scanner.Scan() {
  174. line := strings.TrimSpace(scanner.Text())
  175. if strings.HasPrefix(line, "size_in_bytes") {
  176. ver := strings.Fields(line)
  177. if len(ver) != 2 {
  178. slog.Warn("malformed " + line)
  179. continue
  180. }
  181. bankSizeInBytes, err := strconv.ParseUint(ver[1], 10, 64)
  182. if err != nil {
  183. slog.Warn("malformed int " + line)
  184. continue
  185. }
  186. totalMemory += bankSizeInBytes
  187. }
  188. }
  189. }
  190. if totalMemory == 0 {
  191. continue
  192. }
  193. usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUUsedMemoryFileGlob)
  194. usedFiles, err := filepath.Glob(usedGlob)
  195. if err != nil {
  196. slog.Warn(fmt.Sprintf("error looking up used GPU memory: %s %s", usedGlob, err))
  197. continue
  198. }
  199. for _, usedFile := range usedFiles {
  200. fp, err := os.Open(usedFile)
  201. if err != nil {
  202. slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", usedFile, err))
  203. continue
  204. }
  205. defer fp.Close()
  206. data, err := io.ReadAll(fp)
  207. if err != nil {
  208. slog.Warn(fmt.Sprintf("failed to read sysfs node file %s: %s", usedFile, err))
  209. continue
  210. }
  211. used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
  212. if err != nil {
  213. slog.Warn(fmt.Sprintf("malformed used memory %s: %s", string(data), err))
  214. continue
  215. }
  216. usedMemory += used
  217. }
  218. slog.Info(fmt.Sprintf("[%d] amdgpu totalMemory %d", id, totalMemory))
  219. slog.Info(fmt.Sprintf("[%d] amdgpu freeMemory %d", id, (totalMemory - usedMemory)))
  220. resp.memInfo.DeviceCount++
  221. resp.memInfo.TotalMemory += totalMemory
  222. resp.memInfo.FreeMemory += (totalMemory - usedMemory)
  223. }
  224. if resp.memInfo.DeviceCount > 0 {
  225. resp.Library = "rocm"
  226. }
  227. }
  228. // Quick check for AMD driver so we can skip amdgpu discovery if not present
  229. func AMDDetected() bool {
  230. // Some driver versions (older?) don't have a version file, so just lookup the parent dir
  231. sysfsDir := filepath.Dir(DriverVersionFile)
  232. _, err := os.Stat(sysfsDir)
  233. if errors.Is(err, os.ErrNotExist) {
  234. slog.Debug("amdgpu driver not detected " + sysfsDir)
  235. return false
  236. } else if err != nil {
  237. slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err))
  238. return false
  239. }
  240. return true
  241. }
  242. func setupLink(source, target string) error {
  243. if err := os.RemoveAll(target); err != nil {
  244. return fmt.Errorf("failed to remove old rocm directory %s %w", target, err)
  245. }
  246. if err := os.Symlink(source, target); err != nil {
  247. return fmt.Errorf("failed to create link %s => %s %w", source, target, err)
  248. }
  249. slog.Debug(fmt.Sprintf("host rocm linked %s => %s", source, target))
  250. return nil
  251. }
  252. // Ensure the AMD rocm lib dir is wired up
  253. // Prefer to use host installed ROCm, as long as it meets our minimum requirements
  254. // failing that, tell the user how to download it on their own
  255. func AMDValidateLibDir() (string, error) {
  256. // We rely on the rpath compiled into our library to find rocm
  257. // so we establish a symlink to wherever we find it on the system
  258. // to <payloads>/rocm
  259. payloadsDir, err := PayloadsDir()
  260. if err != nil {
  261. return "", err
  262. }
  263. // If we already have a rocm dependency wired, nothing more to do
  264. rocmTargetDir := filepath.Join(payloadsDir, "rocm")
  265. if rocmLibUsable(rocmTargetDir) {
  266. return rocmTargetDir, nil
  267. }
  268. // Well known ollama installer path
  269. installedRocmDir := "/usr/share/ollama/lib/rocm"
  270. if rocmLibUsable(installedRocmDir) {
  271. return rocmTargetDir, setupLink(installedRocmDir, rocmTargetDir)
  272. }
  273. // Prefer explicit HIP env var
  274. hipPath := os.Getenv("HIP_PATH")
  275. if hipPath != "" {
  276. hipLibDir := filepath.Join(hipPath, "lib")
  277. if rocmLibUsable(hipLibDir) {
  278. slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
  279. return rocmTargetDir, setupLink(hipLibDir, rocmTargetDir)
  280. }
  281. }
  282. // Scan the library path for potential matches
  283. ldPaths := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
  284. for _, ldPath := range ldPaths {
  285. d, err := filepath.Abs(ldPath)
  286. if err != nil {
  287. continue
  288. }
  289. if rocmLibUsable(d) {
  290. return rocmTargetDir, setupLink(d, rocmTargetDir)
  291. }
  292. }
  293. // Well known location(s)
  294. if rocmLibUsable("/opt/rocm/lib") {
  295. return rocmTargetDir, setupLink("/opt/rocm/lib", rocmTargetDir)
  296. }
  297. // If we still haven't found a usable rocm, the user will have to install it on their own
  298. slog.Warn("amdgpu detected, but no compatible rocm library found. Either install rocm v6, or follow manual install instructions at https://github.com/ollama/ollama/blob/main/docs/linux.md#manual-install")
  299. return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
  300. }
  301. func AMDDriverVersion() (string, error) {
  302. _, err := os.Stat(DriverVersionFile)
  303. if err != nil {
  304. return "", fmt.Errorf("amdgpu version file missing: %s %w", DriverVersionFile, err)
  305. }
  306. fp, err := os.Open(DriverVersionFile)
  307. if err != nil {
  308. return "", err
  309. }
  310. defer fp.Close()
  311. verString, err := io.ReadAll(fp)
  312. if err != nil {
  313. return "", err
  314. }
  315. return strings.TrimSpace(string(verString)), nil
  316. }
  317. func AMDGFXVersions() map[int]Version {
  318. res := map[int]Version{}
  319. matches, _ := filepath.Glob(GPUPropertiesFileGlob)
  320. for _, match := range matches {
  321. fp, err := os.Open(match)
  322. if err != nil {
  323. slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err))
  324. continue
  325. }
  326. defer fp.Close()
  327. i, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
  328. if err != nil {
  329. slog.Debug(fmt.Sprintf("failed to parse node ID %s", err))
  330. continue
  331. }
  332. scanner := bufio.NewScanner(fp)
  333. for scanner.Scan() {
  334. line := strings.TrimSpace(scanner.Text())
  335. if strings.HasPrefix(line, "gfx_target_version") {
  336. ver := strings.Fields(line)
  337. if len(ver) != 2 || len(ver[1]) < 5 {
  338. if ver[1] == "0" {
  339. // Silently skip the CPU
  340. continue
  341. } else {
  342. slog.Debug("malformed " + line)
  343. }
  344. res[i] = Version{
  345. Major: 0,
  346. Minor: 0,
  347. Patch: 0,
  348. }
  349. continue
  350. }
  351. l := len(ver[1])
  352. patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32)
  353. minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
  354. major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32)
  355. if err1 != nil || err2 != nil || err3 != nil {
  356. slog.Debug("malformed int " + line)
  357. continue
  358. }
  359. res[i] = Version{
  360. Major: uint(major),
  361. Minor: uint(minor),
  362. Patch: uint(patch),
  363. }
  364. }
  365. }
  366. }
  367. return res
  368. }
  369. func (v Version) ToGFXString() string {
  370. return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch)
  371. }