|
@@ -8,6 +8,7 @@ import (
|
|
|
"log/slog"
|
|
|
"os"
|
|
|
"path/filepath"
|
|
|
+ "regexp"
|
|
|
"slices"
|
|
|
"strconv"
|
|
|
"strings"
|
|
@@ -41,10 +42,8 @@ func AMDGetGPUInfo() []GpuInfo {
|
|
|
}
|
|
|
|
|
|
// Opportunistic logging of driver version to aid in troubleshooting
|
|
|
- ver, err := AMDDriverVersion()
|
|
|
- if err == nil {
|
|
|
- slog.Info("AMD Driver: " + ver)
|
|
|
- } else {
|
|
|
+ driverMajor, driverMinor, err := AMDDriverVersion()
|
|
|
+ if err != nil {
|
|
|
// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
|
|
|
slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err)
|
|
|
}
|
|
@@ -91,6 +90,7 @@ func AMDGetGPUInfo() []GpuInfo {
|
|
|
scanner := bufio.NewScanner(fp)
|
|
|
isCPU := false
|
|
|
var major, minor, patch uint64
|
|
|
+ var vendor, device uint64
|
|
|
for scanner.Scan() {
|
|
|
line := strings.TrimSpace(scanner.Text())
|
|
|
// Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
|
|
@@ -118,6 +118,26 @@ func AMDGetGPUInfo() []GpuInfo {
|
|
|
slog.Debug("malformed int " + line)
|
|
|
continue
|
|
|
}
|
|
|
+ } else if strings.HasPrefix(line, "vendor_id") {
|
|
|
+ ver := strings.Fields(line)
|
|
|
+ if len(ver) != 2 {
|
|
|
+ slog.Debug("malformed vendor_id", "vendor_id", line)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ vendor, err = strconv.ParseUint(ver[1], 10, 32)
|
|
|
+ if err != nil {
|
|
|
+ slog.Debug("malformed vendor_id" + line)
|
|
|
+ }
|
|
|
+ } else if strings.HasPrefix(line, "device_id") {
|
|
|
+ ver := strings.Fields(line)
|
|
|
+ if len(ver) != 2 {
|
|
|
+ slog.Debug("malformed device_id", "device_id", line)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ device, err = strconv.ParseUint(ver[1], 10, 32)
|
|
|
+ if err != nil {
|
|
|
+ slog.Debug("malformed device_id" + line)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
// TODO - any other properties we want to extract and record?
|
|
@@ -140,7 +160,7 @@ func AMDGetGPUInfo() []GpuInfo {
|
|
|
}
|
|
|
|
|
|
if int(major) < RocmComputeMin {
|
|
|
- slog.Warn(fmt.Sprintf("amdgpu too old gfx%d%d%x", major, minor, patch), "gpu", gpuID)
|
|
|
+ slog.Warn(fmt.Sprintf("amdgpu too old gfx%d%x%x", major, minor, patch), "gpu", gpuID)
|
|
|
continue
|
|
|
}
|
|
|
|
|
@@ -210,24 +230,29 @@ func AMDGetGPUInfo() []GpuInfo {
|
|
|
|
|
|
// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
|
|
|
if totalMemory < IGPUMemLimit {
|
|
|
- slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
|
|
|
+ slog.Info("unsupported Radeon iGPU detected skipping", "id", gpuID, "total", format.HumanBytes2(totalMemory))
|
|
|
continue
|
|
|
}
|
|
|
+ var name string
|
|
|
+ // TODO - PCI ID lookup
|
|
|
+ if vendor > 0 && device > 0 {
|
|
|
+ name = fmt.Sprintf("%04x:%04x", vendor, device)
|
|
|
+ }
|
|
|
|
|
|
- slog.Info("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
|
|
|
- slog.Info("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
|
|
|
+ slog.Debug("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
|
|
|
+ slog.Debug("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
|
|
|
gpuInfo := GpuInfo{
|
|
|
Library: "rocm",
|
|
|
memInfo: memInfo{
|
|
|
TotalMemory: totalMemory,
|
|
|
FreeMemory: (totalMemory - usedMemory),
|
|
|
},
|
|
|
- ID: fmt.Sprintf("%d", gpuID),
|
|
|
- // Name: not exposed in sysfs directly, would require pci device id lookup
|
|
|
- Major: int(major),
|
|
|
- Minor: int(minor),
|
|
|
- Patch: int(patch),
|
|
|
+ ID: fmt.Sprintf("%d", gpuID),
|
|
|
+ Name: name,
|
|
|
+ Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch),
|
|
|
MinimumMemory: rocmMinimumMemory,
|
|
|
+ DriverMajor: driverMajor,
|
|
|
+ DriverMinor: driverMinor,
|
|
|
}
|
|
|
|
|
|
// If the user wants to filter to a subset of devices, filter out if we aren't a match
|
|
@@ -266,7 +291,7 @@ func AMDGetGPUInfo() []GpuInfo {
|
|
|
}
|
|
|
slog.Debug("rocm supported GPUs", "types", supported)
|
|
|
}
|
|
|
- gfx := fmt.Sprintf("gfx%d%d%x", gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch)
|
|
|
+ gfx := gpuInfo.Compute
|
|
|
if !slices.Contains[[]string, string](supported, gfx) {
|
|
|
slog.Warn("amdgpu is not supported", "gpu", gpuInfo.ID, "gpu_type", gfx, "library", libDir, "supported_types", supported)
|
|
|
// TODO - consider discrete markdown just for ROCM troubleshooting?
|
|
@@ -276,7 +301,7 @@ func AMDGetGPUInfo() []GpuInfo {
|
|
|
slog.Info("amdgpu is supported", "gpu", gpuInfo.ID, "gpu_type", gfx)
|
|
|
}
|
|
|
} else {
|
|
|
- slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
|
|
|
+ slog.Info("skipping rocm gfx compatibility check", "HSA_OVERRIDE_GFX_VERSION", gfxOverride)
|
|
|
}
|
|
|
|
|
|
// The GPU has passed all the verification steps and is supported
|
|
@@ -322,19 +347,34 @@ func AMDValidateLibDir() (string, error) {
|
|
|
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
|
|
|
}
|
|
|
|
|
|
-func AMDDriverVersion() (string, error) {
|
|
|
- _, err := os.Stat(DriverVersionFile)
|
|
|
+func AMDDriverVersion() (driverMajor, driverMinor int, err error) {
|
|
|
+ _, err = os.Stat(DriverVersionFile)
|
|
|
if err != nil {
|
|
|
- return "", fmt.Errorf("amdgpu version file missing: %s %w", DriverVersionFile, err)
|
|
|
+ return 0, 0, fmt.Errorf("amdgpu version file missing: %s %w", DriverVersionFile, err)
|
|
|
}
|
|
|
fp, err := os.Open(DriverVersionFile)
|
|
|
if err != nil {
|
|
|
- return "", err
|
|
|
+ return 0, 0, err
|
|
|
}
|
|
|
defer fp.Close()
|
|
|
verString, err := io.ReadAll(fp)
|
|
|
if err != nil {
|
|
|
- return "", err
|
|
|
+ return 0, 0, err
|
|
|
+ }
|
|
|
+
|
|
|
+ pattern := `\A(\d+)\.(\d+).*`
|
|
|
+ regex := regexp.MustCompile(pattern)
|
|
|
+ match := regex.FindStringSubmatch(string(verString))
|
|
|
+ if len(match) < 2 {
|
|
|
+ return 0, 0, fmt.Errorf("malformed version string %s", string(verString))
|
|
|
+ }
|
|
|
+ driverMajor, err = strconv.Atoi(match[1])
|
|
|
+ if err != nil {
|
|
|
+ return 0, 0, err
|
|
|
+ }
|
|
|
+ driverMinor, err = strconv.Atoi(match[2])
|
|
|
+ if err != nil {
|
|
|
+ return 0, 0, err
|
|
|
}
|
|
|
- return strings.TrimSpace(string(verString)), nil
|
|
|
+ return driverMajor, driverMinor, nil
|
|
|
}
|