123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434 |
- package gpu
- import (
- "bufio"
- "errors"
- "fmt"
- "io"
- "log/slog"
- "os"
- "path/filepath"
- "slices"
- "strconv"
- "strings"
- )
- // Discovery logic for AMD/ROCm GPUs
- const (
- DriverVersionFile = "/sys/module/amdgpu/version"
- AMDNodesSysfsDir = "/sys/class/kfd/kfd/topology/nodes/"
- GPUPropertiesFileGlob = AMDNodesSysfsDir + "*/properties"
- // Prefix with the node dir
- GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
- GPUUsedMemoryFileGlob = "mem_banks/*/used_memory"
- RocmStandardLocation = "/opt/rocm/lib"
- // TODO find a better way to detect iGPU instead of minimum memory
- IGPUMemLimit = 1024 * 1024 * 1024 // 512G is what they typically report, so anything less than 1G must be iGPU
- )
- var (
- // Used to validate if the given ROCm lib is usable
- ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
- )
- // Gather GPU information from the amdgpu driver if any supported GPUs are detected
- // HIP_VISIBLE_DEVICES will be set if we detect a mix of unsupported and supported devices
- // and the user hasn't already set this variable
- func AMDGetGPUInfo(resp *GpuInfo) {
- // TODO - DRY this out with windows
- if !AMDDetected() {
- return
- }
- skip := map[int]interface{}{}
- // Opportunistic logging of driver version to aid in troubleshooting
- ver, err := AMDDriverVersion()
- if err == nil {
- slog.Info("AMD Driver: " + ver)
- } else {
- // TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
- slog.Warn(fmt.Sprintf("ollama recommends running the https://www.amd.com/en/support/linux-drivers: %s", err))
- }
- // If the user has specified exactly which GPUs to use, look up their memory
- visibleDevices := os.Getenv("HIP_VISIBLE_DEVICES")
- if visibleDevices != "" {
- ids := []int{}
- for _, idStr := range strings.Split(visibleDevices, ",") {
- id, err := strconv.Atoi(idStr)
- if err != nil {
- slog.Warn(fmt.Sprintf("malformed HIP_VISIBLE_DEVICES=%s %s", visibleDevices, err))
- } else {
- ids = append(ids, id)
- }
- }
- amdProcMemLookup(resp, nil, ids)
- return
- }
- // Gather GFX version information from all detected cards
- gfx := AMDGFXVersions()
- verStrings := []string{}
- for i, v := range gfx {
- verStrings = append(verStrings, v.ToGFXString())
- if v.Major == 0 {
- // Silently skip CPUs
- skip[i] = struct{}{}
- continue
- }
- if v.Major < 9 {
- // TODO consider this a build-time setting if we can support 8xx family GPUs
- slog.Warn(fmt.Sprintf("amdgpu [%d] too old %s", i, v.ToGFXString()))
- skip[i] = struct{}{}
- }
- }
- slog.Info(fmt.Sprintf("detected amdgpu versions %v", verStrings))
- // Abort if all GPUs are skipped
- if len(skip) >= len(gfx) {
- slog.Info("all detected amdgpus are skipped, falling back to CPU")
- return
- }
- // If we got this far, then we have at least 1 GPU that's a ROCm candidate, so make sure we have a lib
- libDir, err := AMDValidateLibDir()
- if err != nil {
- slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
- return
- }
- gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
- if gfxOverride == "" {
- supported, err := GetSupportedGFX(libDir)
- if err != nil {
- slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
- return
- }
- slog.Debug(fmt.Sprintf("rocm supported GPU types %v", supported))
- for i, v := range gfx {
- if !slices.Contains[[]string, string](supported, v.ToGFXString()) {
- slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported))
- // TODO - consider discrete markdown just for ROCM troubleshooting?
- slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
- skip[i] = struct{}{}
- } else {
- slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString()))
- }
- }
- } else {
- slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
- }
- if len(skip) >= len(gfx) {
- slog.Info("all detected amdgpus are skipped, falling back to CPU")
- return
- }
- ids := make([]int, len(gfx))
- i := 0
- for k := range gfx {
- ids[i] = k
- i++
- }
- amdProcMemLookup(resp, skip, ids)
- if resp.memInfo.DeviceCount == 0 {
- return
- }
- if len(skip) > 0 {
- amdSetVisibleDevices(ids, skip)
- }
- }
- // Walk the sysfs nodes for the available GPUs and gather information from them
- // skipping over any devices in the skip map
- func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
- resp.memInfo.DeviceCount = 0
- resp.memInfo.TotalMemory = 0
- resp.memInfo.FreeMemory = 0
- slog.Debug("discovering VRAM for amdgpu devices")
- if len(ids) == 0 {
- entries, err := os.ReadDir(AMDNodesSysfsDir)
- if err != nil {
- slog.Warn(fmt.Sprintf("failed to read amdgpu sysfs %s - %s", AMDNodesSysfsDir, err))
- return
- }
- for _, node := range entries {
- if !node.IsDir() {
- continue
- }
- id, err := strconv.Atoi(node.Name())
- if err != nil {
- slog.Warn("malformed amdgpu sysfs node id " + node.Name())
- continue
- }
- ids = append(ids, id)
- }
- }
- slog.Debug(fmt.Sprintf("amdgpu devices %v", ids))
- for _, id := range ids {
- if _, skipped := skip[id]; skipped {
- continue
- }
- totalMemory := uint64(0)
- usedMemory := uint64(0)
- // Adjust for sysfs vs HIP ids
- propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id+1), GPUTotalMemoryFileGlob)
- propFiles, err := filepath.Glob(propGlob)
- if err != nil {
- slog.Warn(fmt.Sprintf("error looking up total GPU memory: %s %s", propGlob, err))
- }
- // 1 or more memory banks - sum the values of all of them
- for _, propFile := range propFiles {
- fp, err := os.Open(propFile)
- if err != nil {
- slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", propFile, err))
- continue
- }
- defer fp.Close()
- scanner := bufio.NewScanner(fp)
- for scanner.Scan() {
- line := strings.TrimSpace(scanner.Text())
- if strings.HasPrefix(line, "size_in_bytes") {
- ver := strings.Fields(line)
- if len(ver) != 2 {
- slog.Warn("malformed " + line)
- continue
- }
- bankSizeInBytes, err := strconv.ParseUint(ver[1], 10, 64)
- if err != nil {
- slog.Warn("malformed int " + line)
- continue
- }
- totalMemory += bankSizeInBytes
- }
- }
- }
- if totalMemory == 0 {
- slog.Warn(fmt.Sprintf("amdgpu [%d] reports zero total memory, skipping", id))
- skip[id] = struct{}{}
- continue
- }
- if totalMemory < IGPUMemLimit {
- slog.Info(fmt.Sprintf("amdgpu [%d] appears to be an iGPU with %dM reported total memory, skipping", id, totalMemory/1024/1024))
- skip[id] = struct{}{}
- continue
- }
- usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUUsedMemoryFileGlob)
- usedFiles, err := filepath.Glob(usedGlob)
- if err != nil {
- slog.Warn(fmt.Sprintf("error looking up used GPU memory: %s %s", usedGlob, err))
- continue
- }
- for _, usedFile := range usedFiles {
- fp, err := os.Open(usedFile)
- if err != nil {
- slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", usedFile, err))
- continue
- }
- defer fp.Close()
- data, err := io.ReadAll(fp)
- if err != nil {
- slog.Warn(fmt.Sprintf("failed to read sysfs node file %s: %s", usedFile, err))
- continue
- }
- used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
- if err != nil {
- slog.Warn(fmt.Sprintf("malformed used memory %s: %s", string(data), err))
- continue
- }
- usedMemory += used
- }
- slog.Info(fmt.Sprintf("[%d] amdgpu totalMemory %dM", id, totalMemory/1024/1024))
- slog.Info(fmt.Sprintf("[%d] amdgpu freeMemory %dM", id, (totalMemory-usedMemory)/1024/1024))
- resp.memInfo.DeviceCount++
- resp.memInfo.TotalMemory += totalMemory
- resp.memInfo.FreeMemory += (totalMemory - usedMemory)
- }
- if resp.memInfo.DeviceCount > 0 {
- resp.Library = "rocm"
- }
- }
- // Quick check for AMD driver so we can skip amdgpu discovery if not present
- func AMDDetected() bool {
- // Some driver versions (older?) don't have a version file, so just lookup the parent dir
- sysfsDir := filepath.Dir(DriverVersionFile)
- _, err := os.Stat(sysfsDir)
- if errors.Is(err, os.ErrNotExist) {
- slog.Debug("amdgpu driver not detected " + sysfsDir)
- return false
- } else if err != nil {
- slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err))
- return false
- }
- return true
- }
- func setupLink(source, target string) error {
- if err := os.RemoveAll(target); err != nil {
- return fmt.Errorf("failed to remove old rocm directory %s %w", target, err)
- }
- if err := os.Symlink(source, target); err != nil {
- return fmt.Errorf("failed to create link %s => %s %w", source, target, err)
- }
- slog.Debug(fmt.Sprintf("host rocm linked %s => %s", source, target))
- return nil
- }
- // Ensure the AMD rocm lib dir is wired up
- // Prefer to use host installed ROCm, as long as it meets our minimum requirements
- // failing that, tell the user how to download it on their own
- func AMDValidateLibDir() (string, error) {
- // We rely on the rpath compiled into our library to find rocm
- // so we establish a symlink to wherever we find it on the system
- // to <payloads>/rocm
- payloadsDir, err := PayloadsDir()
- if err != nil {
- return "", err
- }
- // If we already have a rocm dependency wired, nothing more to do
- rocmTargetDir := filepath.Clean(filepath.Join(payloadsDir, "..", "rocm"))
- if rocmLibUsable(rocmTargetDir) {
- return rocmTargetDir, nil
- }
- // next to the running binary
- exe, err := os.Executable()
- if err == nil {
- peerDir := filepath.Dir(exe)
- if rocmLibUsable(peerDir) {
- slog.Debug("detected ROCM next to ollama executable " + peerDir)
- return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
- }
- peerDir = filepath.Join(filepath.Dir(exe), "rocm")
- if rocmLibUsable(peerDir) {
- slog.Debug("detected ROCM next to ollama executable " + peerDir)
- return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
- }
- }
- // Well known ollama installer path
- installedRocmDir := "/usr/share/ollama/lib/rocm"
- if rocmLibUsable(installedRocmDir) {
- return rocmTargetDir, setupLink(installedRocmDir, rocmTargetDir)
- }
- // Prefer explicit HIP env var
- hipPath := os.Getenv("HIP_PATH")
- if hipPath != "" {
- hipLibDir := filepath.Join(hipPath, "lib")
- if rocmLibUsable(hipLibDir) {
- slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
- return rocmTargetDir, setupLink(hipLibDir, rocmTargetDir)
- }
- }
- // Scan the library path for potential matches
- ldPaths := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
- for _, ldPath := range ldPaths {
- d, err := filepath.Abs(ldPath)
- if err != nil {
- continue
- }
- if rocmLibUsable(d) {
- return rocmTargetDir, setupLink(d, rocmTargetDir)
- }
- }
- // Well known location(s)
- if rocmLibUsable("/opt/rocm/lib") {
- return rocmTargetDir, setupLink("/opt/rocm/lib", rocmTargetDir)
- }
- // If we still haven't found a usable rocm, the user will have to install it on their own
- slog.Warn("amdgpu detected, but no compatible rocm library found. Either install rocm v6, or follow manual install instructions at https://github.com/ollama/ollama/blob/main/docs/linux.md#manual-install")
- return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
- }
- func AMDDriverVersion() (string, error) {
- _, err := os.Stat(DriverVersionFile)
- if err != nil {
- return "", fmt.Errorf("amdgpu version file missing: %s %w", DriverVersionFile, err)
- }
- fp, err := os.Open(DriverVersionFile)
- if err != nil {
- return "", err
- }
- defer fp.Close()
- verString, err := io.ReadAll(fp)
- if err != nil {
- return "", err
- }
- return strings.TrimSpace(string(verString)), nil
- }
- func AMDGFXVersions() map[int]Version {
- // The amdgpu driver always exposes the host CPU as node 0, but we have to skip that and subtract one
- // from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
- res := map[int]Version{}
- matches, _ := filepath.Glob(GPUPropertiesFileGlob)
- for _, match := range matches {
- fp, err := os.Open(match)
- if err != nil {
- slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err))
- continue
- }
- defer fp.Close()
- i, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
- if err != nil {
- slog.Debug(fmt.Sprintf("failed to parse node ID %s", err))
- continue
- }
- if i == 0 {
- // Skipping the CPU
- continue
- }
- // Align with HIP IDs (zero is first GPU, not CPU)
- i -= 1
- scanner := bufio.NewScanner(fp)
- for scanner.Scan() {
- line := strings.TrimSpace(scanner.Text())
- if strings.HasPrefix(line, "gfx_target_version") {
- ver := strings.Fields(line)
- if len(ver) != 2 || len(ver[1]) < 5 {
- if ver[1] != "0" {
- slog.Debug("malformed " + line)
- }
- res[i] = Version{
- Major: 0,
- Minor: 0,
- Patch: 0,
- }
- continue
- }
- l := len(ver[1])
- patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32)
- minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
- major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32)
- if err1 != nil || err2 != nil || err3 != nil {
- slog.Debug("malformed int " + line)
- continue
- }
- res[i] = Version{
- Major: uint(major),
- Minor: uint(minor),
- Patch: uint(patch),
- }
- }
- }
- }
- return res
- }
- func (v Version) ToGFXString() string {
- return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch)
- }
|