소스 검색

Refine GPU discovery to bootstrap once

Now that we call the GPU discovery routines many times to
update memory, this splits initial discovery from free memory
updating.
Daniel Hiltgen 1 년 전
부모
커밋
43ed358f9a
9개의 변경된 파일372개의 추가작업 그리고 138개의 파일을 삭제
  1. 56 31
      gpu/amd_linux.go
  2. 46 16
      gpu/amd_windows.go
  3. 7 8
      gpu/cpu_common.go
  4. 163 77
      gpu/gpu.go
  5. 1 1
      gpu/gpu_info_cudart.c
  6. 2 1
      gpu/gpu_info_cudart.h
  7. 38 3
      gpu/gpu_info_nvcuda.c
  8. 2 1
      gpu/gpu_info_nvcuda.h
  9. 57 0
      gpu/types.go

+ 56 - 31
gpu/amd_linux.go

@@ -44,8 +44,8 @@ var (
 )
 
 // Gather GPU information from the amdgpu driver if any supported GPUs are detected
-func AMDGetGPUInfo() []GpuInfo {
-	resp := []GpuInfo{}
+func AMDGetGPUInfo() []RocmGPUInfo {
+	resp := []RocmGPUInfo{}
 	if !AMDDetected() {
 		return resp
 	}
@@ -178,7 +178,7 @@ func AMDGetGPUInfo() []GpuInfo {
 		// Shouldn't happen, but just in case...
 		if gpuID < 0 {
 			slog.Error("unexpected amdgpu sysfs data resulted in negative GPU ID, please set OLLAMA_DEBUG=1 and report an issue")
-			return []GpuInfo{}
+			return []RocmGPUInfo{}
 		}
 
 		if int(major) < RocmComputeMin {
@@ -189,6 +189,7 @@ func AMDGetGPUInfo() []GpuInfo {
 		// Look up the memory for the current node
 		totalMemory := uint64(0)
 		usedMemory := uint64(0)
+		var usedFile string
 		mapping := []struct {
 			id       uint64
 			filename string
@@ -255,22 +256,10 @@ func AMDGetGPUInfo() []GpuInfo {
 				break
 			}
 
-			usedFile := filepath.Join(devDir, DRMUsedMemoryFile)
-			usedFp, err := os.Open(usedFile)
+			usedFile = filepath.Join(devDir, DRMUsedMemoryFile)
+			usedMemory, err = getFreeMemory(usedFile)
 			if err != nil {
-				slog.Debug("failed to open sysfs node", "file", usedFile, "error", err)
-				break
-			}
-			defer totalFp.Close()
-			buf, err = io.ReadAll(usedFp)
-			if err != nil {
-				slog.Debug("failed to read sysfs node", "file", usedFile, "error", err)
-				break
-			}
-			usedMemory, err = strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64)
-			if err != nil {
-				slog.Debug("failed to parse sysfs node", "file", usedFile, "error", err)
-				break
+				slog.Debug("failed to update used memory", "error", err)
 			}
 			break
 		}
@@ -288,18 +277,21 @@ func AMDGetGPUInfo() []GpuInfo {
 
 		slog.Debug("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
 		slog.Debug("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
-		gpuInfo := GpuInfo{
-			Library: "rocm",
-			memInfo: memInfo{
-				TotalMemory: totalMemory,
-				FreeMemory:  (totalMemory - usedMemory),
+		gpuInfo := RocmGPUInfo{
+			GpuInfo: GpuInfo{
+				Library: "rocm",
+				memInfo: memInfo{
+					TotalMemory: totalMemory,
+					FreeMemory:  (totalMemory - usedMemory),
+				},
+				ID:            fmt.Sprintf("%d", gpuID),
+				Name:          name,
+				Compute:       fmt.Sprintf("gfx%d%x%x", major, minor, patch),
+				MinimumMemory: rocmMinimumMemory,
+				DriverMajor:   driverMajor,
+				DriverMinor:   driverMinor,
 			},
-			ID:            fmt.Sprintf("%d", gpuID),
-			Name:          name,
-			Compute:       fmt.Sprintf("gfx%d%x%x", major, minor, patch),
-			MinimumMemory: rocmMinimumMemory,
-			DriverMajor:   driverMajor,
-			DriverMinor:   driverMinor,
+			usedFilepath: usedFile,
 		}
 
 		// If the user wants to filter to a subset of devices, filter out if we aren't a match
@@ -323,7 +315,7 @@ func AMDGetGPUInfo() []GpuInfo {
 			libDir, err = AMDValidateLibDir()
 			if err != nil {
 				slog.Warn("unable to verify rocm library, will use cpu", "error", err)
-				return []GpuInfo{}
+				return []RocmGPUInfo{}
 			}
 		}
 		gpuInfo.DependencyPath = libDir
@@ -334,7 +326,7 @@ func AMDGetGPUInfo() []GpuInfo {
 				supported, err = GetSupportedGFX(libDir)
 				if err != nil {
 					slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
-					return []GpuInfo{}
+					return []RocmGPUInfo{}
 				}
 				slog.Debug("rocm supported GPUs", "types", supported)
 			}
@@ -425,3 +417,36 @@ func AMDDriverVersion() (driverMajor, driverMinor int, err error) {
 	}
 	return driverMajor, driverMinor, nil
 }
+
+func (gpus RocmGPUInfoList) RefreshFreeMemory() error {
+	if len(gpus) == 0 {
+		return nil
+	}
+	for i := range gpus {
+		usedMemory, err := getFreeMemory(gpus[i].usedFilepath)
+		if err != nil {
+			return err
+		}
+		slog.Debug("updating rocm free memory", "gpu", gpus[i].ID, "name", gpus[i].Name, "before", format.HumanBytes2(gpus[i].FreeMemory), "now", format.HumanBytes2(gpus[i].TotalMemory-usedMemory))
+		gpus[i].FreeMemory = gpus[i].TotalMemory - usedMemory
+	}
+	return nil
+}
+
+func getFreeMemory(usedFile string) (uint64, error) {
+	usedFp, err := os.Open(usedFile)
+	if err != nil {
+		return 0, fmt.Errorf("failed to open sysfs node %s %w", usedFile, err)
+	}
+	defer usedFp.Close()
+	buf, err := io.ReadAll(usedFp)
+	if err != nil {
+		return 0, fmt.Errorf("failed to read sysfs node %s %w", usedFile, err)
+	}
+	usedMemory, err := strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64)
+	if err != nil {
+		slog.Debug("failed to parse sysfs node", "file", usedFile, "error", err)
+		return 0, fmt.Errorf("failed to parse sysfs node %s %w", usedFile, err)
+	}
+	return usedMemory, nil
+}

+ 46 - 16
gpu/amd_windows.go

@@ -24,8 +24,8 @@ var (
 	RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob?
 )
 
-func AMDGetGPUInfo() []GpuInfo {
-	resp := []GpuInfo{}
+func AMDGetGPUInfo() []RocmGPUInfo {
+	resp := []RocmGPUInfo{}
 	hl, err := NewHipLib()
 	if err != nil {
 		slog.Debug(err.Error())
@@ -117,21 +117,24 @@ func AMDGetGPUInfo() []GpuInfo {
 		// v5.7 only reports VRAM used by this process, so it's completely wrong and unusable
 		slog.Debug("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory))
 		slog.Debug("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory))
-		gpuInfo := GpuInfo{
-			Library: "rocm",
-			memInfo: memInfo{
-				TotalMemory: totalMemory,
-				FreeMemory:  freeMemory,
+		gpuInfo := RocmGPUInfo{
+			GpuInfo: GpuInfo{
+				Library: "rocm",
+				memInfo: memInfo{
+					TotalMemory: totalMemory,
+					FreeMemory:  freeMemory,
+				},
+				ID:             fmt.Sprintf("%d", i), // TODO this is probably wrong if we specify visible devices
+				DependencyPath: libDir,
+				MinimumMemory:  rocmMinimumMemory,
+				Name:           name,
+				Compute:        gfx,
+
+				// TODO - this information isn't accurate on windows, so don't report it until we find the right way to retrieve
+				// DriverMajor:    driverMajor,
+				// DriverMinor:    driverMinor,
 			},
-			ID:             fmt.Sprintf("%d", i), // TODO this is probably wrong if we specify visible devices
-			DependencyPath: libDir,
-			MinimumMemory:  rocmMinimumMemory,
-			Name:           name,
-			Compute:        gfx,
-
-			// TODO - this information isn't accurate on windows, so don't report it until we find the right way to retrieve
-			// DriverMajor:    driverMajor,
-			// DriverMinor:    driverMinor,
+			index: i,
 		}
 
 		resp = append(resp, gpuInfo)
@@ -159,3 +162,30 @@ func AMDValidateLibDir() (string, error) {
 	slog.Warn("amdgpu detected, but no compatible rocm library found.  Please install ROCm")
 	return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
 }
+
+func (gpus RocmGPUInfoList) RefreshFreeMemory() error {
+	if len(gpus) == 0 {
+		return nil
+	}
+	hl, err := NewHipLib()
+	if err != nil {
+		slog.Debug(err.Error())
+		return nil
+	}
+	defer hl.Release()
+
+	for i := range gpus {
+		err := hl.HipSetDevice(gpus[i].index)
+		if err != nil {
+			return err
+		}
+		freeMemory, _, err := hl.HipMemGetInfo()
+		if err != nil {
+			slog.Warn("get mem info", "id", i, "error", err)
+			continue
+		}
+		slog.Debug("updating rocm free memory", "gpu", gpus[i].ID, "name", gpus[i].Name, "before", format.HumanBytes2(gpus[i].FreeMemory), "now", format.HumanBytes2(freeMemory))
+		gpus[i].FreeMemory = freeMemory
+	}
+	return nil
+}

+ 7 - 8
gpu/cpu_common.go

@@ -1,21 +1,20 @@
 package gpu
 
 import (
-	"log/slog"
-
 	"golang.org/x/sys/cpu"
 )
 
 func GetCPUVariant() string {
+	return getCPUCapability().ToVariant()
+}
+
+func getCPUCapability() CPUCapability {
 	if cpu.X86.HasAVX2 {
-		slog.Debug("CPU has AVX2")
-		return "avx2"
+		return CPUCapabilityAVX2
 	}
 	if cpu.X86.HasAVX {
-		slog.Debug("CPU has AVX")
-		return "avx"
+		return CPUCapabilityAVX
 	}
-	slog.Debug("CPU does not have vector extensions")
 	// else LCD
-	return ""
+	return CPUCapabilityBase
 }

+ 163 - 77
gpu/gpu.go

@@ -21,8 +21,8 @@ import (
 	"sync"
 	"unsafe"
 
-	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/envconfig"
+	"github.com/ollama/ollama/format"
 )
 
 type handles struct {
@@ -37,7 +37,18 @@ const (
 	rocmMinimumMemory = 457 * format.MebiByte
 )
 
-var gpuMutex sync.Mutex
+var (
+	gpuMutex      sync.Mutex
+	bootstrapped  bool
+	cpuCapability CPUCapability
+	cpus          []CPUInfo
+	cudaGPUs      []CudaGPUInfo
+	nvcudaLibPath string
+	cudartLibPath string
+	oneapiLibPath string
+	rocmGPUs      []RocmGPUInfo
+	oneapiGPUs    []OneapiGPUInfo
+)
 
 // With our current CUDA compile flags, older than 5.0 will not work properly
 var CudaComputeMin = [2]C.int{5, 0}
@@ -96,11 +107,22 @@ var OneapiLinuxGlobs = []string{
 var CudaTegra string = os.Getenv("JETSON_JETPACK")
 
 // Note: gpuMutex must already be held
-func initGPUHandles() *handles {
+func initCudaHandles() *handles {
 
 	// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
 
 	gpuHandles := &handles{}
+	// Short Circuit if we already know which library to use
+	if nvcudaLibPath != "" {
+		gpuHandles.deviceCount, gpuHandles.nvcuda, _ = LoadNVCUDAMgmt([]string{nvcudaLibPath})
+		return gpuHandles
+	}
+	if cudartLibPath != "" {
+		gpuHandles.deviceCount, gpuHandles.cudart, _ = LoadCUDARTMgmt([]string{cudartLibPath})
+		return gpuHandles
+	}
+
+	slog.Debug("searching for GPU discovery libraries for NVIDIA")
 	var cudartMgmtName string
 	var cudartMgmtPatterns []string
 	var nvcudaMgmtName string
@@ -136,7 +158,6 @@ func initGPUHandles() *handles {
 		return gpuHandles
 	}
 
-	slog.Debug("Detecting GPUs")
 	nvcudaLibPaths := FindGPULibs(nvcudaMgmtName, nvcudaMgmtPatterns)
 	if len(nvcudaLibPaths) > 0 {
 		deviceCount, nvcuda, libPath := LoadNVCUDAMgmt(nvcudaLibPaths)
@@ -144,6 +165,7 @@ func initGPUHandles() *handles {
 			slog.Debug("detected GPUs", "count", deviceCount, "library", libPath)
 			gpuHandles.nvcuda = nvcuda
 			gpuHandles.deviceCount = deviceCount
+			nvcudaLibPath = libPath
 			return gpuHandles
 		}
 	}
@@ -155,6 +177,7 @@ func initGPUHandles() *handles {
 			slog.Debug("detected GPUs", "library", libPath, "count", deviceCount)
 			gpuHandles.cudart = cudart
 			gpuHandles.deviceCount = deviceCount
+			cudartLibPath = libPath
 			return gpuHandles
 		}
 	}
@@ -166,6 +189,7 @@ func initGPUHandles() *handles {
 			slog.Debug("detected Intel GPUs", "library", libPath, "count", deviceCount)
 			gpuHandles.oneapi = oneapi
 			gpuHandles.deviceCount = deviceCount
+			oneapiLibPath = libPath
 			return gpuHandles
 		}
 	}
@@ -178,9 +202,12 @@ func GetGPUInfo() GpuInfoList {
 	// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
 	gpuMutex.Lock()
 	defer gpuMutex.Unlock()
-
-	gpuHandles := initGPUHandles()
+	needRefresh := true
+	var gpuHandles *handles
 	defer func() {
+		if gpuHandles == nil {
+			return
+		}
 		if gpuHandles.cudart != nil {
 			C.cudart_release(*gpuHandles.cudart)
 		}
@@ -189,97 +216,156 @@ func GetGPUInfo() GpuInfoList {
 		}
 	}()
 
-	// All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX
-	cpuVariant := GetCPUVariant()
-	if cpuVariant == "" && runtime.GOARCH == "amd64" {
-		slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.")
-	}
+	if !bootstrapped {
+		slog.Debug("Detecting GPUs")
+		needRefresh = false
+		cpuCapability = getCPUCapability()
+		var memInfo C.mem_info_t
+		C.cpu_check_ram(&memInfo)
+		if memInfo.err != nil {
+			slog.Info("error looking up CPU memory", "error", C.GoString(memInfo.err))
+			C.free(unsafe.Pointer(memInfo.err))
+			return []GpuInfo{}
+		}
+		cpuInfo := CPUInfo{
+			GpuInfo: GpuInfo{
+				Library: "cpu",
+				Variant: cpuCapability.ToVariant(),
+			},
+		}
+		cpuInfo.TotalMemory = uint64(memInfo.total)
+		cpuInfo.FreeMemory = uint64(memInfo.free)
+		cpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
+		cpus = []CPUInfo{cpuInfo}
+
+		// Fallback to CPU mode if we're lacking required vector extensions on x86
+		if cpuCapability < GPURunnerCPUCapability && runtime.GOARCH == "amd64" {
+			slog.Warn("CPU does not have minimum vector extensions, GPU inference disabled", "required", GPURunnerCPUCapability.ToString(), "detected", cpuCapability.ToString())
+			bootstrapped = true
+			// No need to do any GPU discovery, since we can't run on them
+			return GpuInfoList{cpus[0].GpuInfo}
+		}
 
-	// On windows we bundle the nvidia library one level above the runner dir
-	depPath := ""
-	if runtime.GOOS == "windows" && envconfig.RunnersDir != "" {
-		depPath = filepath.Dir(envconfig.RunnersDir)
-	}
+		// TODO - implement
 
-	var memInfo C.mem_info_t
-	resp := []GpuInfo{}
+		// TODO refine the discovery to only gather total memory
 
-	// NVIDIA first
-	for i := range gpuHandles.deviceCount {
-		// TODO once we support CPU compilation variants of GPU libraries refine this...
-		if cpuVariant == "" && runtime.GOARCH == "amd64" {
-			continue
+		// On windows we bundle the nvidia library one level above the runner dir
+		depPath := ""
+		if runtime.GOOS == "windows" && envconfig.RunnersDir != "" {
+			depPath = filepath.Dir(envconfig.RunnersDir)
 		}
-		if gpuHandles.cudart != nil || gpuHandles.nvcuda != nil {
-			gpuInfo := GpuInfo{
-				Library: "cuda",
+
+		// Load ALL libraries
+		gpuHandles = initCudaHandles()
+
+		// TODO needs a refactoring pass to init oneapi handles
+
+		// NVIDIA
+		for i := range gpuHandles.deviceCount {
+			if gpuHandles.cudart != nil || gpuHandles.nvcuda != nil {
+				gpuInfo := CudaGPUInfo{
+					GpuInfo: GpuInfo{
+						Library: "cuda",
+					},
+					index: i,
+				}
+				var driverMajor int
+				var driverMinor int
+				if gpuHandles.cudart != nil {
+					C.cudart_bootstrap(*gpuHandles.cudart, C.int(i), &memInfo)
+				} else {
+					C.nvcuda_bootstrap(*gpuHandles.nvcuda, C.int(i), &memInfo)
+					driverMajor = int(gpuHandles.nvcuda.driver_major)
+					driverMinor = int(gpuHandles.nvcuda.driver_minor)
+				}
+				if memInfo.err != nil {
+					slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
+					C.free(unsafe.Pointer(memInfo.err))
+					continue
+				}
+				if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) {
+					slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
+					continue
+				}
+				gpuInfo.TotalMemory = uint64(memInfo.total)
+				gpuInfo.FreeMemory = uint64(memInfo.free)
+				gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
+				gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor)
+				gpuInfo.MinimumMemory = cudaMinimumMemory
+				gpuInfo.DependencyPath = depPath
+				gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
+				gpuInfo.DriverMajor = int(driverMajor)
+				gpuInfo.DriverMinor = int(driverMinor)
+
+				// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
+				cudaGPUs = append(cudaGPUs, gpuInfo)
 			}
-			var driverMajor int
-			var driverMinor int
+			if gpuHandles.oneapi != nil {
+				gpuInfo := OneapiGPUInfo{
+					GpuInfo: GpuInfo{
+						Library: "oneapi",
+					},
+					index: i,
+				}
+				// TODO - split bootstrapping from updating free memory
+				C.oneapi_check_vram(*gpuHandles.oneapi, &memInfo)
+				var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
+				memInfo.free = C.uint64_t(totalFreeMem)
+				gpuInfo.TotalMemory = uint64(memInfo.total)
+				gpuInfo.FreeMemory = uint64(memInfo.free)
+				gpuInfo.ID = strconv.Itoa(i)
+				oneapiGPUs = append(oneapiGPUs, gpuInfo)
+			}
+		}
+
+		rocmGPUs = AMDGetGPUInfo()
+		bootstrapped = true
+	}
+
+	// For detected GPUs, load library if not loaded
+
+	// Refresh free memory usage
+	if needRefresh {
+		// TODO - CPU system memory tracking/refresh
+		var memInfo C.mem_info_t
+		if gpuHandles == nil && len(cudaGPUs) > 0 {
+			gpuHandles = initCudaHandles()
+		}
+		for i, gpu := range cudaGPUs {
 			if gpuHandles.cudart != nil {
-				C.cudart_check_vram(*gpuHandles.cudart, C.int(i), &memInfo)
+				C.cudart_bootstrap(*gpuHandles.cudart, C.int(gpu.index), &memInfo)
 			} else {
-				C.nvcuda_check_vram(*gpuHandles.nvcuda, C.int(i), &memInfo)
-				driverMajor = int(gpuHandles.nvcuda.driver_major)
-				driverMinor = int(gpuHandles.nvcuda.driver_minor)
+				C.nvcuda_get_free(*gpuHandles.nvcuda, C.int(gpu.index), &memInfo.free)
 			}
 			if memInfo.err != nil {
-				slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
+				slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
 				C.free(unsafe.Pointer(memInfo.err))
 				continue
 			}
-			if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) {
-				slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
+			if memInfo.free == 0 {
+				slog.Warn("error looking up nvidia GPU memory")
 				continue
 			}
-			gpuInfo.TotalMemory = uint64(memInfo.total)
-			gpuInfo.FreeMemory = uint64(memInfo.free)
-			gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
-			gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor)
-			gpuInfo.MinimumMemory = cudaMinimumMemory
-			gpuInfo.DependencyPath = depPath
-			gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
-			gpuInfo.DriverMajor = driverMajor
-			gpuInfo.DriverMinor = driverMinor
-
-			// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
-			resp = append(resp, gpuInfo)
+			slog.Debug("updating cuda free memory", "gpu", gpu.ID, "name", gpu.Name, "before", format.HumanBytes2(gpu.FreeMemory), "now", format.HumanBytes2(uint64(memInfo.free)))
+			cudaGPUs[i].FreeMemory = uint64(memInfo.free)
 		}
-		if gpuHandles.oneapi != nil {
-			gpuInfo := GpuInfo{
-				Library: "oneapi",
-			}
-			C.oneapi_check_vram(*gpuHandles.oneapi, &memInfo)
-			var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
-			memInfo.free = C.uint64_t(totalFreeMem)
-			gpuInfo.TotalMemory = uint64(memInfo.total)
-			gpuInfo.FreeMemory = uint64(memInfo.free)
-			gpuInfo.ID = strconv.Itoa(i)
-			resp = append(resp, gpuInfo)
+		err := RocmGPUInfoList(rocmGPUs).RefreshFreeMemory()
+		if err != nil {
+			slog.Debug("problem refreshing ROCm free memory", "error", err)
 		}
 	}
 
-	// Then AMD
-	resp = append(resp, AMDGetGPUInfo()...)
-
+	resp := []GpuInfo{}
+	for _, gpu := range cudaGPUs {
+		resp = append(resp, gpu.GpuInfo)
+	}
+	for _, gpu := range rocmGPUs {
+		resp = append(resp, gpu.GpuInfo)
+	}
 	if len(resp) == 0 {
-		C.cpu_check_ram(&memInfo)
-		if memInfo.err != nil {
-			slog.Info("error looking up CPU memory", "error", C.GoString(memInfo.err))
-			C.free(unsafe.Pointer(memInfo.err))
-			return resp
-		}
-		gpuInfo := GpuInfo{
-			Library: "cpu",
-			Variant: cpuVariant,
-		}
-		gpuInfo.TotalMemory = uint64(memInfo.total)
-		gpuInfo.FreeMemory = uint64(memInfo.free)
-		gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
-
-		resp = append(resp, gpuInfo)
+		resp = append(resp, cpus[0].GpuInfo)
 	}
-
 	return resp
 }
 

+ 1 - 1
gpu/gpu_info_cudart.c

@@ -94,7 +94,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
 }
 
 
-void cudart_check_vram(cudart_handle_t h, int i, mem_info_t *resp) {
+void cudart_bootstrap(cudart_handle_t h, int i, mem_info_t *resp) {
   resp->err = NULL;
   cudartMemory_t memInfo = {0,0,0};
   cudartReturn_t ret;

+ 2 - 1
gpu/gpu_info_cudart.h

@@ -140,7 +140,8 @@ typedef struct cudart_init_resp {
 } cudart_init_resp_t;
 
 void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp);
-void cudart_check_vram(cudart_handle_t ch, int device_id, mem_info_t *resp);
+void cudart_bootstrap(cudart_handle_t ch, int device_id, mem_info_t *resp);
+// TODO - if we keep this library longer term, add cudart_get_free
 void cudart_release(cudart_handle_t ch);
 
 #endif  // __GPU_INFO_CUDART_H__

+ 38 - 3
gpu/gpu_info_nvcuda.c

@@ -96,7 +96,7 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
 }
 
 const int buflen = 256;
-void nvcuda_check_vram(nvcuda_handle_t h, int i, mem_info_t *resp) {
+void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) {
   resp->err = NULL;
   nvcudaMemory_t memInfo = {0,0};
   CUresult ret;
@@ -168,7 +168,7 @@ void nvcuda_check_vram(nvcuda_handle_t h, int i, mem_info_t *resp) {
   // To get memory we have to set (and release) a context
   ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
   if (ret != CUDA_SUCCESS) {
-    snprintf(buf, buflen, "nvcuda failed to get primary device context %d", ret);
+    snprintf(buf, buflen, "nvcuda failed to get device context %d", ret);
     resp->err = strdup(buf);
     return;
   }
@@ -193,7 +193,42 @@ void nvcuda_check_vram(nvcuda_handle_t h, int i, mem_info_t *resp) {
 
   ret = (*h.cuCtxDestroy)(ctx);
   if (ret != CUDA_SUCCESS) {
-    LOG(1, "nvcuda failed to release primary device context %d", ret);
+    LOG(1, "nvcuda failed to release device context %d", ret);
+  }
+}
+
+void nvcuda_get_free(nvcuda_handle_t h, int i, uint64_t *free) {
+  CUresult ret;
+  CUcontext ctx = NULL;
+  CUdevice device = -1;
+  *free = 0;
+  uint64_t total = 0;
+
+  ret = (*h.cuDeviceGet)(&device, i);
+  if (ret != CUDA_SUCCESS) {
+    LOG(1, "nvcuda device failed to initialize");
+    return;
+  }
+
+
+  // To get memory we have to set (and release) a context
+  ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
+  if (ret != CUDA_SUCCESS) {
+    LOG(1, "nvcuda failed to get device context %d", ret);
+    return;
+  }
+
+  ret = (*h.cuMemGetInfo_v2)(free, &total);
+  if (ret != CUDA_SUCCESS) {
+    LOG(1, "nvcuda device memory info lookup failure %d", ret);
+    // Best effort on failure...
+    (*h.cuCtxDestroy)(ctx);
+    return;
+  }
+
+  ret = (*h.cuCtxDestroy)(ctx);
+  if (ret != CUDA_SUCCESS) {
+    LOG(1, "nvcuda failed to release device context %d", ret);
   }
 }
 

+ 2 - 1
gpu/gpu_info_nvcuda.h

@@ -67,7 +67,8 @@ typedef struct nvcuda_init_resp {
 } nvcuda_init_resp_t;
 
 void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp);
-void nvcuda_check_vram(nvcuda_handle_t ch, int device_id, mem_info_t *resp);
+void nvcuda_bootstrap(nvcuda_handle_t ch, int device_id, mem_info_t *resp);
+void nvcuda_get_free(nvcuda_handle_t ch,  int device_id, uint64_t *free);
 void nvcuda_release(nvcuda_handle_t ch);
 
 #endif  // __GPU_INFO_NVCUDA_H__

+ 57 - 0
gpu/types.go

@@ -38,6 +38,29 @@ type GpuInfo struct {
 	// TODO other performance capability info to help in scheduling decisions
 }
 
+type CPUInfo struct {
+	GpuInfo
+}
+
+type CudaGPUInfo struct {
+	GpuInfo
+	index int // device index
+}
+type CudaGPUInfoList []CudaGPUInfo
+
+type RocmGPUInfo struct {
+	GpuInfo
+	usedFilepath string // linux
+	index        int    // device index on windows
+}
+type RocmGPUInfoList []RocmGPUInfo
+
+type OneapiGPUInfo struct {
+	GpuInfo
+	index int // device index
+}
+type OneapiGPUInfoList []OneapiGPUInfo
+
 type GpuInfoList []GpuInfo
 
 // Split up the set of gpu info's by Library and variant
@@ -86,3 +109,37 @@ type ByFreeMemory []GpuInfo
 func (a ByFreeMemory) Len() int           { return len(a) }
 func (a ByFreeMemory) Swap(i, j int)      { a[i], a[j] = a[j], a[i] }
 func (a ByFreeMemory) Less(i, j int) bool { return a[i].FreeMemory < a[j].FreeMemory }
+
+type CPUCapability uint32
+
+// Override at build time when building base GPU runners
+var GPURunnerCPUCapability = CPUCapabilityAVX
+
+const (
+	CPUCapabilityBase CPUCapability = iota
+	CPUCapabilityAVX
+	CPUCapabilityAVX2
+	// TODO AVX512
+)
+
+func (c CPUCapability) ToString() string {
+	switch c {
+	case CPUCapabilityAVX:
+		return "AVX"
+	case CPUCapabilityAVX2:
+		return "AVX2"
+	default:
+		return "no vector extensions"
+	}
+}
+
+func (c CPUCapability) ToVariant() string {
+	switch c {
+	case CPUCapabilityAVX:
+		return "avx"
+	case CPUCapabilityAVX2:
+		return "avx2"
+	default:
+		return ""
+	}
+}