فهرست منبع

Merge pull request #5580 from dhiltgen/cuda_overhead

Detect CUDA OS overhead
Daniel Hiltgen 9 ماه پیش
والد
کامیت
8ea500441d
2فایلهای تغییر یافته به همراه29 افزوده شده و 1 حذف شده
  1. 27 0
      gpu/gpu.go
  2. 2 1
      gpu/types.go

+ 27 - 0
gpu/gpu.go

@@ -274,6 +274,28 @@ func GetGPUInfo() GpuInfoList {
 				gpuInfo.DriverMajor = driverMajor
 				gpuInfo.DriverMinor = driverMinor
 
+				// query the management library as well so we can record any skew between the two
+				// which represents overhead on the GPU we must set aside on subsequent updates
+				if cHandles.nvml != nil {
+					C.nvml_get_free(*cHandles.nvml, C.int(gpuInfo.index), &memInfo.free, &memInfo.total, &memInfo.used)
+					if memInfo.err != nil {
+						slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
+						C.free(unsafe.Pointer(memInfo.err))
+					} else {
+						if memInfo.free != 0 && uint64(memInfo.free) > gpuInfo.FreeMemory {
+							gpuInfo.OSOverhead = uint64(memInfo.free) - gpuInfo.FreeMemory
+							slog.Info("detected OS VRAM overhead",
+								"id", gpuInfo.ID,
+								"library", gpuInfo.Library,
+								"compute", gpuInfo.Compute,
+								"driver", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor),
+								"name", gpuInfo.Name,
+								"overhead", format.HumanBytes2(gpuInfo.OSOverhead),
+							)
+						}
+					}
+				}
+
 				// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
 				cudaGPUs = append(cudaGPUs, gpuInfo)
 			}
@@ -374,9 +396,14 @@ func GetGPUInfo() GpuInfoList {
 				slog.Warn("error looking up nvidia GPU memory")
 				continue
 			}
+			if cHandles.nvml != nil && gpu.OSOverhead > 0 {
+				// When using the management library update based on recorded overhead
+				memInfo.free -= C.uint64_t(gpu.OSOverhead)
+			}
 			slog.Debug("updating cuda memory data",
 				"gpu", gpu.ID,
 				"name", gpu.Name,
+				"overhead", format.HumanBytes2(gpu.OSOverhead),
 				slog.Group(
 					"before",
 					"total", format.HumanBytes2(gpu.TotalMemory),

+ 2 - 1
gpu/types.go

@@ -52,7 +52,8 @@ type CPUInfo struct {
 
 type CudaGPUInfo struct {
 	GpuInfo
-	index int //nolint:unused,nolintlint
+	OSOverhead uint64 // Memory overhead between the driver library and management library
+	index      int    //nolint:unused,nolintlint
 }
 type CudaGPUInfoList []CudaGPUInfo