amd_linux.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. package gpu
  2. import (
  3. "bufio"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "log/slog"
  8. "os"
  9. "path/filepath"
  10. "regexp"
  11. "slices"
  12. "strconv"
  13. "strings"
  14. "github.com/ollama/ollama/format"
  15. )
  16. // Discovery logic for AMD/ROCm GPUs
  17. const (
  18. DriverVersionFile = "/sys/module/amdgpu/version"
  19. AMDNodesSysfsDir = "/sys/class/kfd/kfd/topology/nodes/"
  20. GPUPropertiesFileGlob = AMDNodesSysfsDir + "*/properties"
  21. // Prefix with the node dir
  22. GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
  23. GPUUsedMemoryFileGlob = "mem_banks/*/used_memory"
  24. )
  25. var (
  26. // Used to validate if the given ROCm lib is usable
  27. ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
  28. RocmStandardLocations = []string{"/opt/rocm/lib", "/usr/lib64"}
  29. )
  30. // Gather GPU information from the amdgpu driver if any supported GPUs are detected
  31. func AMDGetGPUInfo() []GpuInfo {
  32. resp := []GpuInfo{}
  33. if !AMDDetected() {
  34. return resp
  35. }
  36. // Opportunistic logging of driver version to aid in troubleshooting
  37. driverMajor, driverMinor, err := AMDDriverVersion()
  38. if err != nil {
  39. // TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
  40. slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err)
  41. }
  42. // Determine if the user has already pre-selected which GPUs to look at, then ignore the others
  43. var visibleDevices []string
  44. hipVD := os.Getenv("HIP_VISIBLE_DEVICES") // zero based index only
  45. rocrVD := os.Getenv("ROCR_VISIBLE_DEVICES") // zero based index or UUID, but consumer cards seem to not support UUID
  46. gpuDO := os.Getenv("GPU_DEVICE_ORDINAL") // zero based index
  47. switch {
  48. // TODO is this priorty order right?
  49. case hipVD != "":
  50. visibleDevices = strings.Split(hipVD, ",")
  51. case rocrVD != "":
  52. visibleDevices = strings.Split(rocrVD, ",")
  53. // TODO - since we don't yet support UUIDs, consider detecting and reporting here
  54. // all our test systems show GPU-XX indicating UUID is not supported
  55. case gpuDO != "":
  56. visibleDevices = strings.Split(gpuDO, ",")
  57. }
  58. gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
  59. var supported []string
  60. libDir := ""
  61. // The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract
  62. // from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
  63. matches, _ := filepath.Glob(GPUPropertiesFileGlob)
  64. cpuCount := 0
  65. for _, match := range matches {
  66. slog.Debug("evaluating amdgpu node " + match)
  67. fp, err := os.Open(match)
  68. if err != nil {
  69. slog.Debug("failed to open sysfs node", "file", match, "error", err)
  70. continue
  71. }
  72. defer fp.Close()
  73. nodeID, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
  74. if err != nil {
  75. slog.Debug("failed to parse node ID", "error", err)
  76. continue
  77. }
  78. scanner := bufio.NewScanner(fp)
  79. isCPU := false
  80. var major, minor, patch uint64
  81. var vendor, device uint64
  82. for scanner.Scan() {
  83. line := strings.TrimSpace(scanner.Text())
  84. // Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
  85. if strings.HasPrefix(line, "gfx_target_version") {
  86. ver := strings.Fields(line)
  87. // Detect CPUs
  88. if len(ver) == 2 && ver[1] == "0" {
  89. slog.Debug("detected CPU " + match)
  90. isCPU = true
  91. break
  92. }
  93. if len(ver) != 2 || len(ver[1]) < 5 {
  94. slog.Warn("malformed "+match, "gfx_target_version", line)
  95. // If this winds up being a CPU, our offsets may be wrong
  96. continue
  97. }
  98. l := len(ver[1])
  99. var err1, err2, err3 error
  100. patch, err1 = strconv.ParseUint(ver[1][l-2:l], 10, 32)
  101. minor, err2 = strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
  102. major, err3 = strconv.ParseUint(ver[1][:l-4], 10, 32)
  103. if err1 != nil || err2 != nil || err3 != nil {
  104. slog.Debug("malformed int " + line)
  105. continue
  106. }
  107. } else if strings.HasPrefix(line, "vendor_id") {
  108. ver := strings.Fields(line)
  109. if len(ver) != 2 {
  110. slog.Debug("malformed vendor_id", "vendor_id", line)
  111. continue
  112. }
  113. vendor, err = strconv.ParseUint(ver[1], 10, 32)
  114. if err != nil {
  115. slog.Debug("malformed vendor_id" + line)
  116. }
  117. } else if strings.HasPrefix(line, "device_id") {
  118. ver := strings.Fields(line)
  119. if len(ver) != 2 {
  120. slog.Debug("malformed device_id", "device_id", line)
  121. continue
  122. }
  123. device, err = strconv.ParseUint(ver[1], 10, 32)
  124. if err != nil {
  125. slog.Debug("malformed device_id" + line)
  126. }
  127. }
  128. // TODO - any other properties we want to extract and record?
  129. // vendor_id + device_id -> pci lookup for "Name"
  130. // Other metrics that may help us understand relative performance between multiple GPUs
  131. }
  132. if isCPU {
  133. cpuCount++
  134. continue
  135. }
  136. // CPUs are always first in the list
  137. gpuID := nodeID - cpuCount
  138. // Shouldn't happen, but just in case...
  139. if gpuID < 0 {
  140. slog.Error("unexpected amdgpu sysfs data resulted in negative GPU ID, please set OLLAMA_DEBUG=1 and report an issue")
  141. return []GpuInfo{}
  142. }
  143. if int(major) < RocmComputeMin {
  144. slog.Warn(fmt.Sprintf("amdgpu too old gfx%d%x%x", major, minor, patch), "gpu", gpuID)
  145. continue
  146. }
  147. // Look up the memory for the current node
  148. totalMemory := uint64(0)
  149. usedMemory := uint64(0)
  150. propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUTotalMemoryFileGlob)
  151. propFiles, err := filepath.Glob(propGlob)
  152. if err != nil {
  153. slog.Warn("error looking up total GPU memory", "glob", propGlob, "error", err)
  154. }
  155. // 1 or more memory banks - sum the values of all of them
  156. for _, propFile := range propFiles {
  157. fp, err := os.Open(propFile)
  158. if err != nil {
  159. slog.Warn("failed to open sysfs node", "file", propFile, "erroir", err)
  160. continue
  161. }
  162. defer fp.Close()
  163. scanner := bufio.NewScanner(fp)
  164. for scanner.Scan() {
  165. line := strings.TrimSpace(scanner.Text())
  166. if strings.HasPrefix(line, "size_in_bytes") {
  167. ver := strings.Fields(line)
  168. if len(ver) != 2 {
  169. slog.Warn("malformed " + line)
  170. continue
  171. }
  172. bankSizeInBytes, err := strconv.ParseUint(ver[1], 10, 64)
  173. if err != nil {
  174. slog.Warn("malformed int " + line)
  175. continue
  176. }
  177. totalMemory += bankSizeInBytes
  178. }
  179. }
  180. }
  181. if totalMemory == 0 {
  182. slog.Warn("amdgpu reports zero total memory", "gpu", gpuID)
  183. continue
  184. }
  185. usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUUsedMemoryFileGlob)
  186. usedFiles, err := filepath.Glob(usedGlob)
  187. if err != nil {
  188. slog.Warn("error looking up used GPU memory", "glob", usedGlob, "error", err)
  189. continue
  190. }
  191. for _, usedFile := range usedFiles {
  192. fp, err := os.Open(usedFile)
  193. if err != nil {
  194. slog.Warn("failed to open sysfs node", "file", usedFile, "error", err)
  195. continue
  196. }
  197. defer fp.Close()
  198. data, err := io.ReadAll(fp)
  199. if err != nil {
  200. slog.Warn("failed to read sysfs node", "file", usedFile, "error", err)
  201. continue
  202. }
  203. used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
  204. if err != nil {
  205. slog.Warn("malformed used memory", "data", string(data), "error", err)
  206. continue
  207. }
  208. usedMemory += used
  209. }
  210. // iGPU detection, remove this check once we can support an iGPU variant of the rocm library
  211. if totalMemory < IGPUMemLimit {
  212. slog.Info("unsupported Radeon iGPU detected skipping", "id", gpuID, "total", format.HumanBytes2(totalMemory))
  213. continue
  214. }
  215. var name string
  216. // TODO - PCI ID lookup
  217. if vendor > 0 && device > 0 {
  218. name = fmt.Sprintf("%04x:%04x", vendor, device)
  219. }
  220. slog.Debug("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
  221. slog.Debug("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
  222. gpuInfo := GpuInfo{
  223. Library: "rocm",
  224. memInfo: memInfo{
  225. TotalMemory: totalMemory,
  226. FreeMemory: (totalMemory - usedMemory),
  227. },
  228. ID: fmt.Sprintf("%d", gpuID),
  229. Name: name,
  230. Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch),
  231. MinimumMemory: rocmMinimumMemory,
  232. DriverMajor: driverMajor,
  233. DriverMinor: driverMinor,
  234. }
  235. // If the user wants to filter to a subset of devices, filter out if we aren't a match
  236. if len(visibleDevices) > 0 {
  237. include := false
  238. for _, visible := range visibleDevices {
  239. if visible == gpuInfo.ID {
  240. include = true
  241. break
  242. }
  243. }
  244. if !include {
  245. slog.Info("filtering out device per user request", "id", gpuInfo.ID, "visible_devices", visibleDevices)
  246. continue
  247. }
  248. }
  249. // Final validation is gfx compatibility - load the library if we haven't already loaded it
  250. // even if the user overrides, we still need to validate the library
  251. if libDir == "" {
  252. libDir, err = AMDValidateLibDir()
  253. if err != nil {
  254. slog.Warn("unable to verify rocm library, will use cpu", "error", err)
  255. return []GpuInfo{}
  256. }
  257. }
  258. gpuInfo.DependencyPath = libDir
  259. if gfxOverride == "" {
  260. // Only load supported list once
  261. if len(supported) == 0 {
  262. supported, err = GetSupportedGFX(libDir)
  263. if err != nil {
  264. slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
  265. return []GpuInfo{}
  266. }
  267. slog.Debug("rocm supported GPUs", "types", supported)
  268. }
  269. gfx := gpuInfo.Compute
  270. if !slices.Contains[[]string, string](supported, gfx) {
  271. slog.Warn("amdgpu is not supported", "gpu", gpuInfo.ID, "gpu_type", gfx, "library", libDir, "supported_types", supported)
  272. // TODO - consider discrete markdown just for ROCM troubleshooting?
  273. slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
  274. continue
  275. } else {
  276. slog.Info("amdgpu is supported", "gpu", gpuInfo.ID, "gpu_type", gfx)
  277. }
  278. } else {
  279. slog.Info("skipping rocm gfx compatibility check", "HSA_OVERRIDE_GFX_VERSION", gfxOverride)
  280. }
  281. // The GPU has passed all the verification steps and is supported
  282. resp = append(resp, gpuInfo)
  283. }
  284. if len(resp) == 0 {
  285. slog.Info("no compatible amdgpu devices detected")
  286. }
  287. return resp
  288. }
  289. // Quick check for AMD driver so we can skip amdgpu discovery if not present
  290. func AMDDetected() bool {
  291. // Some driver versions (older?) don't have a version file, so just lookup the parent dir
  292. sysfsDir := filepath.Dir(DriverVersionFile)
  293. _, err := os.Stat(sysfsDir)
  294. if errors.Is(err, os.ErrNotExist) {
  295. slog.Debug("amdgpu driver not detected " + sysfsDir)
  296. return false
  297. } else if err != nil {
  298. slog.Debug("error looking up amd driver", "path", sysfsDir, "error", err)
  299. return false
  300. }
  301. return true
  302. }
  303. // Prefer to use host installed ROCm, as long as it meets our minimum requirements
  304. // failing that, tell the user how to download it on their own
  305. func AMDValidateLibDir() (string, error) {
  306. libDir, err := commonAMDValidateLibDir()
  307. if err == nil {
  308. return libDir, nil
  309. }
  310. // Well known ollama installer path
  311. installedRocmDir := "/usr/share/ollama/lib/rocm"
  312. if rocmLibUsable(installedRocmDir) {
  313. return installedRocmDir, nil
  314. }
  315. // If we still haven't found a usable rocm, the user will have to install it on their own
  316. slog.Warn("amdgpu detected, but no compatible rocm library found. Either install rocm v6, or follow manual install instructions at https://github.com/ollama/ollama/blob/main/docs/linux.md#manual-install")
  317. return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
  318. }
  319. func AMDDriverVersion() (driverMajor, driverMinor int, err error) {
  320. _, err = os.Stat(DriverVersionFile)
  321. if err != nil {
  322. return 0, 0, fmt.Errorf("amdgpu version file missing: %s %w", DriverVersionFile, err)
  323. }
  324. fp, err := os.Open(DriverVersionFile)
  325. if err != nil {
  326. return 0, 0, err
  327. }
  328. defer fp.Close()
  329. verString, err := io.ReadAll(fp)
  330. if err != nil {
  331. return 0, 0, err
  332. }
  333. pattern := `\A(\d+)\.(\d+).*`
  334. regex := regexp.MustCompile(pattern)
  335. match := regex.FindStringSubmatch(string(verString))
  336. if len(match) < 2 {
  337. return 0, 0, fmt.Errorf("malformed version string %s", string(verString))
  338. }
  339. driverMajor, err = strconv.Atoi(match[1])
  340. if err != nil {
  341. return 0, 0, err
  342. }
  343. driverMinor, err = strconv.Atoi(match[2])
  344. if err != nil {
  345. return 0, 0, err
  346. }
  347. return driverMajor, driverMinor, nil
  348. }