Explorar o código

Use DRM driver for VRAM info for amd

The amdgpu drivers free VRAM reporting omits some other apps, so leverage the
upstream DRM driver which keeps better tabs on things
Daniel Hiltgen hai 11 meses
pai
achega
b32ebb4f29
Modificáronse 1 ficheiros con 105 adicións e 58 borrados
  1. 105 58
      gpu/amd_linux.go

+ 105 - 58
gpu/amd_linux.go

@@ -25,7 +25,16 @@ const (
 
 
 	// Prefix with the node dir
 	// Prefix with the node dir
 	GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
 	GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
-	GPUUsedMemoryFileGlob  = "mem_banks/*/used_memory"
+
+	// Direct Rendering Manager sysfs location
+	DRMDeviceDirGlob   = "/sys/class/drm/card[0-9]/device"
+	DRMTotalMemoryFile = "mem_info_vram_total"
+	DRMUsedMemoryFile  = "mem_info_vram_used"
+
+	// In hex; properties file is in decimal
+	DRMUniqueIDFile = "unique_id"
+	DRMVendorFile   = "vendor"
+	DRMDeviceFile   = "device"
 )
 )
 
 
 var (
 var (
@@ -90,7 +99,7 @@ func AMDGetGPUInfo() []GpuInfo {
 		scanner := bufio.NewScanner(fp)
 		scanner := bufio.NewScanner(fp)
 		isCPU := false
 		isCPU := false
 		var major, minor, patch uint64
 		var major, minor, patch uint64
-		var vendor, device uint64
+		var vendor, device, uniqueID uint64
 		for scanner.Scan() {
 		for scanner.Scan() {
 			line := strings.TrimSpace(scanner.Text())
 			line := strings.TrimSpace(scanner.Text())
 			// Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
 			// Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
@@ -121,30 +130,43 @@ func AMDGetGPUInfo() []GpuInfo {
 			} else if strings.HasPrefix(line, "vendor_id") {
 			} else if strings.HasPrefix(line, "vendor_id") {
 				ver := strings.Fields(line)
 				ver := strings.Fields(line)
 				if len(ver) != 2 {
 				if len(ver) != 2 {
-					slog.Debug("malformed vendor_id", "vendor_id", line)
+					slog.Debug("malformed", "vendor_id", line)
 					continue
 					continue
 				}
 				}
-				vendor, err = strconv.ParseUint(ver[1], 10, 32)
+				vendor, err = strconv.ParseUint(ver[1], 10, 64)
 				if err != nil {
 				if err != nil {
-					slog.Debug("malformed vendor_id" + line)
+					slog.Debug("malformed", "vendor_id", line, "error", err)
 				}
 				}
 			} else if strings.HasPrefix(line, "device_id") {
 			} else if strings.HasPrefix(line, "device_id") {
 				ver := strings.Fields(line)
 				ver := strings.Fields(line)
 				if len(ver) != 2 {
 				if len(ver) != 2 {
-					slog.Debug("malformed device_id", "device_id", line)
+					slog.Debug("malformed", "device_id", line)
 					continue
 					continue
 				}
 				}
-				device, err = strconv.ParseUint(ver[1], 10, 32)
+				device, err = strconv.ParseUint(ver[1], 10, 64)
 				if err != nil {
 				if err != nil {
-					slog.Debug("malformed device_id" + line)
+					slog.Debug("malformed", "device_id", line, "error", err)
+				}
+			} else if strings.HasPrefix(line, "unique_id") {
+				ver := strings.Fields(line)
+				if len(ver) != 2 {
+					slog.Debug("malformed", "unique_id", line)
+					continue
+				}
+				uniqueID, err = strconv.ParseUint(ver[1], 10, 64)
+				if err != nil {
+					slog.Debug("malformed", "unique_id", line, "error", err)
 				}
 				}
 			}
 			}
-
 			// TODO - any other properties we want to extract and record?
 			// TODO - any other properties we want to extract and record?
 			// vendor_id + device_id -> pci lookup for "Name"
 			// vendor_id + device_id -> pci lookup for "Name"
 			// Other metrics that may help us understand relative performance between multiple GPUs
 			// Other metrics that may help us understand relative performance between multiple GPUs
 		}
 		}
 
 
+		// Note: while ./mem_banks/*/used_memory exists, it doesn't appear to take other VRAM consumers
+		// into consideration, so we instead map the device over to the DRM driver sysfs nodes which
+		// do reliably report VRAM usage.
+
 		if isCPU {
 		if isCPU {
 			cpuCount++
 			cpuCount++
 			continue
 			continue
@@ -167,65 +189,90 @@ func AMDGetGPUInfo() []GpuInfo {
 		// Look up the memory for the current node
 		// Look up the memory for the current node
 		totalMemory := uint64(0)
 		totalMemory := uint64(0)
 		usedMemory := uint64(0)
 		usedMemory := uint64(0)
-		propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUTotalMemoryFileGlob)
-		propFiles, err := filepath.Glob(propGlob)
-		if err != nil {
-			slog.Warn("error looking up total GPU memory", "glob", propGlob, "error", err)
+		mapping := []struct {
+			id       uint64
+			filename string
+		}{
+			{vendor, DRMVendorFile},
+			{device, DRMDeviceFile},
+			{uniqueID, DRMUniqueIDFile}, // Not all devices will report this
 		}
 		}
-		// 1 or more memory banks - sum the values of all of them
-		for _, propFile := range propFiles {
-			fp, err := os.Open(propFile)
-			if err != nil {
-				slog.Warn("failed to open sysfs node", "file", propFile, "erroir", err)
+		slog.Debug("mapping amdgpu to drm sysfs nodes", "amdgpu", match, "vendor", vendor, "device", device, "unique_id", uniqueID)
+		// Map over to DRM location to find the total/free memory
+		drmMatches, _ := filepath.Glob(DRMDeviceDirGlob)
+		for _, devDir := range drmMatches {
+			matched := true
+			for _, m := range mapping {
+				if m.id == 0 {
+					continue
+				}
+				filename := filepath.Join(devDir, m.filename)
+				fp, err := os.Open(filename)
+				if err != nil {
+					slog.Debug("failed to open sysfs node", "file", filename, "error", err)
+					matched = false
+					break
+				}
+				defer fp.Close()
+				buf, err := io.ReadAll(fp)
+				if err != nil {
+					slog.Debug("failed to read sysfs node", "file", filename, "error", err)
+					matched = false
+					break
+				}
+				cmp, err := strconv.ParseUint(strings.TrimPrefix(strings.TrimSpace(string(buf)), "0x"), 16, 64)
+				if err != nil {
+					slog.Debug("failed to parse sysfs node", "file", filename, "error", err)
+					matched = false
+					break
+				}
+				if cmp != m.id {
+					matched = false
+					break
+				}
+			}
+			if !matched {
 				continue
 				continue
 			}
 			}
-			defer fp.Close()
-			scanner := bufio.NewScanner(fp)
-			for scanner.Scan() {
-				line := strings.TrimSpace(scanner.Text())
-				if strings.HasPrefix(line, "size_in_bytes") {
-					ver := strings.Fields(line)
-					if len(ver) != 2 {
-						slog.Warn("malformed " + line)
-						continue
-					}
-					bankSizeInBytes, err := strconv.ParseUint(ver[1], 10, 64)
-					if err != nil {
-						slog.Warn("malformed int " + line)
-						continue
-					}
-					totalMemory += bankSizeInBytes
-				}
+
+			// Found the matching DRM directory
+			slog.Debug("matched", "amdgpu", match, "drm", devDir)
+			totalFile := filepath.Join(devDir, DRMTotalMemoryFile)
+			totalFp, err := os.Open(totalFile)
+			if err != nil {
+				slog.Debug("failed to open sysfs node", "file", totalFile, "error", err)
+				break
 			}
 			}
-		}
-		if totalMemory == 0 {
-			slog.Warn("amdgpu reports zero total memory", "gpu", gpuID)
-			continue
-		}
-		usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUUsedMemoryFileGlob)
-		usedFiles, err := filepath.Glob(usedGlob)
-		if err != nil {
-			slog.Warn("error looking up used GPU memory", "glob", usedGlob, "error", err)
-			continue
-		}
-		for _, usedFile := range usedFiles {
-			fp, err := os.Open(usedFile)
+			defer totalFp.Close()
+			buf, err := io.ReadAll(totalFp)
 			if err != nil {
 			if err != nil {
-				slog.Warn("failed to open sysfs node", "file", usedFile, "error", err)
-				continue
+				slog.Debug("failed to read sysfs node", "file", totalFile, "error", err)
+				break
 			}
 			}
-			defer fp.Close()
-			data, err := io.ReadAll(fp)
+			totalMemory, err = strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64)
 			if err != nil {
 			if err != nil {
-				slog.Warn("failed to read sysfs node", "file", usedFile, "error", err)
-				continue
+				slog.Debug("failed to parse sysfs node", "file", totalFile, "error", err)
+				break
+			}
+
+			usedFile := filepath.Join(devDir, DRMUsedMemoryFile)
+			usedFp, err := os.Open(usedFile)
+			if err != nil {
+				slog.Debug("failed to open sysfs node", "file", usedFile, "error", err)
+				break
 			}
 			}
-			used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
+			defer totalFp.Close()
+			buf, err = io.ReadAll(usedFp)
 			if err != nil {
 			if err != nil {
-				slog.Warn("malformed used memory", "data", string(data), "error", err)
-				continue
+				slog.Debug("failed to read sysfs node", "file", usedFile, "error", err)
+				break
+			}
+			usedMemory, err = strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64)
+			if err != nil {
+				slog.Debug("failed to parse sysfs node", "file", usedFile, "error", err)
+				break
 			}
 			}
-			usedMemory += used
+			break
 		}
 		}
 
 
 		// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
 		// iGPU detection, remove this check once we can support an iGPU variant of the rocm library