Browse Source

Merge pull request #4517 from dhiltgen/gpu_incremental

Enhanced GPU discovery and multi-gpu support with concurrency
Daniel Hiltgen 10 months ago
parent
commit
45cacbaf05

+ 12 - 0
envconfig/config.go

@@ -53,6 +53,8 @@ var (
 	NumParallel int
 	// Set via OLLAMA_RUNNERS_DIR in the environment
 	RunnersDir string
+	// Set via OLLAMA_SCHED_SPREAD in the environment
+	SchedSpread bool
 	// Set via OLLAMA_TMPDIR in the environment
 	TmpDir string
 )
@@ -79,6 +81,7 @@ func AsMap() map[string]EnvVar {
 		"OLLAMA_NUM_PARALLEL":      {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests (default 1)"},
 		"OLLAMA_ORIGINS":           {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"},
 		"OLLAMA_RUNNERS_DIR":       {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"},
+		"OLLAMA_SCHED_SPREAD":      {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"},
 		"OLLAMA_TMPDIR":            {"OLLAMA_TMPDIR", TmpDir, "Location for temporary files"},
 	}
 }
@@ -191,6 +194,15 @@ func LoadConfig() {
 		NoHistory = true
 	}
 
+	if spread := clean("OLLAMA_SCHED_SPREAD"); spread != "" {
+		s, err := strconv.ParseBool(spread)
+		if err == nil {
+			SchedSpread = s
+		} else {
+			SchedSpread = true
+		}
+	}
+
 	if noprune := clean("OLLAMA_NOPRUNE"); noprune != "" {
 		NoPrune = true
 	}

+ 131 - 75
gpu/amd_linux.go

@@ -25,7 +25,16 @@ const (
 
 	// Prefix with the node dir
 	GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
-	GPUUsedMemoryFileGlob  = "mem_banks/*/used_memory"
+
+	// Direct Rendering Manager sysfs location
+	DRMDeviceDirGlob   = "/sys/class/drm/card*/device"
+	DRMTotalMemoryFile = "mem_info_vram_total"
+	DRMUsedMemoryFile  = "mem_info_vram_used"
+
+	// In hex; properties file is in decimal
+	DRMUniqueIDFile = "unique_id"
+	DRMVendorFile   = "vendor"
+	DRMDeviceFile   = "device"
 )
 
 var (
@@ -35,8 +44,8 @@ var (
 )
 
 // Gather GPU information from the amdgpu driver if any supported GPUs are detected
-func AMDGetGPUInfo() []GpuInfo {
-	resp := []GpuInfo{}
+func AMDGetGPUInfo() []RocmGPUInfo {
+	resp := []RocmGPUInfo{}
 	if !AMDDetected() {
 		return resp
 	}
@@ -90,7 +99,7 @@ func AMDGetGPUInfo() []GpuInfo {
 		scanner := bufio.NewScanner(fp)
 		isCPU := false
 		var major, minor, patch uint64
-		var vendor, device uint64
+		var vendor, device, uniqueID uint64
 		for scanner.Scan() {
 			line := strings.TrimSpace(scanner.Text())
 			// Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
@@ -121,30 +130,43 @@ func AMDGetGPUInfo() []GpuInfo {
 			} else if strings.HasPrefix(line, "vendor_id") {
 				ver := strings.Fields(line)
 				if len(ver) != 2 {
-					slog.Debug("malformed vendor_id", "vendor_id", line)
+					slog.Debug("malformed", "vendor_id", line)
 					continue
 				}
-				vendor, err = strconv.ParseUint(ver[1], 10, 32)
+				vendor, err = strconv.ParseUint(ver[1], 10, 64)
 				if err != nil {
-					slog.Debug("malformed vendor_id" + line)
+					slog.Debug("malformed", "vendor_id", line, "error", err)
 				}
 			} else if strings.HasPrefix(line, "device_id") {
 				ver := strings.Fields(line)
 				if len(ver) != 2 {
-					slog.Debug("malformed device_id", "device_id", line)
+					slog.Debug("malformed", "device_id", line)
+					continue
+				}
+				device, err = strconv.ParseUint(ver[1], 10, 64)
+				if err != nil {
+					slog.Debug("malformed", "device_id", line, "error", err)
+				}
+			} else if strings.HasPrefix(line, "unique_id") {
+				ver := strings.Fields(line)
+				if len(ver) != 2 {
+					slog.Debug("malformed", "unique_id", line)
 					continue
 				}
-				device, err = strconv.ParseUint(ver[1], 10, 32)
+				uniqueID, err = strconv.ParseUint(ver[1], 10, 64)
 				if err != nil {
-					slog.Debug("malformed device_id" + line)
+					slog.Debug("malformed", "unique_id", line, "error", err)
 				}
 			}
-
 			// TODO - any other properties we want to extract and record?
 			// vendor_id + device_id -> pci lookup for "Name"
 			// Other metrics that may help us understand relative performance between multiple GPUs
 		}
 
+		// Note: while ./mem_banks/*/used_memory exists, it doesn't appear to take other VRAM consumers
+		// into consideration, so we instead map the device over to the DRM driver sysfs nodes which
+		// do reliably report VRAM usage.
+
 		if isCPU {
 			cpuCount++
 			continue
@@ -156,7 +178,7 @@ func AMDGetGPUInfo() []GpuInfo {
 		// Shouldn't happen, but just in case...
 		if gpuID < 0 {
 			slog.Error("unexpected amdgpu sysfs data resulted in negative GPU ID, please set OLLAMA_DEBUG=1 and report an issue")
-			return []GpuInfo{}
+			return nil
 		}
 
 		if int(major) < RocmComputeMin {
@@ -167,65 +189,68 @@ func AMDGetGPUInfo() []GpuInfo {
 		// Look up the memory for the current node
 		totalMemory := uint64(0)
 		usedMemory := uint64(0)
-		propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUTotalMemoryFileGlob)
-		propFiles, err := filepath.Glob(propGlob)
-		if err != nil {
-			slog.Warn("error looking up total GPU memory", "glob", propGlob, "error", err)
+		var usedFile string
+		mapping := []struct {
+			id       uint64
+			filename string
+		}{
+			{vendor, DRMVendorFile},
+			{device, DRMDeviceFile},
+			{uniqueID, DRMUniqueIDFile}, // Not all devices will report this
 		}
-		// 1 or more memory banks - sum the values of all of them
-		for _, propFile := range propFiles {
-			fp, err := os.Open(propFile)
-			if err != nil {
-				slog.Warn("failed to open sysfs node", "file", propFile, "erroir", err)
-				continue
-			}
-			defer fp.Close()
-			scanner := bufio.NewScanner(fp)
-			for scanner.Scan() {
-				line := strings.TrimSpace(scanner.Text())
-				if strings.HasPrefix(line, "size_in_bytes") {
-					ver := strings.Fields(line)
-					if len(ver) != 2 {
-						slog.Warn("malformed " + line)
-						continue
-					}
-					bankSizeInBytes, err := strconv.ParseUint(ver[1], 10, 64)
-					if err != nil {
-						slog.Warn("malformed int " + line)
-						continue
-					}
-					totalMemory += bankSizeInBytes
+		slog.Debug("mapping amdgpu to drm sysfs nodes", "amdgpu", match, "vendor", vendor, "device", device, "unique_id", uniqueID)
+		// Map over to DRM location to find the total/free memory
+		drmMatches, _ := filepath.Glob(DRMDeviceDirGlob)
+		for _, devDir := range drmMatches {
+			matched := true
+			for _, m := range mapping {
+				if m.id == 0 {
+					// Null ID means it didn't populate, so we can't use it to match
+					continue
+				}
+				filename := filepath.Join(devDir, m.filename)
+				buf, err := os.ReadFile(filename)
+				if err != nil {
+					slog.Debug("failed to read sysfs node", "file", filename, "error", err)
+					matched = false
+					break
+				}
+				// values here are in hex, strip off the lead 0x and parse so we can compare the numeric (decimal) values in amdgpu
+				cmp, err := strconv.ParseUint(strings.TrimPrefix(strings.TrimSpace(string(buf)), "0x"), 16, 64)
+				if err != nil {
+					slog.Debug("failed to parse sysfs node", "file", filename, "error", err)
+					matched = false
+					break
+				}
+				if cmp != m.id {
+					matched = false
+					break
 				}
 			}
-		}
-		if totalMemory == 0 {
-			slog.Warn("amdgpu reports zero total memory", "gpu", gpuID)
-			continue
-		}
-		usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUUsedMemoryFileGlob)
-		usedFiles, err := filepath.Glob(usedGlob)
-		if err != nil {
-			slog.Warn("error looking up used GPU memory", "glob", usedGlob, "error", err)
-			continue
-		}
-		for _, usedFile := range usedFiles {
-			fp, err := os.Open(usedFile)
-			if err != nil {
-				slog.Warn("failed to open sysfs node", "file", usedFile, "error", err)
+			if !matched {
 				continue
 			}
-			defer fp.Close()
-			data, err := io.ReadAll(fp)
+
+			// Found the matching DRM directory
+			slog.Debug("matched", "amdgpu", match, "drm", devDir)
+			totalFile := filepath.Join(devDir, DRMTotalMemoryFile)
+			buf, err := os.ReadFile(totalFile)
 			if err != nil {
-				slog.Warn("failed to read sysfs node", "file", usedFile, "error", err)
-				continue
+				slog.Debug("failed to read sysfs node", "file", totalFile, "error", err)
+				break
 			}
-			used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
+			totalMemory, err = strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64)
 			if err != nil {
-				slog.Warn("malformed used memory", "data", string(data), "error", err)
-				continue
+				slog.Debug("failed to parse sysfs node", "file", totalFile, "error", err)
+				break
+			}
+
+			usedFile = filepath.Join(devDir, DRMUsedMemoryFile)
+			usedMemory, err = getFreeMemory(usedFile)
+			if err != nil {
+				slog.Debug("failed to update used memory", "error", err)
 			}
-			usedMemory += used
+			break
 		}
 
 		// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
@@ -241,18 +266,21 @@ func AMDGetGPUInfo() []GpuInfo {
 
 		slog.Debug("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
 		slog.Debug("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
-		gpuInfo := GpuInfo{
-			Library: "rocm",
-			memInfo: memInfo{
-				TotalMemory: totalMemory,
-				FreeMemory:  (totalMemory - usedMemory),
+		gpuInfo := RocmGPUInfo{
+			GpuInfo: GpuInfo{
+				Library: "rocm",
+				memInfo: memInfo{
+					TotalMemory: totalMemory,
+					FreeMemory:  (totalMemory - usedMemory),
+				},
+				ID:            strconv.Itoa(gpuID),
+				Name:          name,
+				Compute:       fmt.Sprintf("gfx%d%x%x", major, minor, patch),
+				MinimumMemory: rocmMinimumMemory,
+				DriverMajor:   driverMajor,
+				DriverMinor:   driverMinor,
 			},
-			ID:            fmt.Sprintf("%d", gpuID),
-			Name:          name,
-			Compute:       fmt.Sprintf("gfx%d%x%x", major, minor, patch),
-			MinimumMemory: rocmMinimumMemory,
-			DriverMajor:   driverMajor,
-			DriverMinor:   driverMinor,
+			usedFilepath: usedFile,
 		}
 
 		// If the user wants to filter to a subset of devices, filter out if we aren't a match
@@ -276,7 +304,7 @@ func AMDGetGPUInfo() []GpuInfo {
 			libDir, err = AMDValidateLibDir()
 			if err != nil {
 				slog.Warn("unable to verify rocm library, will use cpu", "error", err)
-				return []GpuInfo{}
+				return nil
 			}
 		}
 		gpuInfo.DependencyPath = libDir
@@ -287,7 +315,7 @@ func AMDGetGPUInfo() []GpuInfo {
 				supported, err = GetSupportedGFX(libDir)
 				if err != nil {
 					slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
-					return []GpuInfo{}
+					return nil
 				}
 				slog.Debug("rocm supported GPUs", "types", supported)
 			}
@@ -378,3 +406,31 @@ func AMDDriverVersion() (driverMajor, driverMinor int, err error) {
 	}
 	return driverMajor, driverMinor, nil
 }
+
+func (gpus RocmGPUInfoList) RefreshFreeMemory() error {
+	if len(gpus) == 0 {
+		return nil
+	}
+	for i := range gpus {
+		usedMemory, err := getFreeMemory(gpus[i].usedFilepath)
+		if err != nil {
+			return err
+		}
+		slog.Debug("updating rocm free memory", "gpu", gpus[i].ID, "name", gpus[i].Name, "before", format.HumanBytes2(gpus[i].FreeMemory), "now", format.HumanBytes2(gpus[i].TotalMemory-usedMemory))
+		gpus[i].FreeMemory = gpus[i].TotalMemory - usedMemory
+	}
+	return nil
+}
+
+func getFreeMemory(usedFile string) (uint64, error) {
+	buf, err := os.ReadFile(usedFile)
+	if err != nil {
+		return 0, fmt.Errorf("failed to read sysfs node %s %w", usedFile, err)
+	}
+	usedMemory, err := strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64)
+	if err != nil {
+		slog.Debug("failed to parse sysfs node", "file", usedFile, "error", err)
+		return 0, fmt.Errorf("failed to parse sysfs node %s %w", usedFile, err)
+	}
+	return usedMemory, nil
+}

+ 47 - 16
gpu/amd_windows.go

@@ -7,6 +7,7 @@ import (
 	"os"
 	"path/filepath"
 	"slices"
+	"strconv"
 	"strings"
 
 	"github.com/ollama/ollama/format"
@@ -24,8 +25,8 @@ var (
 	RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob?
 )
 
-func AMDGetGPUInfo() []GpuInfo {
-	resp := []GpuInfo{}
+func AMDGetGPUInfo() []RocmGPUInfo {
+	resp := []RocmGPUInfo{}
 	hl, err := NewHipLib()
 	if err != nil {
 		slog.Debug(err.Error())
@@ -117,21 +118,24 @@ func AMDGetGPUInfo() []GpuInfo {
 		// v5.7 only reports VRAM used by this process, so it's completely wrong and unusable
 		slog.Debug("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory))
 		slog.Debug("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory))
-		gpuInfo := GpuInfo{
-			Library: "rocm",
-			memInfo: memInfo{
-				TotalMemory: totalMemory,
-				FreeMemory:  freeMemory,
+		gpuInfo := RocmGPUInfo{
+			GpuInfo: GpuInfo{
+				Library: "rocm",
+				memInfo: memInfo{
+					TotalMemory: totalMemory,
+					FreeMemory:  freeMemory,
+				},
+				ID:             strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices
+				DependencyPath: libDir,
+				MinimumMemory:  rocmMinimumMemory,
+				Name:           name,
+				Compute:        gfx,
+
+				// TODO - this information isn't accurate on windows, so don't report it until we find the right way to retrieve
+				// DriverMajor:    driverMajor,
+				// DriverMinor:    driverMinor,
 			},
-			ID:             fmt.Sprintf("%d", i), // TODO this is probably wrong if we specify visible devices
-			DependencyPath: libDir,
-			MinimumMemory:  rocmMinimumMemory,
-			Name:           name,
-			Compute:        gfx,
-
-			// TODO - this information isn't accurate on windows, so don't report it until we find the right way to retrieve
-			// DriverMajor:    driverMajor,
-			// DriverMinor:    driverMinor,
+			index: i,
 		}
 
 		resp = append(resp, gpuInfo)
@@ -159,3 +163,30 @@ func AMDValidateLibDir() (string, error) {
 	slog.Warn("amdgpu detected, but no compatible rocm library found.  Please install ROCm")
 	return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
 }
+
+func (gpus RocmGPUInfoList) RefreshFreeMemory() error {
+	if len(gpus) == 0 {
+		return nil
+	}
+	hl, err := NewHipLib()
+	if err != nil {
+		slog.Debug(err.Error())
+		return nil
+	}
+	defer hl.Release()
+
+	for i := range gpus {
+		err := hl.HipSetDevice(gpus[i].index)
+		if err != nil {
+			return err
+		}
+		freeMemory, _, err := hl.HipMemGetInfo()
+		if err != nil {
+			slog.Warn("get mem info", "id", i, "error", err)
+			continue
+		}
+		slog.Debug("updating rocm free memory", "gpu", gpus[i].ID, "name", gpus[i].Name, "before", format.HumanBytes2(gpus[i].FreeMemory), "now", format.HumanBytes2(freeMemory))
+		gpus[i].FreeMemory = freeMemory
+	}
+	return nil
+}

+ 4 - 9
gpu/cpu_common.go

@@ -1,21 +1,16 @@
 package gpu
 
 import (
-	"log/slog"
-
 	"golang.org/x/sys/cpu"
 )
 
-func GetCPUVariant() string {
+func GetCPUCapability() CPUCapability {
 	if cpu.X86.HasAVX2 {
-		slog.Debug("CPU has AVX2")
-		return "avx2"
+		return CPUCapabilityAVX2
 	}
 	if cpu.X86.HasAVX {
-		slog.Debug("CPU has AVX")
-		return "avx"
+		return CPUCapabilityAVX
 	}
-	slog.Debug("CPU does not have vector extensions")
 	// else LCD
-	return ""
+	return CPUCapabilityNone
 }

+ 334 - 160
gpu/gpu.go

@@ -24,19 +24,37 @@ import (
 	"github.com/ollama/ollama/format"
 )
 
-type handles struct {
+type cudaHandles struct {
 	deviceCount int
 	cudart      *C.cudart_handle_t
 	nvcuda      *C.nvcuda_handle_t
+	nvml        *C.nvml_handle_t
+}
+
+type oneapiHandles struct {
 	oneapi      *C.oneapi_handle_t
+	deviceCount int
 }
 
 const (
 	cudaMinimumMemory = 457 * format.MebiByte
 	rocmMinimumMemory = 457 * format.MebiByte
+	// TODO OneAPI minimum memory
 )
 
-var gpuMutex sync.Mutex
+var (
+	gpuMutex      sync.Mutex
+	bootstrapped  bool
+	cpuCapability CPUCapability
+	cpus          []CPUInfo
+	cudaGPUs      []CudaGPUInfo
+	nvcudaLibPath string
+	cudartLibPath string
+	oneapiLibPath string
+	nvmlLibPath   string
+	rocmGPUs      []RocmGPUInfo
+	oneapiGPUs    []OneapiGPUInfo
+)
 
 // With our current CUDA compile flags, older than 5.0 will not work properly
 var CudaComputeMin = [2]C.int{5, 0}
@@ -46,113 +64,113 @@ var RocmComputeMin = 9
 // TODO find a better way to detect iGPU instead of minimum memory
 const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU
 
-var CudartLinuxGlobs = []string{
-	"/usr/local/cuda/lib64/libcudart.so*",
-	"/usr/lib/x86_64-linux-gnu/nvidia/current/libcudart.so*",
-	"/usr/lib/x86_64-linux-gnu/libcudart.so*",
-	"/usr/lib/wsl/lib/libcudart.so*",
-	"/usr/lib/wsl/drivers/*/libcudart.so*",
-	"/opt/cuda/lib64/libcudart.so*",
-	"/usr/local/cuda*/targets/aarch64-linux/lib/libcudart.so*",
-	"/usr/lib/aarch64-linux-gnu/nvidia/current/libcudart.so*",
-	"/usr/lib/aarch64-linux-gnu/libcudart.so*",
-	"/usr/local/cuda/lib*/libcudart.so*",
-	"/usr/lib*/libcudart.so*",
-	"/usr/local/lib*/libcudart.so*",
-}
-
-var CudartWindowsGlobs = []string{
-	"c:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*\\bin\\cudart64_*.dll",
-}
-
-var NvcudaLinuxGlobs = []string{
-	"/usr/local/cuda*/targets/*/lib/libcuda.so*",
-	"/usr/lib/*-linux-gnu/nvidia/current/libcuda.so*",
-	"/usr/lib/*-linux-gnu/libcuda.so*",
-	"/usr/lib/wsl/lib/libcuda.so*",
-	"/usr/lib/wsl/drivers/*/libcuda.so*",
-	"/opt/cuda/lib*/libcuda.so*",
-	"/usr/local/cuda/lib*/libcuda.so*",
-	"/usr/lib*/libcuda.so*",
-	"/usr/local/lib*/libcuda.so*",
-}
-
-var NvcudaWindowsGlobs = []string{
-	"c:\\windows\\system*\\nvcuda.dll",
-}
-
-var OneapiWindowsGlobs = []string{
-	"c:\\Windows\\System32\\DriverStore\\FileRepository\\*\\ze_intel_gpu64.dll",
-}
-
-var OneapiLinuxGlobs = []string{
-	"/usr/lib/x86_64-linux-gnu/libze_intel_gpu.so*",
-	"/usr/lib*/libze_intel_gpu.so*",
-}
-
 // Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed.
 // Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
 var CudaTegra string = os.Getenv("JETSON_JETPACK")
 
 // Note: gpuMutex must already be held
-func initGPUHandles() *handles {
+func initCudaHandles() *cudaHandles {
 
 	// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
 
-	gpuHandles := &handles{}
-	var cudartMgmtName string
+	cHandles := &cudaHandles{}
+	// Short Circuit if we already know which library to use
+	if nvmlLibPath != "" {
+		cHandles.nvml, _ = LoadNVMLMgmt([]string{nvmlLibPath})
+		return cHandles
+	}
+	if nvcudaLibPath != "" {
+		cHandles.deviceCount, cHandles.nvcuda, _ = LoadNVCUDAMgmt([]string{nvcudaLibPath})
+		return cHandles
+	}
+	if cudartLibPath != "" {
+		cHandles.deviceCount, cHandles.cudart, _ = LoadCUDARTMgmt([]string{cudartLibPath})
+		return cHandles
+	}
+
+	slog.Debug("searching for GPU discovery libraries for NVIDIA")
 	var cudartMgmtPatterns []string
-	var nvcudaMgmtName string
-	var nvcudaMgmtPatterns []string
 
-	tmpDir, _ := PayloadsDir()
-	switch runtime.GOOS {
-	case "windows":
-		cudartMgmtName = "cudart64_*.dll"
+	// Aligned with driver, we can't carry as payloads
+	nvcudaMgmtPatterns := NvcudaGlobs
+
+	if runtime.GOOS == "windows" {
 		localAppData := os.Getenv("LOCALAPPDATA")
-		cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)}
-		cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...)
-		// Aligned with driver, we can't carry as payloads
-		nvcudaMgmtName = "nvcuda.dll"
-		nvcudaMgmtPatterns = NvcudaWindowsGlobs
-	case "linux":
-		cudartMgmtName = "libcudart.so*"
-		if tmpDir != "" {
-			// TODO - add "payloads" for subprocess
-			cudartMgmtPatterns = []string{filepath.Join(tmpDir, "cuda*", cudartMgmtName)}
+		cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", CudartMgmtName)}
+	}
+	tmpDir, _ := PayloadsDir()
+	if tmpDir != "" {
+		// TODO - add "payloads" for subprocess
+		cudartMgmtPatterns = []string{filepath.Join(tmpDir, "cuda*", CudartMgmtName)}
+	}
+	cudartMgmtPatterns = append(cudartMgmtPatterns, CudartGlobs...)
+
+	if len(NvmlGlobs) > 0 {
+		nvmlLibPaths := FindGPULibs(NvmlMgmtName, NvmlGlobs)
+		if len(nvmlLibPaths) > 0 {
+			nvml, libPath := LoadNVMLMgmt(nvmlLibPaths)
+			if nvml != nil {
+				slog.Debug("nvidia-ml loaded", "library", libPath)
+				cHandles.nvml = nvml
+				nvmlLibPath = libPath
+			}
 		}
-		cudartMgmtPatterns = append(cudartMgmtPatterns, CudartLinuxGlobs...)
-		// Aligned with driver, we can't carry as payloads
-		nvcudaMgmtName = "libcuda.so*"
-		nvcudaMgmtPatterns = NvcudaLinuxGlobs
-	default:
-		return gpuHandles
 	}
 
-	slog.Debug("Detecting GPUs")
-	nvcudaLibPaths := FindGPULibs(nvcudaMgmtName, nvcudaMgmtPatterns)
+	nvcudaLibPaths := FindGPULibs(NvcudaMgmtName, nvcudaMgmtPatterns)
 	if len(nvcudaLibPaths) > 0 {
 		deviceCount, nvcuda, libPath := LoadNVCUDAMgmt(nvcudaLibPaths)
 		if nvcuda != nil {
 			slog.Debug("detected GPUs", "count", deviceCount, "library", libPath)
-			gpuHandles.nvcuda = nvcuda
-			gpuHandles.deviceCount = deviceCount
-			return gpuHandles
+			cHandles.nvcuda = nvcuda
+			cHandles.deviceCount = deviceCount
+			nvcudaLibPath = libPath
+			return cHandles
 		}
 	}
 
-	cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns)
+	cudartLibPaths := FindGPULibs(CudartMgmtName, cudartMgmtPatterns)
 	if len(cudartLibPaths) > 0 {
 		deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths)
 		if cudart != nil {
 			slog.Debug("detected GPUs", "library", libPath, "count", deviceCount)
-			gpuHandles.cudart = cudart
-			gpuHandles.deviceCount = deviceCount
-			return gpuHandles
+			cHandles.cudart = cudart
+			cHandles.deviceCount = deviceCount
+			cudartLibPath = libPath
+			return cHandles
 		}
 	}
 
-	return gpuHandles
+	return cHandles
+}
+
+// Note: gpuMutex must already be held
+func initOneAPIHandles() *oneapiHandles {
+	oHandles := &oneapiHandles{}
+
+	// Short Circuit if we already know which library to use
+	if oneapiLibPath != "" {
+		oHandles.deviceCount, oHandles.oneapi, _ = LoadOneapiMgmt([]string{oneapiLibPath})
+		return oHandles
+	}
+
+	oneapiLibPaths := FindGPULibs(OneapiMgmtName, OneapiGlobs)
+	if len(oneapiLibPaths) > 0 {
+		oHandles.deviceCount, oHandles.oneapi, oneapiLibPath = LoadOneapiMgmt(oneapiLibPaths)
+	}
+
+	return oHandles
+}
+
+func GetCPUInfo() GpuInfoList {
+	gpuMutex.Lock()
+	if !bootstrapped {
+		gpuMutex.Unlock()
+		GetGPUInfo()
+	} else {
+		gpuMutex.Unlock()
+	}
+	return GpuInfoList{cpus[0].GpuInfo}
 }
 
 func GetGPUInfo() GpuInfoList {
@@ -160,110 +178,245 @@ func GetGPUInfo() GpuInfoList {
 	// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
 	gpuMutex.Lock()
 	defer gpuMutex.Unlock()
-
-	gpuHandles := initGPUHandles()
+	needRefresh := true
+	var cHandles *cudaHandles
+	var oHandles *oneapiHandles
 	defer func() {
-		if gpuHandles.cudart != nil {
-			C.cudart_release(*gpuHandles.cudart)
+		if cHandles != nil {
+			if cHandles.cudart != nil {
+				C.cudart_release(*cHandles.cudart)
+			}
+			if cHandles.nvcuda != nil {
+				C.nvcuda_release(*cHandles.nvcuda)
+			}
+			if cHandles.nvml != nil {
+				C.nvml_release(*cHandles.nvml)
+			}
 		}
-		if gpuHandles.nvcuda != nil {
-			C.nvcuda_release(*gpuHandles.nvcuda)
+		if oHandles != nil {
+			if oHandles.oneapi != nil {
+				// TODO - is this needed?
+				C.oneapi_release(*oHandles.oneapi)
+			}
 		}
 	}()
 
-	// All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX
-	cpuVariant := GetCPUVariant()
-	if cpuVariant == "" && runtime.GOARCH == "amd64" {
-		slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.")
-	}
+	if !bootstrapped {
+		slog.Debug("Detecting GPUs")
+		needRefresh = false
+		cpuCapability = GetCPUCapability()
+		var memInfo C.mem_info_t
 
-	// On windows we bundle the nvidia library one level above the runner dir
-	depPath := ""
-	if runtime.GOOS == "windows" && envconfig.RunnersDir != "" {
-		depPath = filepath.Dir(envconfig.RunnersDir)
-	}
+		mem, err := GetCPUMem()
+		if err != nil {
+			slog.Warn("error looking up system memory", "error", err)
+		}
+		cpus = []CPUInfo{CPUInfo{
+			GpuInfo: GpuInfo{
+				memInfo: mem,
+				Library: "cpu",
+				Variant: cpuCapability,
+				ID:      "0",
+			},
+		}}
+
+		// Fallback to CPU mode if we're lacking required vector extensions on x86
+		if cpuCapability < GPURunnerCPUCapability && runtime.GOARCH == "amd64" {
+			slog.Warn("CPU does not have minimum vector extensions, GPU inference disabled", "required", GPURunnerCPUCapability, "detected", cpuCapability)
+			bootstrapped = true
+			// No need to do any GPU discovery, since we can't run on them
+			return GpuInfoList{cpus[0].GpuInfo}
+		}
 
-	var memInfo C.mem_info_t
-	resp := []GpuInfo{}
+		// On windows we bundle the nvidia library one level above the runner dir
+		depPath := ""
+		if runtime.GOOS == "windows" && envconfig.RunnersDir != "" {
+			depPath = filepath.Dir(envconfig.RunnersDir)
+		}
 
-	// NVIDIA first
-	for i := range gpuHandles.deviceCount {
-		// TODO once we support CPU compilation variants of GPU libraries refine this...
-		if cpuVariant == "" && runtime.GOARCH == "amd64" {
-			continue
+		// Load ALL libraries
+		cHandles = initCudaHandles()
+
+		// NVIDIA
+		for i := range cHandles.deviceCount {
+			if cHandles.cudart != nil || cHandles.nvcuda != nil {
+				gpuInfo := CudaGPUInfo{
+					GpuInfo: GpuInfo{
+						Library: "cuda",
+					},
+					index: i,
+				}
+				var driverMajor int
+				var driverMinor int
+				if cHandles.cudart != nil {
+					C.cudart_bootstrap(*cHandles.cudart, C.int(i), &memInfo)
+				} else {
+					C.nvcuda_bootstrap(*cHandles.nvcuda, C.int(i), &memInfo)
+					driverMajor = int(cHandles.nvcuda.driver_major)
+					driverMinor = int(cHandles.nvcuda.driver_minor)
+				}
+				if memInfo.err != nil {
+					slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
+					C.free(unsafe.Pointer(memInfo.err))
+					continue
+				}
+				if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) {
+					slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
+					continue
+				}
+				gpuInfo.TotalMemory = uint64(memInfo.total)
+				gpuInfo.FreeMemory = uint64(memInfo.free)
+				gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
+				gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor)
+				gpuInfo.MinimumMemory = cudaMinimumMemory
+				gpuInfo.DependencyPath = depPath
+				gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
+				gpuInfo.DriverMajor = driverMajor
+				gpuInfo.DriverMinor = driverMinor
+
+				// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
+				cudaGPUs = append(cudaGPUs, gpuInfo)
+			}
 		}
-		if gpuHandles.cudart != nil || gpuHandles.nvcuda != nil {
-			gpuInfo := GpuInfo{
-				Library: "cuda",
+
+		// Intel
+		oHandles = initOneAPIHandles()
+		for d := 0; oHandles.oneapi != nil && d < int(oHandles.oneapi.num_drivers); d++ {
+			if oHandles.oneapi == nil {
+				// shouldn't happen
+				slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers))
+				continue
+			}
+			devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d))
+			for i := range devCount {
+				gpuInfo := OneapiGPUInfo{
+					GpuInfo: GpuInfo{
+						Library: "oneapi",
+					},
+					driverIndex: d,
+					gpuIndex:    int(i),
+				}
+				// TODO - split bootstrapping from updating free memory
+				C.oneapi_check_vram(*oHandles.oneapi, C.int(d), i, &memInfo)
+				// TODO - convert this to MinimumMemory based on testing...
+				var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
+				memInfo.free = C.uint64_t(totalFreeMem)
+				gpuInfo.TotalMemory = uint64(memInfo.total)
+				gpuInfo.FreeMemory = uint64(memInfo.free)
+				gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
+				gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
+				// TODO dependency path?
+				oneapiGPUs = append(oneapiGPUs, gpuInfo)
 			}
-			var driverMajor int
-			var driverMinor int
-			if gpuHandles.cudart != nil {
-				C.cudart_check_vram(*gpuHandles.cudart, C.int(i), &memInfo)
+		}
+
+		rocmGPUs = AMDGetGPUInfo()
+		bootstrapped = true
+	}
+
+	// For detected GPUs, load library if not loaded
+
+	// Refresh free memory usage
+	if needRefresh {
+		mem, err := GetCPUMem()
+		if err != nil {
+			slog.Warn("error looking up system memory", "error", err)
+		} else {
+			slog.Debug("updating system memory data",
+				slog.Group(
+					"before",
+					"total", format.HumanBytes2(cpus[0].TotalMemory),
+					"free", format.HumanBytes2(cpus[0].FreeMemory),
+				),
+				slog.Group(
+					"now",
+					"total", format.HumanBytes2(mem.TotalMemory),
+					"free", format.HumanBytes2(mem.FreeMemory),
+				),
+			)
+			cpus[0].FreeMemory = mem.FreeMemory
+		}
+
+		var memInfo C.mem_info_t
+		if cHandles == nil && len(cudaGPUs) > 0 {
+			cHandles = initCudaHandles()
+		}
+		for i, gpu := range cudaGPUs {
+			if cHandles.nvml != nil {
+				C.nvml_get_free(*cHandles.nvml, C.int(gpu.index), &memInfo.free, &memInfo.total, &memInfo.used)
+			} else if cHandles.cudart != nil {
+				C.cudart_bootstrap(*cHandles.cudart, C.int(gpu.index), &memInfo)
+			} else if cHandles.nvcuda != nil {
+				C.nvcuda_get_free(*cHandles.nvcuda, C.int(gpu.index), &memInfo.free, &memInfo.total)
+				memInfo.used = memInfo.total - memInfo.free
 			} else {
-				C.nvcuda_check_vram(*gpuHandles.nvcuda, C.int(i), &memInfo)
-				driverMajor = int(gpuHandles.nvcuda.driver_major)
-				driverMinor = int(gpuHandles.nvcuda.driver_minor)
+				// shouldn't happen
+				slog.Warn("no valid cuda library loaded to refresh vram usage")
+				break
 			}
 			if memInfo.err != nil {
-				slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
+				slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
 				C.free(unsafe.Pointer(memInfo.err))
 				continue
 			}
-			if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) {
-				slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
+			if memInfo.free == 0 {
+				slog.Warn("error looking up nvidia GPU memory")
 				continue
 			}
-			gpuInfo.TotalMemory = uint64(memInfo.total)
-			gpuInfo.FreeMemory = uint64(memInfo.free)
-			gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
-			gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor)
-			gpuInfo.MinimumMemory = cudaMinimumMemory
-			gpuInfo.DependencyPath = depPath
-			gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
-			gpuInfo.DriverMajor = driverMajor
-			gpuInfo.DriverMinor = driverMinor
-
-			// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
-			resp = append(resp, gpuInfo)
+			slog.Debug("updating cuda memory data",
+				"gpu", gpu.ID,
+				"name", gpu.Name,
+				slog.Group(
+					"before",
+					"total", format.HumanBytes2(gpu.TotalMemory),
+					"free", format.HumanBytes2(gpu.FreeMemory),
+				),
+				slog.Group(
+					"now",
+					"total", format.HumanBytes2(uint64(memInfo.total)),
+					"free", format.HumanBytes2(uint64(memInfo.free)),
+					"used", format.HumanBytes2(uint64(memInfo.used)),
+				),
+			)
+			cudaGPUs[i].FreeMemory = uint64(memInfo.free)
 		}
-	}
-
-	// Then AMD
-	resp = append(resp, AMDGetGPUInfo()...)
 
-	if len(resp) == 0 {
-		C.cpu_check_ram(&memInfo)
-		if memInfo.err != nil {
-			slog.Info("error looking up CPU memory", "error", C.GoString(memInfo.err))
-			C.free(unsafe.Pointer(memInfo.err))
-			return resp
+		if oHandles == nil && len(oneapiGPUs) > 0 {
+			oHandles = initOneAPIHandles()
 		}
-		gpuInfo := GpuInfo{
-			Library: "cpu",
-			Variant: cpuVariant,
+		for i, gpu := range oneapiGPUs {
+			if oHandles.oneapi == nil {
+				// shouldn't happen
+				slog.Warn("nil oneapi handle with device count", "count", oHandles.deviceCount)
+				continue
+			}
+			C.oneapi_check_vram(*oHandles.oneapi, C.int(gpu.driverIndex), C.int(gpu.gpuIndex), &memInfo)
+			// TODO - convert this to MinimumMemory based on testing...
+			var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
+			memInfo.free = C.uint64_t(totalFreeMem)
+			oneapiGPUs[i].FreeMemory = uint64(memInfo.free)
 		}
-		gpuInfo.TotalMemory = uint64(memInfo.total)
-		gpuInfo.FreeMemory = uint64(memInfo.free)
-		gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
 
-		resp = append(resp, gpuInfo)
+		err = RocmGPUInfoList(rocmGPUs).RefreshFreeMemory()
+		if err != nil {
+			slog.Debug("problem refreshing ROCm free memory", "error", err)
+		}
 	}
 
-	return resp
-}
-
-func GetCPUMem() (memInfo, error) {
-	var ret memInfo
-	var info C.mem_info_t
-	C.cpu_check_ram(&info)
-	if info.err != nil {
-		defer C.free(unsafe.Pointer(info.err))
-		return ret, fmt.Errorf(C.GoString(info.err))
+	resp := []GpuInfo{}
+	for _, gpu := range cudaGPUs {
+		resp = append(resp, gpu.GpuInfo)
+	}
+	for _, gpu := range rocmGPUs {
+		resp = append(resp, gpu.GpuInfo)
+	}
+	for _, gpu := range oneapiGPUs {
+		resp = append(resp, gpu.GpuInfo)
+	}
+	if len(resp) == 0 {
+		resp = append(resp, cpus[0].GpuInfo)
 	}
-	ret.FreeMemory = uint64(info.free)
-	ret.TotalMemory = uint64(info.total)
-	return ret, nil
+	return resp
 }
 
 func FindGPULibs(baseLibName string, defaultPatterns []string) []string {
@@ -362,8 +515,26 @@ func LoadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string) {
 	return 0, nil, ""
 }
 
+func LoadNVMLMgmt(nvmlLibPaths []string) (*C.nvml_handle_t, string) {
+	var resp C.nvml_init_resp_t
+	resp.ch.verbose = getVerboseState()
+	for _, libPath := range nvmlLibPaths {
+		lib := C.CString(libPath)
+		defer C.free(unsafe.Pointer(lib))
+		C.nvml_init(lib, &resp)
+		if resp.err != nil {
+			slog.Info(fmt.Sprintf("Unable to load NVML management library %s: %s", libPath, C.GoString(resp.err)))
+			C.free(unsafe.Pointer(resp.err))
+		} else {
+			return &resp.ch, libPath
+		}
+	}
+	return nil, ""
+}
+
 func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) {
 	var resp C.oneapi_init_resp_t
+	num_devices := 0
 	resp.oh.verbose = getVerboseState()
 	for _, libPath := range oneapiLibPaths {
 		lib := C.CString(libPath)
@@ -373,7 +544,10 @@ func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) {
 			slog.Debug("Unable to load oneAPI management library", "library", libPath, "error", C.GoString(resp.err))
 			C.free(unsafe.Pointer(resp.err))
 		} else {
-			return int(resp.num_devices), &resp.oh, libPath
+			for i := range resp.oh.num_drivers {
+				num_devices += int(C.oneapi_get_device_count(resp.oh, C.int(i)))
+			}
+			return num_devices, &resp.oh, libPath
 		}
 	}
 	return 0, nil, ""

+ 12 - 1
gpu/gpu_darwin.go

@@ -24,7 +24,7 @@ func GetGPUInfo() GpuInfoList {
 		return []GpuInfo{
 			{
 				Library: "cpu",
-				Variant: GetCPUVariant(),
+				Variant: GetCPUCapability(),
 				memInfo: mem,
 			},
 		}
@@ -42,6 +42,17 @@ func GetGPUInfo() GpuInfoList {
 	return []GpuInfo{info}
 }
 
+func GetCPUInfo() GpuInfoList {
+	mem, _ := GetCPUMem()
+	return []GpuInfo{
+		{
+			Library: "cpu",
+			Variant: GetCPUCapability(),
+			memInfo: mem,
+		},
+	}
+}
+
 func GetCPUMem() (memInfo, error) {
 	return memInfo{
 		TotalMemory: uint64(C.getPhysicalMemory()),

+ 2 - 0
gpu/gpu_info.h

@@ -47,6 +47,7 @@ typedef struct mem_info {
   char gpu_name[GPU_NAME_LEN];
   uint64_t total;
   uint64_t free;
+  uint64_t used;
 
   // Compute Capability
   int major; 
@@ -62,6 +63,7 @@ void cpu_check_ram(mem_info_t *resp);
 
 #include "gpu_info_cudart.h"
 #include "gpu_info_nvcuda.h"
+#include "gpu_info_nvml.h"
 #include "gpu_info_oneapi.h"
 
 #endif  // __GPU_INFO_H__

+ 0 - 45
gpu/gpu_info_cpu.c

@@ -1,45 +0,0 @@
-#include "gpu_info.h"
-// Fallbacks for CPU mode
-
-#ifdef _WIN32
-#include <sysinfoapi.h>
-void cpu_check_ram(mem_info_t *resp) {
-  resp->err = NULL;
-  MEMORYSTATUSEX info;
-  info.dwLength = sizeof(info);
-  if (GlobalMemoryStatusEx(&info) != 0) {
-    resp->total = info.ullTotalPhys;
-    resp->free = info.ullAvailPhys;
-    snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
-  } else {
-    resp->err = LOAD_ERR();
-  }
-  return;
-}
-
-#elif __linux__
-#include <errno.h>
-#include <string.h>
-#include <sys/sysinfo.h>
-void cpu_check_ram(mem_info_t *resp) {
-  struct sysinfo info;
-  resp->err = NULL;
-  if (sysinfo(&info) != 0) {
-    resp->err = strdup(strerror(errno));
-  } else {
-    resp->total = info.totalram * info.mem_unit;
-    resp->free = info.freeram * info.mem_unit;
-    snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
-  }
-  return;
-}
-
-#elif __APPLE__
-// TODO consider an Apple implementation that does something useful
-// mem_info_t cpu_check_ram() {
-//   mem_info_t resp = {0, 0, NULL};
-//   return resp;
-// }
-#else
-#error "Unsupported platform"
-#endif

+ 3 - 1
gpu/gpu_info_cudart.c

@@ -94,7 +94,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
 }
 
 
-void cudart_check_vram(cudart_handle_t h, int i, mem_info_t *resp) {
+void cudart_bootstrap(cudart_handle_t h, int i, mem_info_t *resp) {
   resp->err = NULL;
   cudartMemory_t memInfo = {0,0,0};
   cudartReturn_t ret;
@@ -166,9 +166,11 @@ void cudart_check_vram(cudart_handle_t h, int i, mem_info_t *resp) {
 
   resp->total = memInfo.total;
   resp->free = memInfo.free;
+  resp->used = memInfo.used;
 
   LOG(h.verbose, "[%s] CUDA totalMem %lu\n", resp->gpu_id, resp->total);
   LOG(h.verbose, "[%s] CUDA freeMem %lu\n", resp->gpu_id, resp->free);
+  LOG(h.verbose, "[%s] CUDA usedMem %lu\n", resp->gpu_id, resp->used);
   LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
 }
 

+ 2 - 1
gpu/gpu_info_cudart.h

@@ -140,7 +140,8 @@ typedef struct cudart_init_resp {
 } cudart_init_resp_t;
 
 void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp);
-void cudart_check_vram(cudart_handle_t ch, int device_id, mem_info_t *resp);
+void cudart_bootstrap(cudart_handle_t ch, int device_id, mem_info_t *resp);
+// TODO - if we keep this library longer term, add cudart_get_free
 void cudart_release(cudart_handle_t ch);
 
 #endif  // __GPU_INFO_CUDART_H__

+ 38 - 3
gpu/gpu_info_nvcuda.c

@@ -96,7 +96,7 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
 }
 
 const int buflen = 256;
-void nvcuda_check_vram(nvcuda_handle_t h, int i, mem_info_t *resp) {
+void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) {
   resp->err = NULL;
   nvcudaMemory_t memInfo = {0,0};
   CUresult ret;
@@ -168,7 +168,7 @@ void nvcuda_check_vram(nvcuda_handle_t h, int i, mem_info_t *resp) {
   // To get memory we have to set (and release) a context
   ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
   if (ret != CUDA_SUCCESS) {
-    snprintf(buf, buflen, "nvcuda failed to get primary device context %d", ret);
+    snprintf(buf, buflen, "nvcuda failed to get device context %d", ret);
     resp->err = strdup(buf);
     return;
   }
@@ -193,7 +193,42 @@ void nvcuda_check_vram(nvcuda_handle_t h, int i, mem_info_t *resp) {
 
   ret = (*h.cuCtxDestroy)(ctx);
   if (ret != CUDA_SUCCESS) {
-    LOG(1, "nvcuda failed to release primary device context %d", ret);
+    LOG(1, "nvcuda failed to release device context %d", ret);
+  }
+}
+
+void nvcuda_get_free(nvcuda_handle_t h, int i, uint64_t *free, uint64_t *total) {
+  CUresult ret;
+  CUcontext ctx = NULL;
+  CUdevice device = -1;
+  *free = 0;
+  *total = 0;
+
+  ret = (*h.cuDeviceGet)(&device, i);
+  if (ret != CUDA_SUCCESS) {
+    LOG(1, "nvcuda device failed to initialize");
+    return;
+  }
+
+
+  // To get memory we have to set (and release) a context
+  ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
+  if (ret != CUDA_SUCCESS) {
+    LOG(1, "nvcuda failed to get device context %d", ret);
+    return;
+  }
+
+  ret = (*h.cuMemGetInfo_v2)(free, total);
+  if (ret != CUDA_SUCCESS) {
+    LOG(1, "nvcuda device memory info lookup failure %d", ret);
+    // Best effort on failure...
+    (*h.cuCtxDestroy)(ctx);
+    return;
+  }
+
+  ret = (*h.cuCtxDestroy)(ctx);
+  if (ret != CUDA_SUCCESS) {
+    LOG(1, "nvcuda failed to release device context %d", ret);
   }
 }
 

+ 2 - 1
gpu/gpu_info_nvcuda.h

@@ -67,7 +67,8 @@ typedef struct nvcuda_init_resp {
 } nvcuda_init_resp_t;
 
 void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp);
-void nvcuda_check_vram(nvcuda_handle_t ch, int device_id, mem_info_t *resp);
+void nvcuda_bootstrap(nvcuda_handle_t ch, int device_id, mem_info_t *resp);
+void nvcuda_get_free(nvcuda_handle_t ch,  int device_id, uint64_t *free, uint64_t *total);
 void nvcuda_release(nvcuda_handle_t ch);
 
 #endif  // __GPU_INFO_NVCUDA_H__

+ 104 - 0
gpu/gpu_info_nvml.c

@@ -0,0 +1,104 @@
+#ifndef __APPLE__  // TODO - maybe consider nvidia support on intel macs?
+
+#include <string.h>
+
+#include "gpu_info_nvml.h"
+
+void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) {
+  nvmlReturn_t ret;
+  resp->err = NULL;
+  const int buflen = 256;
+  char buf[buflen + 1];
+  int i;
+
+  struct lookup {
+    char *s;
+    void **p;
+  } l[] = {
+      {"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2},
+      {"nvmlShutdown", (void *)&resp->ch.nvmlShutdown},
+      {"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.nvmlDeviceGetHandleByIndex},
+      {"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo},
+      {NULL, NULL},
+  };
+
+  resp->ch.handle = LOAD_LIBRARY(nvml_lib_path, RTLD_LAZY);
+  if (!resp->ch.handle) {
+    char *msg = LOAD_ERR();
+    LOG(resp->ch.verbose, "library %s load err: %s\n", nvml_lib_path, msg);
+    snprintf(buf, buflen,
+             "Unable to load %s library to query for Nvidia GPUs: %s",
+             nvml_lib_path, msg);
+    free(msg);
+    resp->err = strdup(buf);
+    return;
+  }
+
+  // TODO once we've squashed the remaining corner cases remove this log
+  // LOG(resp->ch.verbose, "wiring nvidia management library functions in %s\n", nvml_lib_path);
+  
+  for (i = 0; l[i].s != NULL; i++) {
+    // TODO once we've squashed the remaining corner cases remove this log
+    // LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
+
+    *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
+    if (!l[i].p) {
+      resp->ch.handle = NULL;
+      char *msg = LOAD_ERR();
+      LOG(resp->ch.verbose, "dlerr: %s\n", msg);
+      UNLOAD_LIBRARY(resp->ch.handle);
+      snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
+               msg);
+      free(msg);
+      resp->err = strdup(buf);
+      return;
+    }
+  }
+
+  ret = (*resp->ch.nvmlInit_v2)();
+  if (ret != NVML_SUCCESS) {
+    LOG(resp->ch.verbose, "nvmlInit_v2 err: %d\n", ret);
+    UNLOAD_LIBRARY(resp->ch.handle);
+    resp->ch.handle = NULL;
+    snprintf(buf, buflen, "nvml vram init failure: %d", ret);
+    resp->err = strdup(buf);
+    return;
+  }
+}
+
+
+void nvml_get_free(nvml_handle_t h, int device_id, uint64_t *free, uint64_t *total, uint64_t *used) {
+    nvmlDevice_t device;
+    nvmlMemory_t memInfo = {0};
+    nvmlReturn_t ret;
+    ret = (*h.nvmlDeviceGetHandleByIndex)(device_id, &device);
+    if (ret != NVML_SUCCESS) {
+        LOG(1, "unable to get device handle %d: %d", device_id, ret);
+        *free = 0;
+        return;
+    }
+
+    ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo);
+    if (ret != NVML_SUCCESS) {
+        LOG(1, "device memory info lookup failure %d: %d", device_id, ret);
+        *free = 0;
+        return;
+    }
+    *free = memInfo.free;
+    *total = memInfo.total;
+    *used = memInfo.used;
+}
+
+
+void nvml_release(nvml_handle_t h) {
+  LOG(h.verbose, "releasing nvml library\n");
+  nvmlReturn_t ret;
+  ret = (*h.nvmlShutdown)();
+  if (ret != NVML_SUCCESS) {
+    LOG(1, "error during nvmlShutdown %d", ret);
+  }
+  UNLOAD_LIBRARY(h.handle);
+  h.handle = NULL;
+}
+
+#endif  // __APPLE__

+ 48 - 0
gpu/gpu_info_nvml.h

@@ -0,0 +1,48 @@
+#ifndef __APPLE__
+#ifndef __GPU_INFO_NVML_H__
+#define __GPU_INFO_NVML_H__
+#include "gpu_info.h"
+
+// Just enough typedef's to dlopen/dlsym for memory information
+typedef enum nvmlReturn_enum {
+  NVML_SUCCESS = 0,
+  // Other values omitted for now...
+} nvmlReturn_t;
+typedef void *nvmlDevice_t;  // Opaque is sufficient
+typedef struct nvmlMemory_st {
+  unsigned long long total;
+  unsigned long long free;
+  unsigned long long used;
+} nvmlMemory_t;
+
+typedef enum nvmlBrandType_enum
+{
+    NVML_BRAND_UNKNOWN          = 0,
+} nvmlBrandType_t;
+
+typedef struct nvml_handle {
+  void *handle;
+  uint16_t verbose;
+  nvmlReturn_t (*nvmlInit_v2)(void);
+  nvmlReturn_t (*nvmlShutdown)(void);
+  nvmlReturn_t (*nvmlDeviceGetHandleByIndex)(unsigned int, nvmlDevice_t *);
+  nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
+} nvml_handle_t;
+
+typedef struct nvml_init_resp {
+  char *err;  // If err is non-null handle is invalid
+  nvml_handle_t ch;
+} nvml_init_resp_t;
+
+typedef struct nvml_compute_capability {
+  char *err;
+  int major;
+  int minor;
+} nvml_compute_capability_t;
+
+void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp);
+void nvml_get_free(nvml_handle_t ch,  int device_id, uint64_t *free, uint64_t *total, uint64_t *used);
+void nvml_release(nvml_handle_t ch);
+
+#endif  // __GPU_INFO_NVML_H__
+#endif  // __APPLE__

+ 166 - 123
gpu/gpu_info_oneapi.c

@@ -4,15 +4,17 @@
 
 #include <string.h>
 
-void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp)
-{
+void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp) {
   ze_result_t ret;
   resp->err = NULL;
+  resp->oh.devices = NULL;
+  resp->oh.num_devices = NULL;
+  resp->oh.drivers = NULL;
+  resp->oh.num_drivers = 0;
   const int buflen = 256;
   char buf[buflen + 1];
-  int i;
-  struct lookup
-  {
+  int i, d, count;
+  struct lookup {
     char *s;
     void **p;
   } l[] = {
@@ -28,8 +30,7 @@ void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp)
   };
 
   resp->oh.handle = LOAD_LIBRARY(oneapi_lib_path, RTLD_LAZY);
-  if (!resp->oh.handle)
-  {
+  if (!resp->oh.handle) {
     char *msg = LOAD_ERR();
     snprintf(buf, buflen,
              "Unable to load %s library to query for Intel GPUs: %s\n",
@@ -44,14 +45,12 @@ void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp)
       "wiring Level-Zero management library functions in %s\n",
       oneapi_lib_path);
 
-  for (i = 0; l[i].s != NULL; i++)
-  {
+  for (i = 0; l[i].s != NULL; i++) {
     // TODO once we've squashed the remaining corner cases remove this log
     LOG(resp->oh.verbose, "dlsym: %s\n", l[i].s);
 
     *l[i].p = LOAD_SYMBOL(resp->oh.handle, l[i].s);
-    if (!l[i].p)
-    {
+    if (!l[i].p) {
       resp->oh.handle = NULL;
       char *msg = LOAD_ERR();
       LOG(resp->oh.verbose, "dlerr: %s\n", msg);
@@ -64,22 +63,67 @@ void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp)
   }
 
   ret = (*resp->oh.zesInit)(0);
-  if (ret != ZE_RESULT_SUCCESS)
-  {
-    LOG(resp->oh.verbose, "zesInit err: %d\n", ret);
-    UNLOAD_LIBRARY(resp->oh.handle);
-    resp->oh.handle = NULL;
-    snprintf(buf, buflen, "oneapi vram init failure: %d", ret);
+  if (ret != ZE_RESULT_SUCCESS) {
+    LOG(resp->oh.verbose, "zesInit err: %x\n", ret);
+    snprintf(buf, buflen, "oneapi vram init failure: %x", ret);
     resp->err = strdup(buf);
+    oneapi_release(resp->oh);
+    return;
   }
 
-  (*resp->oh.zesDriverGet)(&resp->num_devices, NULL);
+  count = 0;
+  ret = (*resp->oh.zesDriverGet)(&resp->oh.num_drivers, NULL);
+  if (ret != ZE_RESULT_SUCCESS) {
+    LOG(resp->oh.verbose, "zesDriverGet err: %x\n", ret);
+    snprintf(buf, buflen, "unable to get driver count: %x", ret);
+    resp->err = strdup(buf);
+    oneapi_release(resp->oh);
+    return;
+  }
+  LOG(resp->oh.verbose, "oneapi driver count: %d\n", resp->oh.num_drivers);
+  resp->oh.drivers = malloc(resp->oh.num_drivers * sizeof(zes_driver_handle_t));
+  resp->oh.num_devices = malloc(resp->oh.num_drivers * sizeof(uint32_t));
+  memset(&resp->oh.num_devices[0], 0, resp->oh.num_drivers * sizeof(uint32_t));
+  resp->oh.devices =
+      malloc(resp->oh.num_drivers * sizeof(zes_device_handle_t *));
+  ret = (*resp->oh.zesDriverGet)(&resp->oh.num_drivers, &resp->oh.drivers[0]);
+  if (ret != ZE_RESULT_SUCCESS) {
+    LOG(resp->oh.verbose, "zesDriverGet err: %x\n", ret);
+    snprintf(buf, buflen, "unable to get driver count: %x", ret);
+    resp->err = strdup(buf);
+    oneapi_release(resp->oh);
+    return;
+  }
+
+  for (d = 0; d < resp->oh.num_drivers; d++) {
+    ret = (*resp->oh.zesDeviceGet)(resp->oh.drivers[d],
+                                   &resp->oh.num_devices[d], NULL);
+    if (ret != ZE_RESULT_SUCCESS) {
+      LOG(resp->oh.verbose, "zesDeviceGet err: %x\n", ret);
+      snprintf(buf, buflen, "unable to get device count: %x", ret);
+      resp->err = strdup(buf);
+      oneapi_release(resp->oh);
+      return;
+    }
+    resp->oh.devices[d] =
+        malloc(resp->oh.num_devices[d] * sizeof(zes_device_handle_t));
+    ret = (*resp->oh.zesDeviceGet)(
+        resp->oh.drivers[d], &resp->oh.num_devices[d], resp->oh.devices[d]);
+    if (ret != ZE_RESULT_SUCCESS) {
+      LOG(resp->oh.verbose, "zesDeviceGet err: %x\n", ret);
+      snprintf(buf, buflen, "unable to get device count: %x", ret);
+      resp->err = strdup(buf);
+      oneapi_release(resp->oh);
+      return;
+    }
+    count += resp->oh.num_devices[d];
+  }
 
   return;
 }
 
-void oneapi_check_vram(oneapi_handle_t h, mem_info_t *resp)
-{
+void oneapi_check_vram(oneapi_handle_t h, int driver, int device,
+                       mem_info_t *resp) {
   ze_result_t ret;
   resp->err = NULL;
   uint64_t totalMem = 0;
@@ -88,127 +132,126 @@ void oneapi_check_vram(oneapi_handle_t h, mem_info_t *resp)
   char buf[buflen + 1];
   int i, d, m;
 
-  if (h.handle == NULL)
-  {
+  if (h.handle == NULL) {
     resp->err = strdup("Level-Zero handle not initialized");
     return;
   }
 
-  uint32_t driversCount = 0;
-  ret = (*h.zesDriverGet)(&driversCount, NULL);
-  if (ret != ZE_RESULT_SUCCESS)
-  {
-    snprintf(buf, buflen, "unable to get driver count: %d", ret);
-    resp->err = strdup(buf);
+  if (driver > h.num_drivers || device > h.num_devices[driver]) {
+    resp->err = strdup("driver of device index out of bounds");
     return;
   }
-  LOG(h.verbose, "discovered %d Level-Zero drivers\n", driversCount);
-
-  zes_driver_handle_t *allDrivers =
-      malloc(driversCount * sizeof(zes_driver_handle_t));
-  (*h.zesDriverGet)(&driversCount, allDrivers);
 
   resp->total = 0;
   resp->free = 0;
 
-  for (d = 0; d < driversCount; d++)
-  {
-    uint32_t deviceCount = 0;
-    ret = (*h.zesDeviceGet)(allDrivers[d], &deviceCount, NULL);
-    if (ret != ZE_RESULT_SUCCESS)
-    {
-      snprintf(buf, buflen, "unable to get device count: %d", ret);
+  zes_device_ext_properties_t ext_props;
+  ext_props.stype = ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES;
+  ext_props.pNext = NULL;
+
+  zes_device_properties_t props;
+  props.stype = ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES;
+  props.pNext = &ext_props;
+
+  ret = (*h.zesDeviceGetProperties)(h.devices[driver][device], &props);
+  if (ret != ZE_RESULT_SUCCESS) {
+    snprintf(buf, buflen, "unable to get device properties: %d", ret);
+    resp->err = strdup(buf);
+    return;
+  }
+
+  snprintf(&resp->gpu_name[0], GPU_NAME_LEN, props.modelName);
+
+  // TODO this needs to map to ONEAPI_DEVICE_SELECTOR syntax
+  // (this is probably wrong...)
+  // TODO - the driver isn't included - what if there are multiple drivers?
+  snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", device);
+
+  if (h.verbose) {
+    // When in verbose mode, report more information about
+    // the card we discover.
+    LOG(h.verbose, "[%d:%d] oneAPI device name: %s\n", driver, device,
+        props.modelName);
+    LOG(h.verbose, "[%d:%d] oneAPI brand: %s\n", driver, device,
+        props.brandName);
+    LOG(h.verbose, "[%d:%d] oneAPI vendor: %s\n", driver, device,
+        props.vendorName);
+    LOG(h.verbose, "[%d:%d] oneAPI S/N: %s\n", driver, device,
+        props.serialNumber);
+    LOG(h.verbose, "[%d:%d] oneAPI board number: %s\n", driver, device,
+        props.boardNumber);
+  }
+
+  // TODO
+  // Compute Capability equivalent in resp->major, resp->minor, resp->patch
+
+  uint32_t memCount = 0;
+  ret = (*h.zesDeviceEnumMemoryModules)(h.devices[driver][device], &memCount,
+                                        NULL);
+  if (ret != ZE_RESULT_SUCCESS) {
+    snprintf(buf, buflen, "unable to enumerate Level-Zero memory modules: %x",
+             ret);
+    resp->err = strdup(buf);
+    return;
+  }
+
+  LOG(h.verbose, "discovered %d Level-Zero memory modules\n", memCount);
+
+  zes_mem_handle_t *mems = malloc(memCount * sizeof(zes_mem_handle_t));
+  (*h.zesDeviceEnumMemoryModules)(h.devices[driver][device], &memCount, mems);
+
+  for (m = 0; m < memCount; m++) {
+    zes_mem_state_t state;
+    state.stype = ZES_STRUCTURE_TYPE_MEM_STATE;
+    state.pNext = NULL;
+    ret = (*h.zesMemoryGetState)(mems[m], &state);
+    if (ret != ZE_RESULT_SUCCESS) {
+      snprintf(buf, buflen, "unable to get memory state: %x", ret);
       resp->err = strdup(buf);
-      free(allDrivers);
+      free(mems);
       return;
     }
 
-    LOG(h.verbose, "discovered %d Level-Zero devices\n", deviceCount);
-
-    zes_device_handle_t *devices =
-        malloc(deviceCount * sizeof(zes_device_handle_t));
-    (*h.zesDeviceGet)(allDrivers[d], &deviceCount, devices);
-
-    for (i = 0; i < deviceCount; i++)
-    {
-      zes_device_ext_properties_t ext_props;
-      ext_props.stype = ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES;
-      ext_props.pNext = NULL;
-
-      zes_device_properties_t props;
-      props.stype = ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES;
-      props.pNext = &ext_props;
-
-      ret = (*h.zesDeviceGetProperties)(devices[i], &props);
-      if (ret != ZE_RESULT_SUCCESS)
-      {
-        snprintf(buf, buflen, "unable to get device properties: %d", ret);
-        resp->err = strdup(buf);
-        free(allDrivers);
-        free(devices);
-        return;
-      }
-
-      if (h.verbose)
-      {
-        // When in verbose mode, report more information about
-        // the card we discover.
-        LOG(h.verbose, "[%d] oneAPI device name: %s\n", i,
-            props.modelName);
-        LOG(h.verbose, "[%d] oneAPI brand: %s\n", i,
-            props.brandName);
-        LOG(h.verbose, "[%d] oneAPI vendor: %s\n", i,
-            props.vendorName);
-        LOG(h.verbose, "[%d] oneAPI S/N: %s\n", i,
-            props.serialNumber);
-        LOG(h.verbose, "[%d] oneAPI board number: %s\n", i,
-            props.boardNumber);
-      }
-
-      uint32_t memCount = 0;
-      ret = (*h.zesDeviceEnumMemoryModules)(devices[i], &memCount, NULL);
-      if (ret != ZE_RESULT_SUCCESS)
-      {
-        snprintf(buf, buflen,
-                 "unable to enumerate Level-Zero memory modules: %d", ret);
-        resp->err = strdup(buf);
-        free(allDrivers);
-        free(devices);
-        return;
-      }
-
-      LOG(h.verbose, "discovered %d Level-Zero memory modules\n", memCount);
-
-      zes_mem_handle_t *mems = malloc(memCount * sizeof(zes_mem_handle_t));
-      (*h.zesDeviceEnumMemoryModules)(devices[i], &memCount, mems);
-
-      for (m = 0; m < memCount; m++)
-      {
-        zes_mem_state_t state;
-        state.stype = ZES_STRUCTURE_TYPE_MEM_STATE;
-        state.pNext = NULL;
-        ret = (*h.zesMemoryGetState)(mems[m], &state);
-        if (ret != ZE_RESULT_SUCCESS)
-        {
-          snprintf(buf, buflen, "unable to get memory state: %d", ret);
-          resp->err = strdup(buf);
-          free(allDrivers);
-          free(devices);
-          free(mems);
-          return;
-        }
-
-        resp->total += state.size;
-        resp->free += state.free;
-      }
+    resp->total += state.size;
+    resp->free += state.free;
+  }
 
-      free(mems);
-    }
+  free(mems);
+}
 
-    free(devices);
+void oneapi_release(oneapi_handle_t h) {
+  int d;
+  LOG(h.verbose, "releasing oneapi library\n");
+  for (d = 0; d < h.num_drivers; d++) {
+    if (h.devices != NULL && h.devices[d] != NULL) {
+      free(h.devices[d]);
+    }
+  }
+  if (h.devices != NULL) {
+    free(h.devices);
+    h.devices = NULL;
   }
+  if (h.num_devices != NULL) {
+    free(h.num_devices);
+    h.num_devices = NULL;
+  }
+  if (h.drivers != NULL) {
+    free(h.drivers);
+    h.drivers = NULL;
+  }
+  h.num_drivers = 0;
+  UNLOAD_LIBRARY(h.handle);
+  h.handle = NULL;
+}
 
-  free(allDrivers);
+int oneapi_get_device_count(oneapi_handle_t h, int driver) {
+  if (h.handle == NULL || h.num_devices == NULL) {
+    return 0;
+  }
+  if (driver > h.num_drivers) {
+    return 0;
+  }
+  return (int)h.num_devices[driver];
 }
 
 #endif // __APPLE__

+ 34 - 42
gpu/gpu_info_oneapi.h

@@ -9,8 +9,7 @@
 #define ZE_BIT(_i) (1 << _i)
 
 // Just enough typedef's to dlopen/dlsym for memory information
-typedef enum ze_result_t
-{
+typedef enum ze_result_t {
   ZE_RESULT_SUCCESS = 0,
   // Other values omitted for now...
 } ze_result_t;
@@ -20,13 +19,11 @@ typedef struct _zes_driver_handle_t *zes_driver_handle_t;
 typedef struct _zes_device_handle_t *zes_device_handle_t;
 typedef struct _zes_mem_handle_t *zes_mem_handle_t;
 
-typedef enum _ze_structure_type_t
-{
+typedef enum _ze_structure_type_t {
   ZE_STRUCTURE_TYPE_FORCE_UINT32 = 0x7fffffff
 } ze_structure_type_t;
 
-typedef enum _zes_structure_type_t
-{
+typedef enum _zes_structure_type_t {
   ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES = 0x1,
   ZES_STRUCTURE_TYPE_MEM_PROPERTIES = 0xb,
   ZES_STRUCTURE_TYPE_MEM_STATE = 0x1e,
@@ -34,35 +31,29 @@ typedef enum _zes_structure_type_t
   ZES_STRUCTURE_TYPE_FORCE_UINT32 = 0x7fffffff
 } zes_structure_type_t;
 
-typedef enum _zes_mem_type_t
-{
+typedef enum _zes_mem_type_t {
   ZES_MEM_TYPE_FORCE_UINT32 = 0x7fffffff
 } zes_mem_type_t;
 
-typedef enum _zes_mem_loc_t
-{
+typedef enum _zes_mem_loc_t {
   ZES_MEM_LOC_SYSTEM = 0,
   ZES_MEM_LOC_DEVICE = 1,
   ZES_MEM_LOC_FORCE_UINT32 = 0x7fffffff
 } zes_mem_loc_t;
 
-typedef enum _zes_mem_health_t
-{
+typedef enum _zes_mem_health_t {
   ZES_MEM_HEALTH_FORCE_UINT32 = 0x7fffffff
 } zes_mem_health_t;
 
-typedef struct _ze_device_uuid_t
-{
+typedef struct _ze_device_uuid_t {
   uint8_t id[ZE_MAX_DEVICE_UUID_SIZE];
 } ze_device_uuid_t;
 
-typedef struct _zes_uuid_t
-{
+typedef struct _zes_uuid_t {
   uint8_t id[ZE_MAX_DEVICE_UUID_SIZE];
 } zes_uuid_t;
 
-typedef enum _ze_device_type_t
-{
+typedef enum _ze_device_type_t {
   ZE_DEVICE_TYPE_GPU = 1,
   ZE_DEVICE_TYPE_CPU = 2,
   ZE_DEVICE_TYPE_FPGA = 3,
@@ -71,8 +62,7 @@ typedef enum _ze_device_type_t
   ZE_DEVICE_TYPE_FORCE_UINT32 = 0x7fffffff
 } ze_device_type_t;
 
-typedef enum _zes_device_type_t
-{
+typedef enum _zes_device_type_t {
   ZES_DEVICE_TYPE_GPU = 1,
   ZES_DEVICE_TYPE_CPU = 2,
   ZES_DEVICE_TYPE_FPGA = 3,
@@ -82,8 +72,7 @@ typedef enum _zes_device_type_t
 } zes_device_type_t;
 
 typedef uint32_t ze_device_property_flags_t;
-typedef enum _ze_device_property_flag_t
-{
+typedef enum _ze_device_property_flag_t {
   ZE_DEVICE_PROPERTY_FLAG_INTEGRATED = ZE_BIT(0),
   ZE_DEVICE_PROPERTY_FLAG_SUBDEVICE = ZE_BIT(1),
   ZE_DEVICE_PROPERTY_FLAG_ECC = ZE_BIT(2),
@@ -92,8 +81,7 @@ typedef enum _ze_device_property_flag_t
 } ze_device_property_flag_t;
 
 typedef uint32_t zes_device_property_flags_t;
-typedef enum _zes_device_property_flag_t
-{
+typedef enum _zes_device_property_flag_t {
   ZES_DEVICE_PROPERTY_FLAG_INTEGRATED = ZE_BIT(0),
   ZES_DEVICE_PROPERTY_FLAG_SUBDEVICE = ZE_BIT(1),
   ZES_DEVICE_PROPERTY_FLAG_ECC = ZE_BIT(2),
@@ -101,8 +89,7 @@ typedef enum _zes_device_property_flag_t
   ZES_DEVICE_PROPERTY_FLAG_FORCE_UINT32 = 0x7fffffff
 } zes_device_property_flag_t;
 
-typedef struct _ze_device_properties_t
-{
+typedef struct _ze_device_properties_t {
   ze_structure_type_t stype;
   void *pNext;
   ze_device_type_t type;
@@ -126,8 +113,7 @@ typedef struct _ze_device_properties_t
   char name[ZE_MAX_DEVICE_NAME];
 } ze_device_properties_t;
 
-typedef struct _zes_device_properties_t
-{
+typedef struct _zes_device_properties_t {
   zes_structure_type_t stype;
   void *pNext;
   ze_device_properties_t core;
@@ -140,8 +126,7 @@ typedef struct _zes_device_properties_t
   char driverVersion[ZES_STRING_PROPERTY_SIZE];
 } zes_device_properties_t;
 
-typedef struct _zes_device_ext_properties_t
-{
+typedef struct _zes_device_ext_properties_t {
   zes_structure_type_t stype;
   void *pNext;
   zes_uuid_t uuid;
@@ -149,8 +134,7 @@ typedef struct _zes_device_ext_properties_t
   zes_device_property_flags_t flags;
 } zes_device_ext_properties_t;
 
-typedef struct _zes_mem_properties_t
-{
+typedef struct _zes_mem_properties_t {
   zes_structure_type_t stype;
   void *pNext;
   zes_mem_type_t type;
@@ -162,8 +146,7 @@ typedef struct _zes_mem_properties_t
   int32_t numChannels;
 } zes_mem_properties_t;
 
-typedef struct _zes_mem_state_t
-{
+typedef struct _zes_mem_state_t {
   zes_structure_type_t stype;
   const void *pNext;
   zes_mem_health_t health;
@@ -171,10 +154,19 @@ typedef struct _zes_mem_state_t
   uint64_t size;
 } zes_mem_state_t;
 
-typedef struct oneapi_handle
-{
+typedef struct oneapi_handle {
   void *handle;
   uint16_t verbose;
+
+  uint32_t num_drivers;
+  zes_driver_handle_t *drivers;
+  uint32_t *num_devices;
+  zes_device_handle_t **devices;
+
+  // TODO Driver major, minor information
+  // int driver_major;
+  // int driver_minor;
+
   ze_result_t (*zesInit)(int);
   ze_result_t (*zesDriverGet)(uint32_t *pCount, zes_driver_handle_t *phDrivers);
   ze_result_t (*zesDeviceGet)(zes_driver_handle_t hDriver, uint32_t *pCount,
@@ -191,21 +183,21 @@ typedef struct oneapi_handle
 
 } oneapi_handle_t;
 
-typedef struct oneapi_init_resp
-{
+typedef struct oneapi_init_resp {
   char *err; // If err is non-null handle is invalid
-  int num_devices;
   oneapi_handle_t oh;
 } oneapi_init_resp_t;
 
-typedef struct oneapi_version_resp
-{
+typedef struct oneapi_version_resp {
   ze_result_t status;
   char *str; // Contains version or error string if status != 0
 } oneapi_version_resp_t;
 
 void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp);
-void oneapi_check_vram(oneapi_handle_t rh, mem_info_t *resp);
+void oneapi_check_vram(oneapi_handle_t h, int driver, int device,
+                       mem_info_t *resp);
+void oneapi_release(oneapi_handle_t h);
+int oneapi_get_device_count(oneapi_handle_t h, int driver);
 
 #endif // __GPU_INFO_INTEL_H__
 #endif // __APPLE__

+ 89 - 0
gpu/gpu_linux.go

@@ -0,0 +1,89 @@
+package gpu
+
+import (
+	"bufio"
+	"fmt"
+	"os"
+	"strings"
+
+	"github.com/ollama/ollama/format"
+)
+
+var CudartGlobs = []string{
+	"/usr/local/cuda/lib64/libcudart.so*",
+	"/usr/lib/x86_64-linux-gnu/nvidia/current/libcudart.so*",
+	"/usr/lib/x86_64-linux-gnu/libcudart.so*",
+	"/usr/lib/wsl/lib/libcudart.so*",
+	"/usr/lib/wsl/drivers/*/libcudart.so*",
+	"/opt/cuda/lib64/libcudart.so*",
+	"/usr/local/cuda*/targets/aarch64-linux/lib/libcudart.so*",
+	"/usr/lib/aarch64-linux-gnu/nvidia/current/libcudart.so*",
+	"/usr/lib/aarch64-linux-gnu/libcudart.so*",
+	"/usr/local/cuda/lib*/libcudart.so*",
+	"/usr/lib*/libcudart.so*",
+	"/usr/local/lib*/libcudart.so*",
+}
+
+var NvmlGlobs = []string{}
+
+var NvcudaGlobs = []string{
+	"/usr/local/cuda*/targets/*/lib/libcuda.so*",
+	"/usr/lib/*-linux-gnu/nvidia/current/libcuda.so*",
+	"/usr/lib/*-linux-gnu/libcuda.so*",
+	"/usr/lib/wsl/lib/libcuda.so*",
+	"/usr/lib/wsl/drivers/*/libcuda.so*",
+	"/opt/cuda/lib*/libcuda.so*",
+	"/usr/local/cuda/lib*/libcuda.so*",
+	"/usr/lib*/libcuda.so*",
+	"/usr/local/lib*/libcuda.so*",
+}
+
+var OneapiGlobs = []string{
+	"/usr/lib/x86_64-linux-gnu/libze_intel_gpu.so*",
+	"/usr/lib*/libze_intel_gpu.so*",
+}
+
+var CudartMgmtName = "libcudart.so*"
+var NvcudaMgmtName = "libcuda.so*"
+var NvmlMgmtName = "" // not currently wired on linux
+var OneapiMgmtName = "libze_intel_gpu.so"
+
+func GetCPUMem() (memInfo, error) {
+	var mem memInfo
+	var total, available, free, buffers, cached uint64
+	f, err := os.Open("/proc/meminfo")
+	if err != nil {
+		return mem, err
+	}
+	defer f.Close()
+	s := bufio.NewScanner(f)
+	for s.Scan() {
+		line := s.Text()
+		switch {
+		case strings.HasPrefix(line, "MemTotal:"):
+			_, err = fmt.Sscanf(line, "MemTotal:%d", &total)
+		case strings.HasPrefix(line, "MemAvailable:"):
+			_, err = fmt.Sscanf(line, "MemAvailable:%d", &available)
+		case strings.HasPrefix(line, "MemFree:"):
+			_, err = fmt.Sscanf(line, "MemFree:%d", &free)
+		case strings.HasPrefix(line, "Buffers:"):
+			_, err = fmt.Sscanf(line, "Buffers:%d", &buffers)
+		case strings.HasPrefix(line, "Cached:"):
+			_, err = fmt.Sscanf(line, "Cached:%d", &cached)
+		default:
+			continue
+		}
+		if err != nil {
+			return mem, err
+		}
+
+		if total > 0 && available > 0 {
+			mem.TotalMemory = total * format.KibiByte
+			mem.FreeMemory = available * format.KibiByte
+			return mem, nil
+		}
+	}
+	mem.TotalMemory = total * format.KibiByte
+	mem.FreeMemory = (free + buffers + cached) * format.KibiByte
+	return mem, nil
+}

+ 55 - 0
gpu/gpu_windows.go

@@ -0,0 +1,55 @@
+package gpu
+
+import (
+	"fmt"
+	"syscall"
+	"unsafe"
+)
+
+type MEMORYSTATUSEX struct {
+	length               uint32
+	MemoryLoad           uint32
+	TotalPhys            uint64
+	AvailPhys            uint64
+	TotalPageFile        uint64
+	AvailPageFile        uint64
+	TotalVirtual         uint64
+	AvailVirtual         uint64
+	AvailExtendedVirtual uint64
+}
+
+var (
+	k32                      = syscall.NewLazyDLL("kernel32.dll")
+	globalMemoryStatusExProc = k32.NewProc("GlobalMemoryStatusEx")
+	sizeofMemoryStatusEx     = uint32(unsafe.Sizeof(MEMORYSTATUSEX{}))
+)
+
+var CudartGlobs = []string{
+	"c:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*\\bin\\cudart64_*.dll",
+}
+
+var NvmlGlobs = []string{
+	"c:\\Windows\\System32\\nvml.dll",
+}
+
+var NvcudaGlobs = []string{
+	"c:\\windows\\system*\\nvcuda.dll",
+}
+
+var OneapiGlobs = []string{
+	"c:\\Windows\\System32\\DriverStore\\FileRepository\\*\\ze_intel_gpu64.dll",
+}
+
+var CudartMgmtName = "cudart64_*.dll"
+var NvcudaMgmtName = "nvcuda.dll"
+var NvmlMgmtName = "nvml.dll"
+var OneapiMgmtName = "ze_intel_gpu64.dll"
+
+func GetCPUMem() (memInfo, error) {
+	memStatus := MEMORYSTATUSEX{length: sizeofMemoryStatusEx}
+	r1, _, err := globalMemoryStatusExProc.Call(uintptr(unsafe.Pointer(&memStatus)))
+	if r1 == 0 {
+		return memInfo{}, fmt.Errorf("GlobalMemoryStatusEx failed: %w", err)
+	}
+	return memInfo{TotalMemory: memStatus.TotalPhys, FreeMemory: memStatus.AvailPhys}, nil
+}

+ 50 - 3
gpu/types.go

@@ -18,7 +18,7 @@ type GpuInfo struct {
 	Library string `json:"library,omitempty"`
 
 	// Optional variant to select (e.g. versions, cpu feature flags)
-	Variant string `json:"variant,omitempty"`
+	Variant CPUCapability `json:"variant"`
 
 	// MinimumMemory represents the minimum memory required to use the GPU
 	MinimumMemory uint64 `json:"-"`
@@ -38,6 +38,30 @@ type GpuInfo struct {
 	// TODO other performance capability info to help in scheduling decisions
 }
 
+type CPUInfo struct {
+	GpuInfo
+}
+
+type CudaGPUInfo struct {
+	GpuInfo
+	index int //nolint:unused,nolintlint
+}
+type CudaGPUInfoList []CudaGPUInfo
+
+type RocmGPUInfo struct {
+	GpuInfo
+	usedFilepath string //nolint:unused,nolintlint
+	index        int    //nolint:unused,nolintlint
+}
+type RocmGPUInfoList []RocmGPUInfo
+
+type OneapiGPUInfo struct {
+	GpuInfo
+	driverIndex int //nolint:unused,nolintlint
+	gpuIndex    int //nolint:unused,nolintlint
+}
+type OneapiGPUInfoList []OneapiGPUInfo
+
 type GpuInfoList []GpuInfo
 
 // Split up the set of gpu info's by Library and variant
@@ -47,8 +71,8 @@ func (l GpuInfoList) ByLibrary() []GpuInfoList {
 	for _, info := range l {
 		found := false
 		requested := info.Library
-		if info.Variant != "" {
-			requested += "_" + info.Variant
+		if info.Variant != CPUCapabilityNone {
+			requested += "_" + info.Variant.String()
 		}
 		for i, lib := range libs {
 			if lib == requested {
@@ -86,3 +110,26 @@ type ByFreeMemory []GpuInfo
 func (a ByFreeMemory) Len() int           { return len(a) }
 func (a ByFreeMemory) Swap(i, j int)      { a[i], a[j] = a[j], a[i] }
 func (a ByFreeMemory) Less(i, j int) bool { return a[i].FreeMemory < a[j].FreeMemory }
+
+type CPUCapability uint32
+
+// Override at build time when building base GPU runners
+var GPURunnerCPUCapability = CPUCapabilityAVX
+
+const (
+	CPUCapabilityNone CPUCapability = iota
+	CPUCapabilityAVX
+	CPUCapabilityAVX2
+	// TODO AVX512
+)
+
+func (c CPUCapability) String() string {
+	switch c {
+	case CPUCapabilityAVX:
+		return "avx"
+	case CPUCapabilityAVX2:
+		return "avx2"
+	default:
+		return "no vector extensions"
+	}
+}

+ 59 - 17
integration/concurrency_test.go

@@ -19,17 +19,19 @@ func TestMultiModelConcurrency(t *testing.T) {
 	var (
 		req = [2]api.GenerateRequest{
 			{
-				Model:  "orca-mini",
-				Prompt: "why is the ocean blue?",
-				Stream: &stream,
+				Model:     "orca-mini",
+				Prompt:    "why is the ocean blue?",
+				Stream:    &stream,
+				KeepAlive: &api.Duration{Duration: 10 * time.Second},
 				Options: map[string]interface{}{
 					"seed":        42,
 					"temperature": 0.0,
 				},
 			}, {
-				Model:  "tinydolphin",
-				Prompt: "what is the origin of the us thanksgiving holiday?",
-				Stream: &stream,
+				Model:     "tinydolphin",
+				Prompt:    "what is the origin of the us thanksgiving holiday?",
+				Stream:    &stream,
+				KeepAlive: &api.Duration{Duration: 10 * time.Second},
 				Options: map[string]interface{}{
 					"seed":        42,
 					"temperature": 0.0,
@@ -38,42 +40,64 @@ func TestMultiModelConcurrency(t *testing.T) {
 		}
 		resp = [2][]string{
 			[]string{"sunlight"},
-			[]string{"england", "english", "massachusetts", "pilgrims"},
+			[]string{"england", "english", "massachusetts", "pilgrims", "british"},
 		}
 	)
 	var wg sync.WaitGroup
 	wg.Add(len(req))
-	ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*240)
 	defer cancel()
+
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+
+	for i := 0; i < len(req); i++ {
+		require.NoError(t, PullIfMissing(ctx, client, req[i].Model))
+	}
+
 	for i := 0; i < len(req); i++ {
 		go func(i int) {
 			defer wg.Done()
-			GenerateTestHelper(ctx, t, req[i], resp[i])
+			DoGenerate(ctx, t, client, req[i], resp[i], 60*time.Second, 10*time.Second)
 		}(i)
 	}
 	wg.Wait()
 }
 
 func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
-	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) // GTX 750 2G card takes ~9 minutes
+	req, resp := GenerateRequests()
+	reqLimit := len(req)
+	iterLimit := 5
+
+	vram := os.Getenv("OLLAMA_MAX_VRAM")
+	if vram != "" {
+		max, err := strconv.ParseUint(vram, 10, 64)
+		require.NoError(t, err)
+		// Don't hammer on small VRAM cards...
+		if max < 4*1024*1024*1024 {
+			reqLimit = min(reqLimit, 2)
+			iterLimit = 2
+		}
+	}
+
+	ctx, cancel := context.WithTimeout(context.Background(), 9*time.Minute)
 	defer cancel()
 	client, _, cleanup := InitServerConnection(ctx, t)
 	defer cleanup()
 
-	req, resp := GenerateRequests()
 	// Get the server running (if applicable) warm the model up with a single initial request
-	DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 5*time.Second)
+	DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 10*time.Second)
 
 	var wg sync.WaitGroup
-	wg.Add(len(req))
-	for i := 0; i < len(req); i++ {
+	wg.Add(reqLimit)
+	for i := 0; i < reqLimit; i++ {
 		go func(i int) {
 			defer wg.Done()
-			for j := 0; j < 5; j++ {
+			for j := 0; j < iterLimit; j++ {
 				slog.Info("Starting", "req", i, "iter", j)
-				// On slower GPUs it can take a while to process the 4 concurrent requests
+				// On slower GPUs it can take a while to process the concurrent requests
 				// so we allow a much longer initial timeout
-				DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
+				DoGenerate(ctx, t, client, req[i], resp[i], 120*time.Second, 20*time.Second)
 			}
 		}(i)
 	}
@@ -221,5 +245,23 @@ func TestMultiModelStress(t *testing.T) {
 			}
 		}(i)
 	}
+	go func() {
+		for {
+			time.Sleep(2 * time.Second)
+			select {
+			case <-ctx.Done():
+				return
+			default:
+				models, err := client.ListRunning(ctx)
+				if err != nil {
+					slog.Warn("failed to list running models", "error", err)
+					continue
+				}
+				for _, m := range models.Models {
+					slog.Info("loaded model snapshot", "model", m)
+				}
+			}
+		}
+	}()
 	wg.Wait()
 }

+ 2 - 1
integration/context_test.go

@@ -11,7 +11,8 @@ import (
 )
 
 func TestContextExhaustion(t *testing.T) {
-	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) // TODO maybe shorter?
+	// Longer needed for small footprint GPUs
+	ctx, cancel := context.WithTimeout(context.Background(), 6*time.Minute)
 	defer cancel()
 	// Set up the test data
 	req := api.GenerateRequest{

+ 5 - 1
integration/llm_image_test.go

@@ -32,7 +32,11 @@ func TestIntegrationMultimodal(t *testing.T) {
 	resp := "the ollam"
 	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
 	defer cancel()
-	GenerateTestHelper(ctx, t, req, []string{resp})
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+	require.NoError(t, PullIfMissing(ctx, client, req.Model))
+	// llava models on CPU can be quite slow to start,
+	DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
 }
 
 const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb

+ 22 - 17
integration/utils_test.go

@@ -140,7 +140,7 @@ func PullIfMissing(ctx context.Context, client *api.Client, modelName string) er
 
 	showCtx, cancel := context.WithDeadlineCause(
 		ctx,
-		time.Now().Add(5*time.Second),
+		time.Now().Add(10*time.Second),
 		fmt.Errorf("show for existing model %s took too long", modelName),
 	)
 	defer cancel()
@@ -287,41 +287,46 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
 func GenerateRequests() ([]api.GenerateRequest, [][]string) {
 	return []api.GenerateRequest{
 			{
-				Model:  "orca-mini",
-				Prompt: "why is the ocean blue?",
-				Stream: &stream,
+				Model:     "orca-mini",
+				Prompt:    "why is the ocean blue?",
+				Stream:    &stream,
+				KeepAlive: &api.Duration{Duration: 10 * time.Second},
 				Options: map[string]interface{}{
 					"seed":        42,
 					"temperature": 0.0,
 				},
 			}, {
-				Model:  "orca-mini",
-				Prompt: "why is the color of dirt brown?",
-				Stream: &stream,
+				Model:     "orca-mini",
+				Prompt:    "why is the color of dirt brown?",
+				Stream:    &stream,
+				KeepAlive: &api.Duration{Duration: 10 * time.Second},
 				Options: map[string]interface{}{
 					"seed":        42,
 					"temperature": 0.0,
 				},
 			}, {
-				Model:  "orca-mini",
-				Prompt: "what is the origin of the us thanksgiving holiday?",
-				Stream: &stream,
+				Model:     "orca-mini",
+				Prompt:    "what is the origin of the us thanksgiving holiday?",
+				Stream:    &stream,
+				KeepAlive: &api.Duration{Duration: 10 * time.Second},
 				Options: map[string]interface{}{
 					"seed":        42,
 					"temperature": 0.0,
 				},
 			}, {
-				Model:  "orca-mini",
-				Prompt: "what is the origin of independence day?",
-				Stream: &stream,
+				Model:     "orca-mini",
+				Prompt:    "what is the origin of independence day?",
+				Stream:    &stream,
+				KeepAlive: &api.Duration{Duration: 10 * time.Second},
 				Options: map[string]interface{}{
 					"seed":        42,
 					"temperature": 0.0,
 				},
 			}, {
-				Model:  "orca-mini",
-				Prompt: "what is the composition of air?",
-				Stream: &stream,
+				Model:     "orca-mini",
+				Prompt:    "what is the composition of air?",
+				Stream:    &stream,
+				KeepAlive: &api.Duration{Duration: 10 * time.Second},
 				Options: map[string]interface{}{
 					"seed":        42,
 					"temperature": 0.0,
@@ -331,7 +336,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
 		[][]string{
 			[]string{"sunlight"},
 			[]string{"soil", "organic", "earth", "black", "tan"},
-			[]string{"england", "english", "massachusetts", "pilgrims"},
+			[]string{"england", "english", "massachusetts", "pilgrims", "british"},
 			[]string{"fourth", "july", "declaration", "independence"},
 			[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
 		}

+ 7 - 7
llm/ext_server/server.cpp

@@ -2335,9 +2335,9 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, g
                 invalid_param = true;
                 break;
             }
-#ifndef GGML_USE_CUBLAS
-            fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Setting the split mode has no effect.\n");
-#endif // GGML_USE_CUBLAS
+#ifndef GGML_USE_CUDA
+            fprintf(stderr, "warning: llama.cpp was compiled without CUDA. Setting the split mode has no effect.\n");
+#endif // GGML_USE_CUDA
         }
         else if (arg == "--tensor-split" || arg == "-ts")
         {
@@ -2346,7 +2346,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, g
                 invalid_param = true;
                 break;
             }
-#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_SYCL)
+#if defined(GGML_USE_CUDA) || defined(GGML_USE_SYCL)
             std::string arg_next = argv[i];
 
             // split string by , and /
@@ -2367,8 +2367,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, g
                 }
             }
 #else
-            LOG_WARNING("llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n", {});
-#endif // GGML_USE_CUBLAS
+            LOG_WARNING("llama.cpp was compiled without CUDA. It is not possible to set a tensor split.\n", {});
+#endif // GGML_USE_CUDA
         }
         else if (arg == "--main-gpu" || arg == "-mg")
         {
@@ -2377,7 +2377,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, g
                 invalid_param = true;
                 break;
             }
-#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_SYCL)
+#if defined(GGML_USE_CUDA) || defined(GGML_USE_SYCL)
             params.main_gpu = std::stoi(argv[i]);
 #else
             LOG_WARNING("llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.", {});

+ 1 - 0
llm/ggml.go

@@ -307,6 +307,7 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
 
 		partialOffload = 4 * batch * embedding
 		partialOffload += max(
+			// 4*batch*(4+6*embedding+context*(2*heads)+llm.KV().GQA()),
 			4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV),
 			4*batch*(embedding+vocab)+embedding*vocab*105/128,
 		)

+ 202 - 55
llm/memory.go

@@ -1,11 +1,11 @@
 package llm
 
 import (
-	"fmt"
 	"log/slog"
+	"strconv"
+	"strings"
 
 	"github.com/ollama/ollama/api"
-	"github.com/ollama/ollama/envconfig"
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/gpu"
 )
@@ -16,7 +16,8 @@ func PredictServerFit(allGpus gpu.GpuInfoList, ggml *GGML, adapters, projectors
 	var estimatedVRAM uint64
 	for _, gpus := range allGpus.ByLibrary() {
 		var layerCount int
-		layerCount, estimatedVRAM, _ = EstimateGPULayers(gpus, ggml, projectors, opts)
+		estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
+		layerCount, estimatedVRAM = estimate.Layers, estimate.VRAMSize
 		if opts.NumGPU < 0 {
 			if layerCount > 0 && layerCount >= int(ggml.KV().BlockCount()+1) {
 				return true, estimatedVRAM
@@ -30,24 +31,64 @@ func PredictServerFit(allGpus gpu.GpuInfoList, ggml *GGML, adapters, projectors
 	return false, estimatedVRAM
 }
 
+type MemoryEstimate struct {
+	// How many layers we predict we can load
+	Layers int
+
+	// The size of the graph which occupies the main GPU
+	Graph uint64
+
+	// How much VRAM will be allocated given the number of layers we predict
+	VRAMSize uint64
+
+	// The total size of the model if loaded into VRAM.  If all layers are loaded, VRAMSize == TotalSize
+	TotalSize uint64
+
+	// For multi-GPU scenarios, this provides the tensor split parameter
+	TensorSplit string
+
+	// For multi-GPU scenarios, this is the size in bytes per GPU
+	GPUSizes []uint64
+}
+
 // Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
 // The GPUs provided must all be the same Library
-func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts api.Options) (int, uint64, uint64) {
-	var memoryAvailable uint64
-	for _, info := range gpus {
-		memoryAvailable += info.FreeMemory
-	}
-	if envconfig.MaxVRAM > 0 {
-		memoryAvailable = envconfig.MaxVRAM
-	}
+func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts api.Options) MemoryEstimate {
+	// Graph size for a partial offload, applies to all GPUs
+	var graphPartialOffload uint64
+
+	// Graph size when all layers are offloaded, applies to all GPUs
+	var graphFullOffload uint64
+
+	// Final graph offload once we know full or partial
+	var graphOffload uint64
+
+	// Projectors loaded into GPU0 only
+	var projectorSize uint64
+
+	// Conditional output size on GPU 0
+	var memoryLayerOutput uint64
 
-	slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", format.HumanBytes2(memoryAvailable))
+	// The sizes of a layer
+	var layerSize uint64
 
-	// TODO - this is probably wrong, first GPU vs secondaries will have different overheads
-	memoryMinimum := gpus[0].MinimumMemory
+	// The sum of all the layer sizes (just for logging)
+	var memoryWeights uint64
+
+	// True if all the layers are loaded
+	var fullyLoaded bool
+
+	// Overflow that didn't fit into the GPU
+	var overflow uint64
+
+	availableList := make([]string, len(gpus))
+	for i, gpu := range gpus {
+		availableList[i] = format.HumanBytes2(gpu.FreeMemory)
+	}
+	slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", availableList)
 
 	for _, projector := range projectors {
-		memoryMinimum += projectorMemoryRequirements(projector)
+		projectorSize += projectorMemoryRequirements(projector)
 
 		// multimodal models require at least 2048 context
 		opts.NumCtx = max(opts.NumCtx, 2048)
@@ -56,79 +97,160 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts
 	layers := ggml.Tensors().Layers()
 	// add one layer worth of memory as a buffer
 	if blk0, ok := layers["blk.0"]; ok {
-		memoryMinimum += blk0.size()
+		layerSize = blk0.size()
+	} else {
+		slog.Warn("model missing blk.0 layer size")
 	}
 
 	// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
 	var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
 
-	graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
+	// KV is proportional to the number of layers
+	layerSize += kv / ggml.KV().BlockCount()
+
+	graphPartialOffload, graphFullOffload = ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
 	if graphPartialOffload == 0 {
 		graphPartialOffload = ggml.KV().GQA() * kv / 6
 	}
-
 	if graphFullOffload == 0 {
 		graphFullOffload = graphPartialOffload
 	}
 
-	graphFullOffload *= uint64(len(gpus))
-	graphPartialOffload *= uint64(len(gpus))
-
 	// on metal there's no partial offload overhead
 	if gpus[0].Library == "metal" {
 		graphPartialOffload = graphFullOffload
+	} else if len(gpus) > 1 {
+		// multigpu should always use the partial graph size
+		graphFullOffload = graphPartialOffload
 	}
 
-	// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
-	memoryRequiredTotal := memoryMinimum + graphFullOffload
-
-	// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
-	memoryRequiredPartial := memoryMinimum + graphPartialOffload
-
-	var memoryLayerOutput uint64
 	if layer, ok := layers["output_norm"]; ok {
 		memoryLayerOutput += layer.size()
 	}
-
 	if layer, ok := layers["output"]; ok {
 		memoryLayerOutput += layer.size()
 	} else if layer, ok := layers["token_embd"]; ok {
 		memoryLayerOutput += layer.size()
 	}
 
-	if gpus[0].Library == "metal" && opts.UseMMap {
-		// memory is preallocated for output tensors
-		memoryRequiredTotal += memoryLayerOutput
-		memoryRequiredPartial += memoryLayerOutput
-	}
+	// Output layer handled at the end if we have space
+	gpuZeroOverhead := projectorSize
 
+	// Reduce set of GPUs to only those that have sufficient space to fit overhead and at least one layer
 	var layerCount int
+	layerCounts := make([]int, len(gpus))
+	gpuAllocations := make([]uint64, len(gpus))
+	type gs struct {
+		i int
+		g *gpu.GpuInfo
+	}
+	gpusWithSpace := []gs{}
+	for i := range gpus {
+		var gzo uint64
+		if len(gpusWithSpace) == 0 {
+			gzo = gpuZeroOverhead
+		}
+		// Only include GPUs that can fit the graph, gpu minimum, the layer buffer and at least more layer
+		if gpus[i].FreeMemory < gzo+max(graphPartialOffload, graphFullOffload)+gpus[i].MinimumMemory+2*layerSize {
+			slog.Debug("gpu has too little memory to allocate any layers", "gpu", gpus[i])
+			continue
+		}
+		gpusWithSpace = append(gpusWithSpace, gs{i, &gpus[i]})
+		gpuAllocations[i] += gpus[i].MinimumMemory + layerSize // We hold off on graph until we know partial vs. full
+	}
+
+	var gpuZeroID int
+	if len(gpusWithSpace) > 0 {
+		gpuZeroID = gpusWithSpace[0].i
+		gpuAllocations[gpuZeroID] += gpuZeroOverhead
+	}
+
+	// For all the layers, find where they can fit on the GPU(s)
 	for i := range int(ggml.KV().BlockCount()) {
-		if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
-			memoryLayer := blk.size()
+		memoryWeights += layerSize
 
-			// KV is proportional to the number of layers
-			memoryLayer += kv / ggml.KV().BlockCount()
+		if opts.NumGPU >= 0 && layerCount >= opts.NumGPU {
+			// Stop allocating on GPU(s) once we hit the users target NumGPU
+			continue
+		}
 
-			memoryRequiredTotal += memoryLayer
-			if (opts.NumGPU >= 0 && layerCount+1 <= opts.NumGPU) || (opts.NumGPU < 0 && memoryAvailable > memoryRequiredPartial+memoryLayer) {
-				memoryRequiredPartial += memoryLayer
+		// distribute the layers across the GPU(s) that have space
+		for j := len(gpusWithSpace); j > 0; j-- {
+			g := gpusWithSpace[i%j]
+			used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload)
+			if g.g.FreeMemory > used+layerSize {
+				gpuAllocations[g.i] += layerSize
+				layerCounts[g.i]++
 				layerCount++
+				break
+			} else {
+				gpusWithSpace = append(gpusWithSpace[:i%j], gpusWithSpace[i%j+1:]...)
 			}
 		}
 	}
+	if layerCount >= int(ggml.KV().BlockCount()) {
+		fullyLoaded = true
+	} else {
+		for i := layerCount; i < int(ggml.KV().BlockCount()); i++ {
+			overflow += layerSize
+		}
+	}
+
+	// Determine if we need to consider output then find where it fits
+	if memoryLayerOutput > 0 && (opts.NumGPU < 0 || layerCount < opts.NumGPU) {
+		for j := len(gpusWithSpace); j > 0; j-- {
+			g := gpusWithSpace[layerCount%j]
+			used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload)
+			if g.g.FreeMemory > used+memoryLayerOutput {
+				gpuAllocations[g.i] += memoryLayerOutput
+				layerCounts[g.i]++
+				layerCount++
+				break
+			}
+		}
 
-	if gpus[0].Library != "metal" || !opts.UseMMap {
-		// memory was not preallocated for output tensors
-		memoryRequiredTotal += memoryLayerOutput
+		if layerCount < int(ggml.KV().BlockCount())+1 {
+			fullyLoaded = false
+			overflow += memoryLayerOutput
+		}
 	}
 
-	if (opts.NumGPU >= 0 && layerCount+1 <= opts.NumGPU) || (opts.NumGPU < 0 && memoryAvailable > memoryRequiredTotal) {
-		layerCount = int(ggml.KV().BlockCount()) + 1
-		memoryRequiredPartial = memoryRequiredTotal
+	// Add the applicable (full or partial) graph allocations
+	for i := range gpus {
+		if layerCounts[i] <= 0 {
+			continue
+		}
+		if fullyLoaded {
+			gpuAllocations[i] += graphFullOffload
+		} else {
+			gpuAllocations[i] += graphPartialOffload
+		}
+	}
+	if fullyLoaded {
+		graphOffload = graphFullOffload
+	} else {
+		graphOffload = graphPartialOffload
 	}
 
-	memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
+	// Summaries for the log
+	var memoryRequiredPartial, memoryRequiredTotal uint64
+	for i := range gpuAllocations {
+		memoryRequiredPartial += gpuAllocations[i]
+	}
+	memoryRequiredTotal = memoryRequiredPartial + overflow
+
+	tensorSplit := ""
+	if len(gpus) > 1 {
+		splits := make([]string, len(gpus))
+		for i, count := range layerCounts {
+			splits[i] = strconv.Itoa(count)
+		}
+		tensorSplit = strings.Join(splits, ",")
+	}
+	allocationsList := []string{}
+	for _, a := range gpuAllocations {
+		allocationsList = append(allocationsList, format.HumanBytes2(a))
+	}
 
 	slog.Info(
 		"offload to gpu",
@@ -136,13 +258,17 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts
 			"layers",
 			// requested number of layers to offload
 			"requested", opts.NumGPU,
+			// The number of layers the model has (including output)
+			"model", int(ggml.KV().BlockCount())+1,
 			// estimated number of layers that can be offloaded
-			"real", layerCount,
+			"offload", layerCount,
+			// multi-gpu split for tesnors
+			"split", tensorSplit,
 		),
 		slog.Group(
 			"memory",
-			// memory available for offloading
-			"available", format.HumanBytes2(memoryAvailable),
+			// memory available by GPU for offloading
+			"available", availableList,
 			slog.Group(
 				"required",
 				// memory required for full offloading
@@ -151,6 +277,8 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts
 				"partial", format.HumanBytes2(memoryRequiredPartial),
 				// memory of KV cache
 				"kv", format.HumanBytes2(kv),
+				// Allocations across the GPUs
+				"allocations", allocationsList,
 			),
 			slog.Group(
 				"weights",
@@ -171,12 +299,31 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts
 		),
 	)
 	if gpus[0].Library == "cpu" {
-		return 0, 0, memoryRequiredTotal
+		return MemoryEstimate{
+			Layers:    0,
+			Graph:     0,
+			VRAMSize:  0,
+			TotalSize: memoryRequiredTotal,
+			GPUSizes:  []uint64{},
+		}
 	}
-	if memoryRequiredPartial > memoryAvailable {
+	if layerCount == 0 {
 		slog.Debug("insufficient VRAM to load any model layers")
-		return 0, 0, memoryRequiredTotal
+		return MemoryEstimate{
+			Layers:    0,
+			Graph:     0,
+			VRAMSize:  0,
+			TotalSize: memoryRequiredTotal,
+			GPUSizes:  []uint64{},
+		}
 	}
 
-	return layerCount, memoryRequiredPartial, memoryRequiredTotal
+	return MemoryEstimate{
+		Layers:      layerCount,
+		Graph:       graphOffload,
+		VRAMSize:    memoryRequiredPartial,
+		TotalSize:   memoryRequiredTotal,
+		TensorSplit: tensorSplit,
+		GPUSizes:    gpuAllocations,
+	}
 }

+ 127 - 0
llm/memory_test.go

@@ -0,0 +1,127 @@
+package llm
+
+import (
+	"bytes"
+	"encoding/binary"
+	"fmt"
+	"os"
+	"testing"
+
+	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/envconfig"
+	"github.com/ollama/ollama/gpu"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func TestEstimateGPULayers(t *testing.T) {
+	envconfig.Debug = true
+	modelName := "dummy"
+	f, err := os.CreateTemp(t.TempDir(), modelName)
+	require.NoError(t, err)
+	defer f.Close()
+	gguf := NewGGUFV3(binary.LittleEndian)
+	inputLayerCount := 5
+	tensors := []Tensor{
+		{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
+		{Name: "blk.1.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
+		{Name: "blk.2.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
+		{Name: "blk.3.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
+		{Name: "blk.4.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
+		{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
+	}
+	assert.Len(t, tensors, inputLayerCount+1)
+	err = gguf.Encode(f, KV{
+		"general.architecture":          "llama",
+		"general.name":                  "name",
+		"llama.context_length":          uint32(32),
+		"llama.embedding_length":        uint32(4096),
+		"llama.block_count":             uint32(inputLayerCount),
+		"llama.attention.head_count":    uint32(32),
+		"llama.attention.head_count_kv": uint32(32),
+		"tokenizer.ggml.tokens":         []string{" "},
+		"tokenizer.ggml.scores":         []float32{0},
+		"tokenizer.ggml.token_type":     []int32{0},
+	}, tensors)
+	require.NoError(t, err)
+
+	ggml, err := LoadModel(f.Name())
+	require.NoError(t, err)
+
+	// Simple CPU scenario
+	gpus := []gpu.GpuInfo{
+		{
+			Library: "cpu",
+		},
+	}
+	projectors := []string{}
+	opts := api.DefaultOptions()
+	t.Run("cpu", func(t *testing.T) {
+		estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
+		assert.Equal(t, 0, estimate.Layers)
+		assert.Equal(t, uint64(0), estimate.Graph)
+	})
+
+	// derived from the dummy ggml file above
+	graphPartialOffload := uint64(202377216)
+	graphFullOffload := uint64(171968512)
+	layerSize := uint64(33554436)
+	projectorSize := uint64(0)
+	memoryLayerOutput := uint64(4)
+
+	// Dual CUDA scenario with assymetry
+	gpuMinimumMemory := uint64(2048)
+	gpus = []gpu.GpuInfo{
+		{
+			Library:       "cuda",
+			MinimumMemory: gpuMinimumMemory,
+		},
+		{
+			Library:       "cuda",
+			MinimumMemory: gpuMinimumMemory,
+		},
+	}
+	// Nested array: GPU0 layer space, GPU1 layer space, expected gpu0, expected gpu1
+	for i, s := range []struct {
+		layer0, layer1   uint64
+		expect0, expect1 uint64
+	}{
+		{1, 1, 1, 1},
+		{2, 1, 2, 1},
+		{2, 2, 2, 2},
+		{1, 2, 1, 2},
+		{3, 3, 3, 3},
+		{4, 4, 3, 3},
+		{6, 6, 3, 3},
+		{0, 3, 0, 3},
+	} {
+		t.Run(fmt.Sprintf("%v", s), func(t *testing.T) {
+			gpus[0].FreeMemory = 0
+			gpus[1].FreeMemory = 0
+			gpus[0].FreeMemory += projectorSize
+			if s.layer0 > 0 {
+				gpus[0].FreeMemory += memoryLayerOutput
+			} else {
+				gpus[1].FreeMemory += memoryLayerOutput
+			}
+			gpus[0].FreeMemory += gpuMinimumMemory + layerSize + s.layer0*layerSize + 1
+			gpus[1].FreeMemory += gpuMinimumMemory + layerSize + s.layer1*layerSize + 1
+			gpus[0].FreeMemory += max(graphFullOffload, graphPartialOffload)
+			gpus[1].FreeMemory += max(graphFullOffload, graphPartialOffload)
+			estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
+			assert.Equal(t, int(s.expect0+s.expect1), estimate.Layers, "scenario %d: %v", i, s)
+			assert.Equal(t, fmt.Sprintf("%d,%d", s.expect0, s.expect1), estimate.TensorSplit, "scenario %d: %v", i, s)
+			var layerSums uint64
+			for _, b := range estimate.GPUSizes {
+				layerSums += b
+			}
+			if estimate.Layers < inputLayerCount+1 {
+				assert.Less(t, estimate.VRAMSize, estimate.TotalSize, "scenario %d: %v %+v", i, s, estimate)
+				assert.Equal(t, estimate.VRAMSize, layerSums, "scenario %d: %v %+v", i, s, estimate)
+			} else {
+				assert.Equal(t, estimate.VRAMSize, estimate.TotalSize, "scenario %d: %v %+v", i, s, estimate)
+				assert.Equal(t, estimate.TotalSize, layerSums, "scenario %d: %v %+v", i, s, estimate)
+			}
+		})
+	}
+}

+ 8 - 8
llm/payload.go

@@ -82,8 +82,8 @@ func serversForGpu(info gpu.GpuInfo) []string {
 	// glob workDir for files that start with ollama_
 	availableServers := availableServers()
 	requested := info.Library
-	if info.Variant != "" {
-		requested += "_" + info.Variant
+	if info.Variant != gpu.CPUCapabilityNone {
+		requested += "_" + info.Variant.String()
 	}
 
 	servers := []string{}
@@ -117,14 +117,14 @@ func serversForGpu(info gpu.GpuInfo) []string {
 
 	// Load up the best CPU variant if not primary requested
 	if info.Library != "cpu" {
-		variant := gpu.GetCPUVariant()
+		variant := gpu.GetCPUCapability()
 		// If no variant, then we fall back to default
 		// If we have a variant, try that if we find an exact match
 		// Attempting to run the wrong CPU instructions will panic the
 		// process
-		if variant != "" {
+		if variant != gpu.CPUCapabilityNone {
 			for cmp := range availableServers {
-				if cmp == "cpu_"+variant {
+				if cmp == "cpu_"+variant.String() {
 					servers = append(servers, cmp)
 					break
 				}
@@ -146,11 +146,11 @@ func serverForCpu() string {
 	if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
 		return "metal"
 	}
-	variant := gpu.GetCPUVariant()
+	variant := gpu.GetCPUCapability()
 	availableServers := availableServers()
-	if variant != "" {
+	if variant != gpu.CPUCapabilityNone {
 		for cmp := range availableServers {
-			if cmp == "cpu_"+variant {
+			if cmp == "cpu_"+variant.String() {
 				return cmp
 			}
 		}

+ 50 - 36
llm/server.go

@@ -37,8 +37,9 @@ type LlamaServer interface {
 	Tokenize(ctx context.Context, content string) ([]int, error)
 	Detokenize(ctx context.Context, tokens []int) (string, error)
 	Close() error
-	EstimatedVRAM() uint64
+	EstimatedVRAM() uint64 // Total VRAM across all GPUs
 	EstimatedTotal() uint64
+	EstimatedVRAMByGPU(gpuID string) uint64
 }
 
 // llmServer is an instance of the llama.cpp server
@@ -49,13 +50,12 @@ type llmServer struct {
 	status  *StatusWriter
 	options api.Options
 
-	// TODO - this should be broken down by GPU
-	estimatedVRAM  uint64 // Estimated usage of VRAM by the loaded model
-	estimatedTotal uint64 // Total size of model
-	totalLayers    uint64
-	gpuCount       int
-	loadDuration   time.Duration // Record how long it took the model to load
-	loadProgress   float32
+	estimate    MemoryEstimate
+	totalLayers uint64
+	// gpuCount     int
+	gpus         gpu.GpuInfoList // Recorded just before the model loaded, free space will be incorrect
+	loadDuration time.Duration   // Record how long it took the model to load
+	loadProgress float32
 
 	sem *semaphore.Weighted
 }
@@ -80,16 +80,16 @@ func LoadModel(model string) (*GGML, error) {
 func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) {
 	var err error
 	var cpuRunner string
-	var estimatedVRAM uint64
-	var estimatedTotal uint64
+	var estimate MemoryEstimate
 	var systemMemory uint64
-	gpuCount := len(gpus)
-	if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 {
-		// TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner
 
+	// If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info
+	if opts.NumGPU == 0 {
+		gpus = gpu.GetCPUInfo()
+	}
+	if len(gpus) == 1 && gpus[0].Library == "cpu" {
 		cpuRunner = serverForCpu()
-		gpuCount = 0
-		_, _, estimatedTotal = EstimateGPULayers(gpus, ggml, projectors, opts)
+		estimate = EstimateGPULayers(gpus, ggml, projectors, opts)
 	} else {
 		if gpus[0].Library == "metal" {
 			memInfo, err := gpu.GetCPUMem()
@@ -100,20 +100,19 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 				slog.Debug("system memory", "total", format.HumanBytes2(systemMemory))
 			}
 		}
-		var layers int
-		layers, estimatedVRAM, estimatedTotal = EstimateGPULayers(gpus, ggml, projectors, opts)
+		estimate = EstimateGPULayers(gpus, ggml, projectors, opts)
 
 		switch {
-		case gpus[0].Library == "metal" && estimatedVRAM > systemMemory:
+		case gpus[0].Library == "metal" && estimate.VRAMSize > systemMemory:
 			// disable partial offloading when model is greater than total system memory as this
 			// can lead to locking up the system
 			opts.NumGPU = 0
-		case gpus[0].Library != "metal" && layers == 0:
+		case gpus[0].Library != "metal" && estimate.Layers == 0:
 			// Don't bother loading into the GPU if no layers can fit
 			cpuRunner = serverForCpu()
-			gpuCount = 0
-		case opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu":
-			opts.NumGPU = layers
+			gpus = gpu.GetCPUInfo()
+		case opts.NumGPU < 0 && estimate.Layers > 0 && gpus[0].Library != "cpu":
+			opts.NumGPU = estimate.Layers
 		}
 	}
 
@@ -232,6 +231,14 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 
 	params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
 
+	if estimate.TensorSplit != "" {
+		params = append(params, "--tensor-split", estimate.TensorSplit)
+	}
+
+	if estimate.TensorSplit != "" {
+		params = append(params, "--tensor-split", estimate.TensorSplit)
+	}
+
 	for i := range len(servers) {
 		dir := availableServers[servers[i]]
 		if dir == "" {
@@ -242,8 +249,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 		}
 
 		if strings.HasPrefix(servers[i], "cpu") {
-			// TODO if we tried a gpu runner first, and it failed, record the error and bubble that back up
-			gpuCount = 0
+			gpus = gpu.GetCPUInfo()
 		}
 
 		// Find an availableServers  port, retry on each iteration in case the failure was a port conflict race
@@ -299,16 +305,15 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 		}
 
 		s := &llmServer{
-			port:           port,
-			cmd:            exec.Command(server, finalParams...),
-			status:         NewStatusWriter(os.Stderr),
-			options:        opts,
-			estimatedVRAM:  estimatedVRAM,
-			estimatedTotal: estimatedTotal,
-			sem:            semaphore.NewWeighted(int64(numParallel)),
-			totalLayers:    ggml.KV().BlockCount() + 1,
-			gpuCount:       gpuCount,
-			done:           make(chan error, 1),
+			port:        port,
+			cmd:         exec.Command(server, finalParams...),
+			status:      NewStatusWriter(os.Stderr),
+			options:     opts,
+			estimate:    estimate,
+			sem:         semaphore.NewWeighted(int64(numParallel)),
+			totalLayers: ggml.KV().BlockCount() + 1,
+			gpus:        gpus,
+			done:        make(chan error, 1),
 		}
 
 		s.cmd.Env = os.Environ()
@@ -1004,11 +1009,20 @@ func (s *llmServer) Close() error {
 }
 
 func (s *llmServer) EstimatedVRAM() uint64 {
-	return s.estimatedVRAM
+	return s.estimate.VRAMSize
 }
 
 func (s *llmServer) EstimatedTotal() uint64 {
-	return s.estimatedTotal
+	return s.estimate.TotalSize
+}
+
+func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 {
+	for i, gpu := range s.gpus {
+		if gpu.ID == gpuID {
+			return s.estimate.GPUSizes[i]
+		}
+	}
+	return 0
 }
 
 func parseDurationMs(ms float64) time.Duration {

+ 127 - 42
server/sched.go

@@ -7,7 +7,6 @@ import (
 	"log/slog"
 	"reflect"
 	"runtime"
-	"slices"
 	"sort"
 	"strings"
 	"sync"
@@ -27,6 +26,7 @@ type LlmRequest struct {
 	sessionDuration time.Duration
 	successCh       chan *runnerRef
 	errCh           chan error
+	schedAttempts   uint
 }
 
 type Scheduler struct {
@@ -38,9 +38,11 @@ type Scheduler struct {
 	loaded   map[string]*runnerRef
 	loadedMu sync.Mutex
 
-	loadFn      func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList)
-	newServerFn func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error)
-	getGpuFn    func() gpu.GpuInfoList
+	loadFn       func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList)
+	newServerFn  func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error)
+	getGpuFn     func() gpu.GpuInfoList
+	getCpuFn     func() gpu.GpuInfoList
+	reschedDelay time.Duration
 }
 
 var ErrMaxQueue = fmt.Errorf("server busy, please try again.  maximum pending requests exceeded")
@@ -54,6 +56,8 @@ func InitScheduler(ctx context.Context) *Scheduler {
 		loaded:        make(map[string]*runnerRef),
 		newServerFn:   llm.NewLlamaServer,
 		getGpuFn:      gpu.GetGPUInfo,
+		getCpuFn:      gpu.GetCPUInfo,
+		reschedDelay:  250 * time.Millisecond,
 	}
 	sched.loadFn = sched.load
 	return sched
@@ -105,6 +109,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
 			return
 		case pending := <-s.pendingReqCh:
 			// Block other requests until we get this pending request running
+			pending.schedAttempts++
 
 			if pending.ctx.Err() != nil {
 				slog.Debug("pending request cancelled or timed out, skipping scheduling")
@@ -131,7 +136,12 @@ func (s *Scheduler) processPending(ctx context.Context) {
 				} else {
 					// Either no models are loaded or below envconfig.MaxRunners
 					// Get a refreshed GPU list
-					gpus := s.getGpuFn()
+					var gpus gpu.GpuInfoList
+					if pending.opts.NumGPU == 0 {
+						gpus = s.getCpuFn()
+					} else {
+						gpus = s.getGpuFn()
+					}
 
 					// Load model for fitting
 					ggml, err := llm.LoadModel(pending.model.ModelPath)
@@ -140,16 +150,22 @@ func (s *Scheduler) processPending(ctx context.Context) {
 						break
 					}
 
-					// If we're CPU only mode, just limit by envconfig.MaxRunners above
-					// TODO handle system memory exhaustion
-					if (len(gpus) == 1 && gpus[0].Library == "cpu") || pending.opts.NumGPU == 0 {
-						slog.Debug("cpu mode with existing models, loading")
-						s.loadFn(pending, ggml, gpus)
-						break
-					}
-
-					// No models loaded. Load the model but prefer the best fit.
-					if loadedCount == 0 {
+					// Evaluate if the model will fit in the available system memory, or if we should unload a model first
+					if len(gpus) == 1 && gpus[0].Library == "cpu" {
+						if loadedCount == 0 {
+							slog.Debug("cpu mode with first model, loading")
+							s.loadFn(pending, ggml, gpus)
+							break
+						}
+						runnerToExpire = s.maybeFindCPURunnerToUnload(pending, ggml, gpus)
+						if runnerToExpire == nil {
+							slog.Debug("cpu mode with available system memory or first model, loading")
+							s.loadFn(pending, ggml, gpus)
+							break
+						}
+						// else we need to expire a runner
+					} else if loadedCount == 0 {
+						// No models loaded. Load the model but prefer the best fit.
 						slog.Debug("loading first model", "model", pending.model.ModelPath)
 						g := pickBestFitGPUs(pending, ggml, gpus)
 						if g != nil {
@@ -159,16 +175,44 @@ func (s *Scheduler) processPending(ctx context.Context) {
 						break
 					}
 
-					// More than one loaded model, so we have to see if the new one fits
-					// Update free memory from currently loaded models
-					s.updateFreeSpace(gpus)
-					gpus = pickBestFitGPUs(pending, ggml, gpus)
-					if gpus != nil {
-						slog.Debug("new model fits with existing models, loading")
-						s.loadFn(pending, ggml, gpus)
-						break
+					if runnerToExpire == nil {
+						// More than one loaded model, so we have to see if the
+						// new one fits
+						//
+						// We want to avoid loading on any GPUs that have other
+						// models still loading on them to avoid potential races
+						// with VRAM consumption ramping up during load
+						availGpus := s.filterGPUsWithoutLoadingModels(gpus)
+
+						// Update free memory from currently loaded models
+						s.updateFreeSpace(availGpus)
+						fitGpus := pickBestFitGPUs(pending, ggml, availGpus)
+						if fitGpus != nil {
+							slog.Debug("new model fits with existing models, loading")
+							s.loadFn(pending, ggml, fitGpus)
+							break
+						}
+
+						// We couldn't find a set of GPUs to fully load the new
+						// model. If no other models are loading (both GPU lists
+						// are the same) then we need to unload another model to
+						// make room
+						if len(availGpus) < len(gpus) {
+							// There are other requests pending, and this one
+							// needs more time, so put it on the back of the
+							// queue so that we might satisfy other pending
+							// requests that aren't blocked
+							go func() {
+								// Process in a go routine to avoid deadlocking
+								// the scheduler if our queue is full
+								slog.Debug("delaying scheduling while other models finish loading", "attempts", pending.schedAttempts, "model", pending.model.ModelPath)
+								time.Sleep(s.reschedDelay)
+								s.pendingReqCh <- pending
+							}()
+							break
+						}
+						runnerToExpire = s.findRunnerToUnload()
 					}
-					runnerToExpire = s.findRunnerToUnload()
 				}
 
 				if runnerToExpire == nil {
@@ -368,17 +412,9 @@ func (s *Scheduler) updateFreeSpace(allGpus gpu.GpuInfoList) {
 	s.loadedMu.Lock()
 	for _, r := range s.loaded {
 		r.refMu.Lock()
-		gpuIDs := make([]string, 0, len(r.gpus))
 		if r.llama != nil {
-			// TODO this should be broken down by GPU instead of assuming uniform spread
-			estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus))
-			for _, gpu := range r.gpus {
-				gpuIDs = append(gpuIDs, gpu.ID)
-			}
 			for _, gpu := range allGpus {
-				if slices.Contains(gpuIDs, gpu.ID) {
-					predMap[predKey{gpu.Library, gpu.ID}] += estimatedVRAMPerGPU
-				}
+				predMap[predKey{gpu.Library, gpu.ID}] += r.llama.EstimatedVRAMByGPU(gpu.ID)
 			}
 		} else {
 			slog.Warn("unexpected nil runner reference, memory prediction may be incorrect")
@@ -401,11 +437,36 @@ func (s *Scheduler) updateFreeSpace(allGpus gpu.GpuInfoList) {
 				// after we start our first runner, then we'll never acount for that, so picking the smallest free value seems prudent.
 				allGpus[i].FreeMemory = allGpus[i].TotalMemory - p
 			}
-			slog.Info("updated VRAM", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "total", format.HumanBytes2(allGpus[i].TotalMemory), "available", format.HumanBytes2(allGpus[i].FreeMemory))
+			slog.Info("updated VRAM based on existing loaded models", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "total", format.HumanBytes2(allGpus[i].TotalMemory), "available", format.HumanBytes2(allGpus[i].FreeMemory))
 		}
 	}
 }
 
+// While models are loading the VRAM consumption numbers will be indeterminate, so we have
+// to avoid scheduling another model on the same GPU(s) that haven't stabilized.
+// This routine returns the set of GPUs that do not have an active loading model.
+// If all GPUs have loading models, an empty list will be returned (not a single CPU entry)
+func (s *Scheduler) filterGPUsWithoutLoadingModels(allGpus gpu.GpuInfoList) gpu.GpuInfoList {
+	ret := append(gpu.GpuInfoList{}, allGpus...)
+	s.loadedMu.Lock()
+	defer s.loadedMu.Unlock()
+	for _, runner := range s.loaded {
+		if runner.loading {
+			slog.Debug("overlapping loads detected", "gpus", runner.gpus, "model", runner.modelPath)
+			for _, busyGPU := range runner.gpus {
+				for i := range ret {
+					if ret[i].ID == busyGPU.ID {
+						ret = append(ret[:i], ret[i+1:]...)
+						break
+					}
+				}
+			}
+		}
+	}
+	return ret
+}
+
+// TODO consolidate sched_types.go
 type runnerRef struct {
 	refMu sync.Mutex
 	// refCond   sync.Cond // Signaled on transition from 1 -> 0 refCount
@@ -487,8 +548,11 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
 func (runner *runnerRef) waitForVRAMRecovery() chan interface{} {
 	finished := make(chan interface{}, 1)
 
-	// CPU or Metal don't need checking, so no waiting required, windows can page VRAM, and the APIs we query tend to be optimistic on free space
-	if (len(runner.gpus) == 1 && (runner.gpus[0].Library == "cpu" || runner.gpus[0].Library == "metal")) || runtime.GOOS == "windows" {
+	// CPU or Metal don't need checking, so no waiting required
+	// windows can page VRAM, only cuda currently can report accurate used vram usage
+	if len(runner.gpus) == 0 ||
+		(len(runner.gpus) == 1 && (runner.gpus[0].Library == "cpu" || runner.gpus[0].Library == "metal")) ||
+		(runtime.GOOS == "windows" && runner.gpus[0].Library != "cuda") {
 		finished <- struct{}{}
 		return finished
 	}
@@ -508,7 +572,7 @@ func (runner *runnerRef) waitForVRAMRecovery() chan interface{} {
 		for {
 			<-ticker.C
 			if time.Now().After(expiresAt) {
-				slog.Warn("gpu VRAM usage didn't recover within timeout", "seconds", time.Since(start).Seconds())
+				slog.Warn("gpu VRAM usage didn't recover within timeout", "seconds", time.Since(start).Seconds(), "model", runner.modelPath)
 				finished <- struct{}{}
 			}
 
@@ -521,7 +585,7 @@ func (runner *runnerRef) waitForVRAMRecovery() chan interface{} {
 			}
 			// If we're within ~80% of the estimated memory usage recovered, bail out
 			if float32(freeMemoryNow-freeMemoryBefore) > float32(runner.estimatedVRAM)*0.8 {
-				slog.Debug(fmt.Sprintf("gpu VRAM free memory converged after %0.2f seconds", time.Since(start).Seconds()))
+				slog.Debug(fmt.Sprintf("gpu VRAM free memory converged after %0.2f seconds", time.Since(start).Seconds()), "model", runner.modelPath)
 				finished <- struct{}{}
 				return
 			}
@@ -558,10 +622,12 @@ func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) gpu.
 		sort.Sort(sort.Reverse(gpu.ByFreeMemory(sgl)))
 
 		// First attempt to fit the model into a single GPU
-		for _, g := range sgl {
-			if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
-				slog.Debug("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
-				return []gpu.GpuInfo{g}
+		if !envconfig.SchedSpread {
+			for _, g := range sgl {
+				if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
+					slog.Debug("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
+					return []gpu.GpuInfo{g}
+				}
 			}
 		}
 
@@ -586,6 +652,10 @@ func (s *Scheduler) findRunnerToUnload() *runnerRef {
 		runnerList = append(runnerList, r)
 	}
 	s.loadedMu.Unlock()
+	if len(runnerList) == 0 {
+		slog.Debug("no loaded runner to unload")
+		return nil
+	}
 
 	// In the future we can enhance the algorithm to be smarter about picking the optimal runner to unload
 	// e.g., if we have multiple options, will one make room for the request?
@@ -616,3 +686,18 @@ func (s *Scheduler) unloadAllRunners() {
 		}
 	}
 }
+
+// If other runners are loaded, make sure the pending request will fit in system memory
+// If not, pick a runner to unload, else return nil and the request can be loaded
+func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) *runnerRef {
+	slog.Debug("evaluating if CPU model load will fit in available system memory")
+	estimate := llm.EstimateGPULayers(gpus, ggml, req.model.ProjectorPaths, req.opts)
+	if estimate.TotalSize <= gpus[0].FreeMemory {
+		slog.Debug("cpu inference mode, model fits in available system memory", "model", format.HumanBytes2(estimate.TotalSize), "available", format.HumanBytes2(gpus[0].FreeMemory))
+		return nil
+	}
+
+	// TODO - optimization: try to find CPU only runners first, or partial offloads with enough in system memory to make room
+
+	return s.findRunnerToUnload()
+}

+ 71 - 28
server/sched_test.go

@@ -60,7 +60,7 @@ func TestLoad(t *testing.T) {
 	err := <-req.errCh
 	require.Contains(t, err.Error(), "this model may be incompatible")
 
-	server := &mockLlm{estimatedVRAM: 10}
+	server := &mockLlm{estimatedVRAM: 10, estimatedVRAMByGPU: map[string]uint64{}}
 	s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
 		return server, nil
 	}
@@ -129,6 +129,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
 		"tokenizer.ggml.token_type":     []int32{0},
 	}, []llm.Tensor{
 		{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
+		{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
 	})
 	require.NoError(t, err)
 
@@ -145,7 +146,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
 		successCh:       make(chan *runnerRef, 1),
 		errCh:           make(chan error, 1),
 	}
-	scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM}
+	scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
 	return scenario
 }
 
@@ -155,7 +156,7 @@ func TestRequests(t *testing.T) {
 
 	// Same model, same request
 	scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
-	scenario1a.req.sessionDuration = 0
+	scenario1a.req.sessionDuration = 5 * time.Millisecond
 	scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
 	scenario1b.req.model = scenario1a.req.model
 	scenario1b.ggml = scenario1a.ggml
@@ -166,6 +167,7 @@ func TestRequests(t *testing.T) {
 	tmpModel := *scenario1a.req.model
 	scenario2a.req.model = &tmpModel
 	scenario2a.ggml = scenario1a.ggml
+	scenario2a.req.sessionDuration = 5 * time.Millisecond
 
 	// Multiple loaded models
 	scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
@@ -181,6 +183,12 @@ func TestRequests(t *testing.T) {
 		g.FreeMemory = 12 * format.GigaByte
 		return []gpu.GpuInfo{g}
 	}
+	s.getCpuFn = func() gpu.GpuInfoList {
+		g := gpu.GpuInfo{Library: "cpu"}
+		g.TotalMemory = 32 * format.GigaByte
+		g.FreeMemory = 26 * format.GigaByte
+		return []gpu.GpuInfo{g}
+	}
 	s.newServerFn = scenario1a.newServer
 	slog.Info("scenario1a")
 	s.pendingReqCh <- scenario1a.req
@@ -309,7 +317,6 @@ func TestGetRunner(t *testing.T) {
 	ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
 	defer done()
 
-	// Same model, same request
 	scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
 	scenario1a.req.sessionDuration = 0
 	scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
@@ -419,7 +426,7 @@ func TestUseLoadedRunner(t *testing.T) {
 		sessionDuration: 2,
 	}
 	finished := make(chan *LlmRequest)
-	llm1 := &mockLlm{}
+	llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
 	r1 := &runnerRef{llama: llm1, sessionDuration: 1}
 	req.useLoadedRunner(r1, finished)
 	require.Equal(t, uint(1), r1.refCount)
@@ -452,8 +459,8 @@ func TestUpdateFreeSpace(t *testing.T) {
 	gpus[0].FreeMemory = 900
 	gpus[1].TotalMemory = 2000
 	gpus[1].FreeMemory = 1900
-	llm1 := &mockLlm{estimatedVRAM: 100}
-	llm2 := &mockLlm{estimatedVRAM: 200}
+	llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{"1": 50, "2": 50}}
+	llm2 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{"1": 125, "2": 75}}
 	r1 := &runnerRef{llama: llm1, gpus: gpus}
 	r2 := &runnerRef{llama: llm2, gpus: gpus}
 
@@ -464,8 +471,42 @@ func TestUpdateFreeSpace(t *testing.T) {
 	s.loadedMu.Unlock()
 
 	s.updateFreeSpace(gpus)
-	require.Equal(t, uint64(850), gpus[0].FreeMemory)
-	require.Equal(t, uint64(1850), gpus[1].FreeMemory)
+	require.Equal(t, uint64(1000-50-125), gpus[0].FreeMemory)
+	require.Equal(t, uint64(2000-50-75), gpus[1].FreeMemory)
+}
+
+func TestFilterGPUsWithoutLoadingModels(t *testing.T) {
+	ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
+	defer done()
+	gpus := gpu.GpuInfoList{
+		{
+			Library: "cuda",
+			ID:      "0",
+		},
+		{
+			Library: "cuda",
+			ID:      "1",
+		},
+	}
+	r1 := &runnerRef{gpus: gpu.GpuInfoList{gpus[0]}, loading: true}
+
+	s := InitScheduler(ctx)
+	s.loadedMu.Lock()
+	s.loaded["a"] = r1
+	s.loadedMu.Unlock()
+
+	tmp := s.filterGPUsWithoutLoadingModels(gpus)
+	require.Len(t, tmp, 1)
+	require.Equal(t, "1", tmp[0].ID)
+
+	r1.gpus = gpu.GpuInfoList{gpus[1]}
+	tmp = s.filterGPUsWithoutLoadingModels(gpus)
+	require.Len(t, tmp, 1)
+	require.Equal(t, "0", tmp[0].ID)
+
+	r1.gpus = gpu.GpuInfoList{}
+	tmp = s.filterGPUsWithoutLoadingModels(gpus)
+	require.Len(t, tmp, 2)
 }
 
 func TestFindRunnerToUnload(t *testing.T) {
@@ -492,7 +533,7 @@ func TestNeedsReload(t *testing.T) {
 	ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
 	defer done()
 
-	llm := &mockLlm{}
+	llm := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
 	do := api.DefaultOptions()
 	runner := &runnerRef{
 		model:   &Model{AdapterPaths: []string{"adapter1"}, ProjectorPaths: []string{"projector1"}},
@@ -535,8 +576,8 @@ func TestUnloadAllRunners(t *testing.T) {
 	ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
 	defer done()
 
-	llm1 := &mockLlm{}
-	llm2 := &mockLlm{}
+	llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
+	llm2 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
 	s := InitScheduler(ctx)
 	s.unloadAllRunners()
 
@@ -554,7 +595,7 @@ func TestUnloadAllRunners(t *testing.T) {
 }
 
 func TestUnload(t *testing.T) {
-	llm1 := &mockLlm{}
+	llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
 	r1 := &runnerRef{llama: llm1}
 	r2 := &runnerRef{model: &Model{AdapterPaths: []string{"A"}}}
 	r1.unload()
@@ -564,19 +605,20 @@ func TestUnload(t *testing.T) {
 }
 
 type mockLlm struct {
-	pingResp          error
-	waitResp          error
-	completionResp    error
-	embeddingResp     []float64
-	embeddingRespErr  error
-	tokenizeResp      []int
-	tokenizeRespErr   error
-	detokenizeResp    string
-	detonekizeRespErr error
-	closeResp         error
-	closeCalled       bool
-	estimatedVRAM     uint64
-	estimatedTotal    uint64
+	pingResp           error
+	waitResp           error
+	completionResp     error
+	embeddingResp      []float64
+	embeddingRespErr   error
+	tokenizeResp       []int
+	tokenizeRespErr    error
+	detokenizeResp     string
+	detonekizeRespErr  error
+	closeResp          error
+	closeCalled        bool
+	estimatedVRAM      uint64
+	estimatedTotal     uint64
+	estimatedVRAMByGPU map[string]uint64
 }
 
 func (s *mockLlm) Ping(ctx context.Context) error             { return s.pingResp }
@@ -597,5 +639,6 @@ func (s *mockLlm) Close() error {
 	s.closeCalled = true
 	return s.closeResp
 }
-func (s *mockLlm) EstimatedVRAM() uint64  { return s.estimatedVRAM }
-func (s *mockLlm) EstimatedTotal() uint64 { return s.estimatedTotal }
+func (s *mockLlm) EstimatedVRAM() uint64                  { return s.estimatedVRAM }
+func (s *mockLlm) EstimatedTotal() uint64                 { return s.estimatedTotal }
+func (s *mockLlm) EstimatedVRAMByGPU(gpuid string) uint64 { return s.estimatedVRAMByGPU[gpuid] }