amd_linux.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. package gpu
  2. import (
  3. "bufio"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "log/slog"
  8. "os"
  9. "path/filepath"
  10. "slices"
  11. "strconv"
  12. "strings"
  13. "github.com/jmorganca/ollama/version"
  14. )
  15. // Discovery logic for AMD/ROCm GPUs
  16. const (
  17. curlMsg = "curl -fsSL https://github.com/ollama/ollama/releases/download/v%s/rocm-amd64-deps.tgz | tar -zxf - -C %s"
  18. DriverVersionFile = "/sys/module/amdgpu/version"
  19. AMDNodesSysfsDir = "/sys/class/kfd/kfd/topology/nodes/"
  20. GPUPropertiesFileGlob = AMDNodesSysfsDir + "*/properties"
  21. // Prefix with the node dir
  22. GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
  23. GPUUsedMemoryFileGlob = "mem_banks/*/used_memory"
  24. RocmStandardLocation = "/opt/rocm/lib"
  25. )
  26. var (
  27. // Used to validate if the given ROCm lib is usable
  28. ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
  29. )
  30. // Gather GPU information from the amdgpu driver if any supported GPUs are detected
  31. // HIP_VISIBLE_DEVICES will be set if we detect a mix of unsupported and supported devices
  32. // and the user hasn't already set this variable
  33. func AMDGetGPUInfo(resp *GpuInfo) {
  34. // TODO - DRY this out with windows
  35. if !AMDDetected() {
  36. return
  37. }
  38. skip := map[int]interface{}{}
  39. // Opportunistic logging of driver version to aid in troubleshooting
  40. ver, err := AMDDriverVersion()
  41. if err == nil {
  42. slog.Info("AMD Driver: " + ver)
  43. } else {
  44. // TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
  45. slog.Warn(fmt.Sprintf("ollama recommends running the https://www.amd.com/en/support/linux-drivers: %s", err))
  46. }
  47. // If the user has specified exactly which GPUs to use, look up their memory
  48. visibleDevices := os.Getenv("HIP_VISIBLE_DEVICES")
  49. if visibleDevices != "" {
  50. ids := []int{}
  51. for _, idStr := range strings.Split(visibleDevices, ",") {
  52. id, err := strconv.Atoi(idStr)
  53. if err != nil {
  54. slog.Warn(fmt.Sprintf("malformed HIP_VISIBLE_DEVICES=%s %s", visibleDevices, err))
  55. } else {
  56. ids = append(ids, id)
  57. }
  58. }
  59. amdProcMemLookup(resp, nil, ids)
  60. return
  61. }
  62. // Gather GFX version information from all detected cards
  63. gfx := AMDGFXVersions()
  64. verStrings := []string{}
  65. for i, v := range gfx {
  66. verStrings = append(verStrings, v.ToGFXString())
  67. if v.Major == 0 {
  68. // Silently skip CPUs
  69. skip[i] = struct{}{}
  70. continue
  71. }
  72. if v.Major < 9 {
  73. // TODO consider this a build-time setting if we can support 8xx family GPUs
  74. slog.Warn(fmt.Sprintf("amdgpu [%d] too old %s", i, v.ToGFXString()))
  75. skip[i] = struct{}{}
  76. }
  77. }
  78. slog.Info(fmt.Sprintf("detected amdgpu versions %v", verStrings))
  79. // Abort if all GPUs are skipped
  80. if len(skip) >= len(gfx) {
  81. slog.Info("all detected amdgpus are skipped, falling back to CPU")
  82. return
  83. }
  84. // If we got this far, then we have at least 1 GPU that's a ROCm candidate, so make sure we have a lib
  85. libDir, err := AMDValidateLibDir()
  86. if err != nil {
  87. slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
  88. return
  89. }
  90. gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
  91. if gfxOverride == "" {
  92. supported, err := GetSupportedGFX(libDir)
  93. if err != nil {
  94. slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
  95. return
  96. }
  97. slog.Debug(fmt.Sprintf("rocm supported GPU types %v", supported))
  98. for i, v := range gfx {
  99. if !slices.Contains[[]string, string](supported, v.ToGFXString()) {
  100. slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported))
  101. // TODO - consider discrete markdown just for ROCM troubleshooting?
  102. slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
  103. skip[i] = struct{}{}
  104. } else {
  105. slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString()))
  106. }
  107. }
  108. } else {
  109. slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
  110. }
  111. if len(skip) >= len(gfx) {
  112. slog.Info("all detected amdgpus are skipped, falling back to CPU")
  113. return
  114. }
  115. ids := make([]int, len(gfx))
  116. i := 0
  117. for k := range gfx {
  118. ids[i] = k
  119. i++
  120. }
  121. amdProcMemLookup(resp, skip, ids)
  122. if resp.memInfo.DeviceCount == 0 {
  123. return
  124. }
  125. if len(skip) > 0 {
  126. amdSetVisibleDevices(ids, skip)
  127. }
  128. }
  129. // Walk the sysfs nodes for the available GPUs and gather information from them
  130. // skipping over any devices in the skip map
  131. func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
  132. resp.memInfo.DeviceCount = 0
  133. resp.memInfo.TotalMemory = 0
  134. resp.memInfo.FreeMemory = 0
  135. if len(ids) == 0 {
  136. slog.Debug("discovering all amdgpu devices")
  137. entries, err := os.ReadDir(AMDNodesSysfsDir)
  138. if err != nil {
  139. slog.Warn(fmt.Sprintf("failed to read amdgpu sysfs %s - %s", AMDNodesSysfsDir, err))
  140. return
  141. }
  142. for _, node := range entries {
  143. if !node.IsDir() {
  144. continue
  145. }
  146. id, err := strconv.Atoi(node.Name())
  147. if err != nil {
  148. slog.Warn("malformed amdgpu sysfs node id " + node.Name())
  149. continue
  150. }
  151. ids = append(ids, id)
  152. }
  153. }
  154. slog.Debug(fmt.Sprintf("discovering amdgpu devices %v", ids))
  155. for _, id := range ids {
  156. if _, skipped := skip[id]; skipped {
  157. continue
  158. }
  159. totalMemory := uint64(0)
  160. usedMemory := uint64(0)
  161. propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUTotalMemoryFileGlob)
  162. propFiles, err := filepath.Glob(propGlob)
  163. if err != nil {
  164. slog.Warn(fmt.Sprintf("error looking up total GPU memory: %s %s", propGlob, err))
  165. }
  166. // 1 or more memory banks - sum the values of all of them
  167. for _, propFile := range propFiles {
  168. fp, err := os.Open(propFile)
  169. if err != nil {
  170. slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", propFile, err))
  171. continue
  172. }
  173. defer fp.Close()
  174. scanner := bufio.NewScanner(fp)
  175. for scanner.Scan() {
  176. line := strings.TrimSpace(scanner.Text())
  177. if strings.HasPrefix(line, "size_in_bytes") {
  178. ver := strings.Fields(line)
  179. if len(ver) != 2 {
  180. slog.Warn("malformed " + line)
  181. continue
  182. }
  183. bankSizeInBytes, err := strconv.ParseUint(ver[1], 10, 64)
  184. if err != nil {
  185. slog.Warn("malformed int " + line)
  186. continue
  187. }
  188. totalMemory += bankSizeInBytes
  189. }
  190. }
  191. }
  192. if totalMemory == 0 {
  193. continue
  194. }
  195. usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUUsedMemoryFileGlob)
  196. usedFiles, err := filepath.Glob(usedGlob)
  197. if err != nil {
  198. slog.Warn(fmt.Sprintf("error looking up used GPU memory: %s %s", usedGlob, err))
  199. continue
  200. }
  201. for _, usedFile := range usedFiles {
  202. fp, err := os.Open(usedFile)
  203. if err != nil {
  204. slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", usedFile, err))
  205. continue
  206. }
  207. defer fp.Close()
  208. data, err := io.ReadAll(fp)
  209. if err != nil {
  210. slog.Warn(fmt.Sprintf("failed to read sysfs node file %s: %s", usedFile, err))
  211. continue
  212. }
  213. used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
  214. if err != nil {
  215. slog.Warn(fmt.Sprintf("malformed used memory %s: %s", string(data), err))
  216. continue
  217. }
  218. usedMemory += used
  219. }
  220. slog.Info(fmt.Sprintf("[%d] amdgpu totalMemory %d", id, totalMemory))
  221. slog.Info(fmt.Sprintf("[%d] amdgpu freeMemory %d", id, (totalMemory - usedMemory)))
  222. resp.memInfo.DeviceCount++
  223. resp.memInfo.TotalMemory += totalMemory
  224. resp.memInfo.FreeMemory += (totalMemory - usedMemory)
  225. }
  226. if resp.memInfo.DeviceCount > 0 {
  227. resp.Library = "rocm"
  228. }
  229. }
  230. // Quick check for AMD driver so we can skip amdgpu discovery if not present
  231. func AMDDetected() bool {
  232. // Some driver versions (older?) don't have a version file, so just lookup the parent dir
  233. sysfsDir := filepath.Dir(DriverVersionFile)
  234. _, err := os.Stat(sysfsDir)
  235. if errors.Is(err, os.ErrNotExist) {
  236. slog.Debug("amdgpu driver not detected " + sysfsDir)
  237. return false
  238. } else if err != nil {
  239. slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err))
  240. return false
  241. }
  242. return true
  243. }
  244. func setupLink(source, target string) error {
  245. if err := os.RemoveAll(target); err != nil {
  246. return fmt.Errorf("failed to remove old rocm directory %s %w", target, err)
  247. }
  248. if err := os.Symlink(source, target); err != nil {
  249. return fmt.Errorf("failed to create link %s => %s %w", source, target, err)
  250. }
  251. slog.Debug(fmt.Sprintf("host rocm linked %s => %s", source, target))
  252. return nil
  253. }
  254. // Ensure the AMD rocm lib dir is wired up
  255. // Prefer to use host installed ROCm, as long as it meets our minimum requirements
  256. // failing that, tell the user how to download it on their own
  257. func AMDValidateLibDir() (string, error) {
  258. // We rely on the rpath compiled into our library to find rocm
  259. // so we establish a symlink to wherever we find it on the system
  260. // to $AssetsDir/rocm
  261. // If we already have a rocm dependency wired, nothing more to do
  262. assetsDir, err := AssetsDir()
  263. if err != nil {
  264. return "", fmt.Errorf("unable to lookup lib dir: %w", err)
  265. }
  266. // Versioned directory
  267. rocmTargetDir := filepath.Join(assetsDir, "rocm")
  268. if rocmLibUsable(rocmTargetDir) {
  269. return rocmTargetDir, nil
  270. }
  271. // Parent dir (unversioned)
  272. commonRocmDir := filepath.Join(filepath.Dir(assetsDir), "rocm")
  273. if rocmLibUsable(commonRocmDir) {
  274. return rocmTargetDir, setupLink(commonRocmDir, rocmTargetDir)
  275. }
  276. // Prefer explicit HIP env var
  277. hipPath := os.Getenv("HIP_PATH")
  278. if hipPath != "" {
  279. hipLibDir := filepath.Join(hipPath, "lib")
  280. if rocmLibUsable(hipLibDir) {
  281. slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
  282. return rocmTargetDir, setupLink(hipLibDir, rocmTargetDir)
  283. }
  284. }
  285. // Scan the library path for potential matches
  286. ldPaths := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
  287. for _, ldPath := range ldPaths {
  288. d, err := filepath.Abs(ldPath)
  289. if err != nil {
  290. continue
  291. }
  292. if rocmLibUsable(d) {
  293. return rocmTargetDir, setupLink(d, rocmTargetDir)
  294. }
  295. }
  296. // Well known location(s)
  297. if rocmLibUsable("/opt/rocm/lib") {
  298. return rocmTargetDir, setupLink("/opt/rocm/lib", rocmTargetDir)
  299. }
  300. err = os.MkdirAll(rocmTargetDir, 0755)
  301. if err != nil {
  302. return "", fmt.Errorf("failed to create empty rocm dir %s %w", rocmTargetDir, err)
  303. }
  304. // If we still haven't found a usable rocm, the user will have to download it on their own
  305. slog.Warn("amdgpu detected, but no compatible rocm library found. Either install rocm v6, or run the following")
  306. slog.Warn(fmt.Sprintf(curlMsg, version.Version, rocmTargetDir))
  307. return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
  308. }
  309. func AMDDriverVersion() (string, error) {
  310. _, err := os.Stat(DriverVersionFile)
  311. if err != nil {
  312. return "", fmt.Errorf("amdgpu version file missing: %s %w", DriverVersionFile, err)
  313. }
  314. fp, err := os.Open(DriverVersionFile)
  315. if err != nil {
  316. return "", err
  317. }
  318. defer fp.Close()
  319. verString, err := io.ReadAll(fp)
  320. if err != nil {
  321. return "", err
  322. }
  323. return strings.TrimSpace(string(verString)), nil
  324. }
  325. func AMDGFXVersions() map[int]Version {
  326. res := map[int]Version{}
  327. matches, _ := filepath.Glob(GPUPropertiesFileGlob)
  328. for _, match := range matches {
  329. fp, err := os.Open(match)
  330. if err != nil {
  331. slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err))
  332. continue
  333. }
  334. defer fp.Close()
  335. i, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
  336. if err != nil {
  337. slog.Debug(fmt.Sprintf("failed to parse node ID %s", err))
  338. continue
  339. }
  340. scanner := bufio.NewScanner(fp)
  341. for scanner.Scan() {
  342. line := strings.TrimSpace(scanner.Text())
  343. if strings.HasPrefix(line, "gfx_target_version") {
  344. ver := strings.Fields(line)
  345. if len(ver) != 2 || len(ver[1]) < 5 {
  346. if ver[1] == "0" {
  347. // Silently skip the CPU
  348. continue
  349. } else {
  350. slog.Debug("malformed " + line)
  351. }
  352. res[i] = Version{
  353. Major: 0,
  354. Minor: 0,
  355. Patch: 0,
  356. }
  357. continue
  358. }
  359. l := len(ver[1])
  360. patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32)
  361. minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
  362. major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32)
  363. if err1 != nil || err2 != nil || err3 != nil {
  364. slog.Debug("malformed int " + line)
  365. continue
  366. }
  367. res[i] = Version{
  368. Major: uint(major),
  369. Minor: uint(minor),
  370. Patch: uint(patch),
  371. }
  372. }
  373. }
  374. }
  375. return res
  376. }
  377. func (v Version) ToGFXString() string {
  378. return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch)
  379. }