Explorar o código

Detect CUDA OS Overhead

This adds logic to detect skew between the driver and
management library which can be attributed to OS overhead
and records that so we can adjust subsequent management
library free VRAM updates and avoid OOM scenarios.
Daniel Hiltgen hai 9 meses
pai
achega
f6f759fc5f
Modificáronse 2 ficheiros con 29 adicións e 1 borrados
  1. 27 0
      gpu/gpu.go
  2. 2 1
      gpu/types.go

+ 27 - 0
gpu/gpu.go

@@ -274,6 +274,28 @@ func GetGPUInfo() GpuInfoList {
 				gpuInfo.DriverMajor = driverMajor
 				gpuInfo.DriverMajor = driverMajor
 				gpuInfo.DriverMinor = driverMinor
 				gpuInfo.DriverMinor = driverMinor
 
 
+				// query the management library as well so we can record any skew between the two
+				// which represents overhead on the GPU we must set aside on subsequent updates
+				if cHandles.nvml != nil {
+					C.nvml_get_free(*cHandles.nvml, C.int(gpuInfo.index), &memInfo.free, &memInfo.total, &memInfo.used)
+					if memInfo.err != nil {
+						slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
+						C.free(unsafe.Pointer(memInfo.err))
+					} else {
+						if memInfo.free != 0 && uint64(memInfo.free) > gpuInfo.FreeMemory {
+							gpuInfo.OSOverhead = uint64(memInfo.free) - gpuInfo.FreeMemory
+							slog.Info("detected OS VRAM overhead",
+								"id", gpuInfo.ID,
+								"library", gpuInfo.Library,
+								"compute", gpuInfo.Compute,
+								"driver", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor),
+								"name", gpuInfo.Name,
+								"overhead", format.HumanBytes2(gpuInfo.OSOverhead),
+							)
+						}
+					}
+				}
+
 				// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
 				// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
 				cudaGPUs = append(cudaGPUs, gpuInfo)
 				cudaGPUs = append(cudaGPUs, gpuInfo)
 			}
 			}
@@ -374,9 +396,14 @@ func GetGPUInfo() GpuInfoList {
 				slog.Warn("error looking up nvidia GPU memory")
 				slog.Warn("error looking up nvidia GPU memory")
 				continue
 				continue
 			}
 			}
+			if cHandles.nvml != nil && gpu.OSOverhead > 0 {
+				// When using the management library update based on recorded overhead
+				memInfo.free -= C.uint64_t(gpu.OSOverhead)
+			}
 			slog.Debug("updating cuda memory data",
 			slog.Debug("updating cuda memory data",
 				"gpu", gpu.ID,
 				"gpu", gpu.ID,
 				"name", gpu.Name,
 				"name", gpu.Name,
+				"overhead", format.HumanBytes2(gpu.OSOverhead),
 				slog.Group(
 				slog.Group(
 					"before",
 					"before",
 					"total", format.HumanBytes2(gpu.TotalMemory),
 					"total", format.HumanBytes2(gpu.TotalMemory),

+ 2 - 1
gpu/types.go

@@ -52,7 +52,8 @@ type CPUInfo struct {
 
 
 type CudaGPUInfo struct {
 type CudaGPUInfo struct {
 	GpuInfo
 	GpuInfo
-	index int //nolint:unused,nolintlint
+	OSOverhead uint64 // Memory overhead between the driver library and management library
+	index      int    //nolint:unused,nolintlint
 }
 }
 type CudaGPUInfoList []CudaGPUInfo
 type CudaGPUInfoList []CudaGPUInfo