|
- package gpu
- import (
- "bufio"
- "errors"
- "fmt"
- "io"
- "log/slog"
- "os"
- "path/filepath"
- "slices"
- "strconv"
- "strings"
- "github.com/jmorganca/ollama/version"
- )
- // Discovery logic for AMD/ROCm GPUs
- const (
- curlMsg = "curl -fsSL https://github.com/ollama/ollama/releases/download/v%s/rocm-amd64-deps.tgz | tar -zxf - -C %s"
- DriverVersionFile = "/sys/module/amdgpu/version"
- AMDNodesSysfsDir = "/sys/class/kfd/kfd/topology/nodes/"
- GPUPropertiesFileGlob = AMDNodesSysfsDir + "*/properties"
- // Prefix with the node dir
- GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
- GPUUsedMemoryFileGlob = "mem_banks/*/used_memory"
- RocmStandardLocation = "/opt/rocm/lib"
- )
- var (
- // Used to validate if the given ROCm lib is usable
- ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
- )
- // Gather GPU information from the amdgpu driver if any supported GPUs are detected
- // HIP_VISIBLE_DEVICES will be set if we detect a mix of unsupported and supported devices
- // and the user hasn't already set this variable
- func AMDGetGPUInfo(resp *GpuInfo) {
- // TODO - DRY this out with windows
- if !AMDDetected() {
- return
- }
- skip := map[int]interface{}{}
- // Opportunistic logging of driver version to aid in troubleshooting
- ver, err := AMDDriverVersion()
- if err == nil {
- slog.Info("AMD Driver: " + ver)
- } else {
- // TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
- slog.Warn(fmt.Sprintf("ollama recommends running the https://www.amd.com/en/support/linux-drivers: %s", err))
- }
- // If the user has specified exactly which GPUs to use, look up their memory
- visibleDevices := os.Getenv("HIP_VISIBLE_DEVICES")
- if visibleDevices != "" {
- ids := []int{}
- for _, idStr := range strings.Split(visibleDevices, ",") {
- id, err := strconv.Atoi(idStr)
- if err != nil {
- slog.Warn(fmt.Sprintf("malformed HIP_VISIBLE_DEVICES=%s %s", visibleDevices, err))
- } else {
- ids = append(ids, id)
- }
- }
- amdProcMemLookup(resp, nil, ids)
- return
- }
- // Gather GFX version information from all detected cards
- gfx := AMDGFXVersions()
- verStrings := []string{}
- for i, v := range gfx {
- verStrings = append(verStrings, v.ToGFXString())
- if v.Major == 0 {
- // Silently skip CPUs
- skip[i] = struct{}{}
- continue
- }
- if v.Major < 9 {
- // TODO consider this a build-time setting if we can support 8xx family GPUs
- slog.Warn(fmt.Sprintf("amdgpu [%d] too old %s", i, v.ToGFXString()))
- skip[i] = struct{}{}
- }
- }
- slog.Info(fmt.Sprintf("detected amdgpu versions %v", verStrings))
- // Abort if all GPUs are skipped
- if len(skip) >= len(gfx) {
- slog.Info("all detected amdgpus are skipped, falling back to CPU")
- return
- }
- // If we got this far, then we have at least 1 GPU that's a ROCm candidate, so make sure we have a lib
- libDir, err := AMDValidateLibDir()
- if err != nil {
- slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
- return
- }
- gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
- if gfxOverride == "" {
- supported, err := GetSupportedGFX(libDir)
- if err != nil {
- slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
- return
- }
- slog.Debug(fmt.Sprintf("rocm supported GPU types %v", supported))
- for i, v := range gfx {
- if !slices.Contains[[]string, string](supported, v.ToGFXString()) {
- slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported))
- // TODO - consider discrete markdown just for ROCM troubleshooting?
- slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
- skip[i] = struct{}{}
- } else {
- slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString()))
- }
- }
- } else {
- slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
- }
- if len(skip) >= len(gfx) {
- slog.Info("all detected amdgpus are skipped, falling back to CPU")
- return
- }
- ids := make([]int, len(gfx))
- i := 0
- for k := range gfx {
- ids[i] = k
- i++
- }
- amdProcMemLookup(resp, skip, ids)
- if resp.memInfo.DeviceCount == 0 {
- return
- }
- if len(skip) > 0 {
- amdSetVisibleDevices(ids, skip)
- }
- }
- // Walk the sysfs nodes for the available GPUs and gather information from them
- // skipping over any devices in the skip map
- func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
- resp.memInfo.DeviceCount = 0
- resp.memInfo.TotalMemory = 0
- resp.memInfo.FreeMemory = 0
- if len(ids) == 0 {
- slog.Debug("discovering all amdgpu devices")
- entries, err := os.ReadDir(AMDNodesSysfsDir)
- if err != nil {
- slog.Warn(fmt.Sprintf("failed to read amdgpu sysfs %s - %s", AMDNodesSysfsDir, err))
- return
- }
- for _, node := range entries {
- if !node.IsDir() {
- continue
- }
- id, err := strconv.Atoi(node.Name())
- if err != nil {
- slog.Warn("malformed amdgpu sysfs node id " + node.Name())
- continue
- }
- ids = append(ids, id)
- }
- }
- slog.Debug(fmt.Sprintf("discovering amdgpu devices %v", ids))
- for _, id := range ids {
- if _, skipped := skip[id]; skipped {
- continue
- }
- totalMemory := uint64(0)
- usedMemory := uint64(0)
- propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUTotalMemoryFileGlob)
- propFiles, err := filepath.Glob(propGlob)
- if err != nil {
- slog.Warn(fmt.Sprintf("error looking up total GPU memory: %s %s", propGlob, err))
- }
- // 1 or more memory banks - sum the values of all of them
- for _, propFile := range propFiles {
- fp, err := os.Open(propFile)
- if err != nil {
- slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", propFile, err))
- continue
- }
- defer fp.Close()
- scanner := bufio.NewScanner(fp)
- for scanner.Scan() {
- line := strings.TrimSpace(scanner.Text())
- if strings.HasPrefix(line, "size_in_bytes") {
- ver := strings.Fields(line)
- if len(ver) != 2 {
- slog.Warn("malformed " + line)
- continue
- }
- bankSizeInBytes, err := strconv.ParseUint(ver[1], 10, 64)
- if err != nil {
- slog.Warn("malformed int " + line)
- continue
- }
- totalMemory += bankSizeInBytes
- }
- }
- }
- if totalMemory == 0 {
- continue
- }
- usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUUsedMemoryFileGlob)
- usedFiles, err := filepath.Glob(usedGlob)
- if err != nil {
- slog.Warn(fmt.Sprintf("error looking up used GPU memory: %s %s", usedGlob, err))
- continue
- }
- for _, usedFile := range usedFiles {
- fp, err := os.Open(usedFile)
- if err != nil {
- slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", usedFile, err))
- continue
- }
- defer fp.Close()
- data, err := io.ReadAll(fp)
- if err != nil {
- slog.Warn(fmt.Sprintf("failed to read sysfs node file %s: %s", usedFile, err))
- continue
- }
- used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
- if err != nil {
- slog.Warn(fmt.Sprintf("malformed used memory %s: %s", string(data), err))
- continue
- }
- usedMemory += used
- }
- slog.Info(fmt.Sprintf("[%d] amdgpu totalMemory %d", id, totalMemory))
- slog.Info(fmt.Sprintf("[%d] amdgpu freeMemory %d", id, (totalMemory - usedMemory)))
- resp.memInfo.DeviceCount++
- resp.memInfo.TotalMemory += totalMemory
- resp.memInfo.FreeMemory += (totalMemory - usedMemory)
- }
- if resp.memInfo.DeviceCount > 0 {
- resp.Library = "rocm"
- }
- }
- // Quick check for AMD driver so we can skip amdgpu discovery if not present
- func AMDDetected() bool {
- // Some driver versions (older?) don't have a version file, so just lookup the parent dir
- sysfsDir := filepath.Dir(DriverVersionFile)
- _, err := os.Stat(sysfsDir)
- if errors.Is(err, os.ErrNotExist) {
- slog.Debug("amdgpu driver not detected " + sysfsDir)
- return false
- } else if err != nil {
- slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err))
- return false
- }
- return true
- }
- func setupLink(source, target string) error {
- if err := os.RemoveAll(target); err != nil {
- return fmt.Errorf("failed to remove old rocm directory %s %w", target, err)
- }
- if err := os.Symlink(source, target); err != nil {
- return fmt.Errorf("failed to create link %s => %s %w", source, target, err)
- }
- slog.Debug(fmt.Sprintf("host rocm linked %s => %s", source, target))
- return nil
- }
- // Ensure the AMD rocm lib dir is wired up
- // Prefer to use host installed ROCm, as long as it meets our minimum requirements
- // failing that, tell the user how to download it on their own
- func AMDValidateLibDir() (string, error) {
- // We rely on the rpath compiled into our library to find rocm
- // so we establish a symlink to wherever we find it on the system
- // to $AssetsDir/rocm
- // If we already have a rocm dependency wired, nothing more to do
- assetsDir, err := AssetsDir()
- if err != nil {
- return "", fmt.Errorf("unable to lookup lib dir: %w", err)
- }
- // Versioned directory
- rocmTargetDir := filepath.Join(assetsDir, "rocm")
- if rocmLibUsable(rocmTargetDir) {
- return rocmTargetDir, nil
- }
- // Parent dir (unversioned)
- commonRocmDir := filepath.Join(filepath.Dir(assetsDir), "rocm")
- if rocmLibUsable(commonRocmDir) {
- return rocmTargetDir, setupLink(commonRocmDir, rocmTargetDir)
- }
- // Prefer explicit HIP env var
- hipPath := os.Getenv("HIP_PATH")
- if hipPath != "" {
- hipLibDir := filepath.Join(hipPath, "lib")
- if rocmLibUsable(hipLibDir) {
- slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
- return rocmTargetDir, setupLink(hipLibDir, rocmTargetDir)
- }
- }
- // Scan the library path for potential matches
- ldPaths := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
- for _, ldPath := range ldPaths {
- d, err := filepath.Abs(ldPath)
- if err != nil {
- continue
- }
- if rocmLibUsable(d) {
- return rocmTargetDir, setupLink(d, rocmTargetDir)
- }
- }
- // Well known location(s)
- if rocmLibUsable("/opt/rocm/lib") {
- return rocmTargetDir, setupLink("/opt/rocm/lib", rocmTargetDir)
- }
- err = os.MkdirAll(rocmTargetDir, 0755)
- if err != nil {
- return "", fmt.Errorf("failed to create empty rocm dir %s %w", rocmTargetDir, err)
- }
- // If we still haven't found a usable rocm, the user will have to download it on their own
- slog.Warn("amdgpu detected, but no compatible rocm library found. Either install rocm v6, or run the following")
- slog.Warn(fmt.Sprintf(curlMsg, version.Version, rocmTargetDir))
- return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
- }
- func AMDDriverVersion() (string, error) {
- _, err := os.Stat(DriverVersionFile)
- if err != nil {
- return "", fmt.Errorf("amdgpu version file missing: %s %w", DriverVersionFile, err)
- }
- fp, err := os.Open(DriverVersionFile)
- if err != nil {
- return "", err
- }
- defer fp.Close()
- verString, err := io.ReadAll(fp)
- if err != nil {
- return "", err
- }
- return strings.TrimSpace(string(verString)), nil
- }
- func AMDGFXVersions() map[int]Version {
- res := map[int]Version{}
- matches, _ := filepath.Glob(GPUPropertiesFileGlob)
- for _, match := range matches {
- fp, err := os.Open(match)
- if err != nil {
- slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err))
- continue
- }
- defer fp.Close()
- i, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
- if err != nil {
- slog.Debug(fmt.Sprintf("failed to parse node ID %s", err))
- continue
- }
- scanner := bufio.NewScanner(fp)
- for scanner.Scan() {
- line := strings.TrimSpace(scanner.Text())
- if strings.HasPrefix(line, "gfx_target_version") {
- ver := strings.Fields(line)
- if len(ver) != 2 || len(ver[1]) < 5 {
- if ver[1] == "0" {
- // Silently skip the CPU
- continue
- } else {
- slog.Debug("malformed " + line)
- }
- res[i] = Version{
- Major: 0,
- Minor: 0,
- Patch: 0,
- }
- continue
- }
- l := len(ver[1])
- patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32)
- minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
- major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32)
- if err1 != nil || err2 != nil || err3 != nil {
- slog.Debug("malformed int " + line)
- continue
- }
- res[i] = Version{
- Major: uint(major),
- Minor: uint(minor),
- Patch: uint(patch),
- }
- }
- }
- }
- return res
- }
- func (v Version) ToGFXString() string {
- return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch)
- }
|