Browse Source

Merge pull request #4517 from dhiltgen/gpu_incremental

Enhanced GPU discovery and multi-gpu support with concurrency
Daniel Hiltgen 10 months ago
parent
commit
45cacbaf05

+ 12 - 0
envconfig/config.go

@@ -53,6 +53,8 @@ var (
 	NumParallel int
 	NumParallel int
 	// Set via OLLAMA_RUNNERS_DIR in the environment
 	// Set via OLLAMA_RUNNERS_DIR in the environment
 	RunnersDir string
 	RunnersDir string
+	// Set via OLLAMA_SCHED_SPREAD in the environment
+	SchedSpread bool
 	// Set via OLLAMA_TMPDIR in the environment
 	// Set via OLLAMA_TMPDIR in the environment
 	TmpDir string
 	TmpDir string
 )
 )
@@ -79,6 +81,7 @@ func AsMap() map[string]EnvVar {
 		"OLLAMA_NUM_PARALLEL":      {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests (default 1)"},
 		"OLLAMA_NUM_PARALLEL":      {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests (default 1)"},
 		"OLLAMA_ORIGINS":           {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"},
 		"OLLAMA_ORIGINS":           {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"},
 		"OLLAMA_RUNNERS_DIR":       {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"},
 		"OLLAMA_RUNNERS_DIR":       {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"},
+		"OLLAMA_SCHED_SPREAD":      {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"},
 		"OLLAMA_TMPDIR":            {"OLLAMA_TMPDIR", TmpDir, "Location for temporary files"},
 		"OLLAMA_TMPDIR":            {"OLLAMA_TMPDIR", TmpDir, "Location for temporary files"},
 	}
 	}
 }
 }
@@ -191,6 +194,15 @@ func LoadConfig() {
 		NoHistory = true
 		NoHistory = true
 	}
 	}
 
 
+	if spread := clean("OLLAMA_SCHED_SPREAD"); spread != "" {
+		s, err := strconv.ParseBool(spread)
+		if err == nil {
+			SchedSpread = s
+		} else {
+			SchedSpread = true
+		}
+	}
+
 	if noprune := clean("OLLAMA_NOPRUNE"); noprune != "" {
 	if noprune := clean("OLLAMA_NOPRUNE"); noprune != "" {
 		NoPrune = true
 		NoPrune = true
 	}
 	}

+ 131 - 75
gpu/amd_linux.go

@@ -25,7 +25,16 @@ const (
 
 
 	// Prefix with the node dir
 	// Prefix with the node dir
 	GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
 	GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
-	GPUUsedMemoryFileGlob  = "mem_banks/*/used_memory"
+
+	// Direct Rendering Manager sysfs location
+	DRMDeviceDirGlob   = "/sys/class/drm/card*/device"
+	DRMTotalMemoryFile = "mem_info_vram_total"
+	DRMUsedMemoryFile  = "mem_info_vram_used"
+
+	// In hex; properties file is in decimal
+	DRMUniqueIDFile = "unique_id"
+	DRMVendorFile   = "vendor"
+	DRMDeviceFile   = "device"
 )
 )
 
 
 var (
 var (
@@ -35,8 +44,8 @@ var (
 )
 )
 
 
 // Gather GPU information from the amdgpu driver if any supported GPUs are detected
 // Gather GPU information from the amdgpu driver if any supported GPUs are detected
-func AMDGetGPUInfo() []GpuInfo {
-	resp := []GpuInfo{}
+func AMDGetGPUInfo() []RocmGPUInfo {
+	resp := []RocmGPUInfo{}
 	if !AMDDetected() {
 	if !AMDDetected() {
 		return resp
 		return resp
 	}
 	}
@@ -90,7 +99,7 @@ func AMDGetGPUInfo() []GpuInfo {
 		scanner := bufio.NewScanner(fp)
 		scanner := bufio.NewScanner(fp)
 		isCPU := false
 		isCPU := false
 		var major, minor, patch uint64
 		var major, minor, patch uint64
-		var vendor, device uint64
+		var vendor, device, uniqueID uint64
 		for scanner.Scan() {
 		for scanner.Scan() {
 			line := strings.TrimSpace(scanner.Text())
 			line := strings.TrimSpace(scanner.Text())
 			// Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
 			// Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
@@ -121,30 +130,43 @@ func AMDGetGPUInfo() []GpuInfo {
 			} else if strings.HasPrefix(line, "vendor_id") {
 			} else if strings.HasPrefix(line, "vendor_id") {
 				ver := strings.Fields(line)
 				ver := strings.Fields(line)
 				if len(ver) != 2 {
 				if len(ver) != 2 {
-					slog.Debug("malformed vendor_id", "vendor_id", line)
+					slog.Debug("malformed", "vendor_id", line)
 					continue
 					continue
 				}
 				}
-				vendor, err = strconv.ParseUint(ver[1], 10, 32)
+				vendor, err = strconv.ParseUint(ver[1], 10, 64)
 				if err != nil {
 				if err != nil {
-					slog.Debug("malformed vendor_id" + line)
+					slog.Debug("malformed", "vendor_id", line, "error", err)
 				}
 				}
 			} else if strings.HasPrefix(line, "device_id") {
 			} else if strings.HasPrefix(line, "device_id") {
 				ver := strings.Fields(line)
 				ver := strings.Fields(line)
 				if len(ver) != 2 {
 				if len(ver) != 2 {
-					slog.Debug("malformed device_id", "device_id", line)
+					slog.Debug("malformed", "device_id", line)
+					continue
+				}
+				device, err = strconv.ParseUint(ver[1], 10, 64)
+				if err != nil {
+					slog.Debug("malformed", "device_id", line, "error", err)
+				}
+			} else if strings.HasPrefix(line, "unique_id") {
+				ver := strings.Fields(line)
+				if len(ver) != 2 {
+					slog.Debug("malformed", "unique_id", line)
 					continue
 					continue
 				}
 				}
-				device, err = strconv.ParseUint(ver[1], 10, 32)
+				uniqueID, err = strconv.ParseUint(ver[1], 10, 64)
 				if err != nil {
 				if err != nil {
-					slog.Debug("malformed device_id" + line)
+					slog.Debug("malformed", "unique_id", line, "error", err)
 				}
 				}
 			}
 			}
-
 			// TODO - any other properties we want to extract and record?
 			// TODO - any other properties we want to extract and record?
 			// vendor_id + device_id -> pci lookup for "Name"
 			// vendor_id + device_id -> pci lookup for "Name"
 			// Other metrics that may help us understand relative performance between multiple GPUs
 			// Other metrics that may help us understand relative performance between multiple GPUs
 		}
 		}
 
 
+		// Note: while ./mem_banks/*/used_memory exists, it doesn't appear to take other VRAM consumers
+		// into consideration, so we instead map the device over to the DRM driver sysfs nodes which
+		// do reliably report VRAM usage.
+
 		if isCPU {
 		if isCPU {
 			cpuCount++
 			cpuCount++
 			continue
 			continue
@@ -156,7 +178,7 @@ func AMDGetGPUInfo() []GpuInfo {
 		// Shouldn't happen, but just in case...
 		// Shouldn't happen, but just in case...
 		if gpuID < 0 {
 		if gpuID < 0 {
 			slog.Error("unexpected amdgpu sysfs data resulted in negative GPU ID, please set OLLAMA_DEBUG=1 and report an issue")
 			slog.Error("unexpected amdgpu sysfs data resulted in negative GPU ID, please set OLLAMA_DEBUG=1 and report an issue")
-			return []GpuInfo{}
+			return nil
 		}
 		}
 
 
 		if int(major) < RocmComputeMin {
 		if int(major) < RocmComputeMin {
@@ -167,65 +189,68 @@ func AMDGetGPUInfo() []GpuInfo {
 		// Look up the memory for the current node
 		// Look up the memory for the current node
 		totalMemory := uint64(0)
 		totalMemory := uint64(0)
 		usedMemory := uint64(0)
 		usedMemory := uint64(0)
-		propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUTotalMemoryFileGlob)
-		propFiles, err := filepath.Glob(propGlob)
-		if err != nil {
-			slog.Warn("error looking up total GPU memory", "glob", propGlob, "error", err)
+		var usedFile string
+		mapping := []struct {
+			id       uint64
+			filename string
+		}{
+			{vendor, DRMVendorFile},
+			{device, DRMDeviceFile},
+			{uniqueID, DRMUniqueIDFile}, // Not all devices will report this
 		}
 		}
-		// 1 or more memory banks - sum the values of all of them
-		for _, propFile := range propFiles {
-			fp, err := os.Open(propFile)
-			if err != nil {
-				slog.Warn("failed to open sysfs node", "file", propFile, "erroir", err)
-				continue
-			}
-			defer fp.Close()
-			scanner := bufio.NewScanner(fp)
-			for scanner.Scan() {
-				line := strings.TrimSpace(scanner.Text())
-				if strings.HasPrefix(line, "size_in_bytes") {
-					ver := strings.Fields(line)
-					if len(ver) != 2 {
-						slog.Warn("malformed " + line)
-						continue
-					}
-					bankSizeInBytes, err := strconv.ParseUint(ver[1], 10, 64)
-					if err != nil {
-						slog.Warn("malformed int " + line)
-						continue
-					}
-					totalMemory += bankSizeInBytes
+		slog.Debug("mapping amdgpu to drm sysfs nodes", "amdgpu", match, "vendor", vendor, "device", device, "unique_id", uniqueID)
+		// Map over to DRM location to find the total/free memory
+		drmMatches, _ := filepath.Glob(DRMDeviceDirGlob)
+		for _, devDir := range drmMatches {
+			matched := true
+			for _, m := range mapping {
+				if m.id == 0 {
+					// Null ID means it didn't populate, so we can't use it to match
+					continue
+				}
+				filename := filepath.Join(devDir, m.filename)
+				buf, err := os.ReadFile(filename)
+				if err != nil {
+					slog.Debug("failed to read sysfs node", "file", filename, "error", err)
+					matched = false
+					break
+				}
+				// values here are in hex, strip off the lead 0x and parse so we can compare the numeric (decimal) values in amdgpu
+				cmp, err := strconv.ParseUint(strings.TrimPrefix(strings.TrimSpace(string(buf)), "0x"), 16, 64)
+				if err != nil {
+					slog.Debug("failed to parse sysfs node", "file", filename, "error", err)
+					matched = false
+					break
+				}
+				if cmp != m.id {
+					matched = false
+					break
 				}
 				}
 			}
 			}
-		}
-		if totalMemory == 0 {
-			slog.Warn("amdgpu reports zero total memory", "gpu", gpuID)
-			continue
-		}
-		usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUUsedMemoryFileGlob)
-		usedFiles, err := filepath.Glob(usedGlob)
-		if err != nil {
-			slog.Warn("error looking up used GPU memory", "glob", usedGlob, "error", err)
-			continue
-		}
-		for _, usedFile := range usedFiles {
-			fp, err := os.Open(usedFile)
-			if err != nil {
-				slog.Warn("failed to open sysfs node", "file", usedFile, "error", err)
+			if !matched {
 				continue
 				continue
 			}
 			}
-			defer fp.Close()
-			data, err := io.ReadAll(fp)
+
+			// Found the matching DRM directory
+			slog.Debug("matched", "amdgpu", match, "drm", devDir)
+			totalFile := filepath.Join(devDir, DRMTotalMemoryFile)
+			buf, err := os.ReadFile(totalFile)
 			if err != nil {
 			if err != nil {
-				slog.Warn("failed to read sysfs node", "file", usedFile, "error", err)
-				continue
+				slog.Debug("failed to read sysfs node", "file", totalFile, "error", err)
+				break
 			}
 			}
-			used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
+			totalMemory, err = strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64)
 			if err != nil {
 			if err != nil {
-				slog.Warn("malformed used memory", "data", string(data), "error", err)
-				continue
+				slog.Debug("failed to parse sysfs node", "file", totalFile, "error", err)
+				break
+			}
+
+			usedFile = filepath.Join(devDir, DRMUsedMemoryFile)
+			usedMemory, err = getFreeMemory(usedFile)
+			if err != nil {
+				slog.Debug("failed to update used memory", "error", err)
 			}
 			}
-			usedMemory += used
+			break
 		}
 		}
 
 
 		// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
 		// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
@@ -241,18 +266,21 @@ func AMDGetGPUInfo() []GpuInfo {
 
 
 		slog.Debug("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
 		slog.Debug("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
 		slog.Debug("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
 		slog.Debug("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
-		gpuInfo := GpuInfo{
-			Library: "rocm",
-			memInfo: memInfo{
-				TotalMemory: totalMemory,
-				FreeMemory:  (totalMemory - usedMemory),
+		gpuInfo := RocmGPUInfo{
+			GpuInfo: GpuInfo{
+				Library: "rocm",
+				memInfo: memInfo{
+					TotalMemory: totalMemory,
+					FreeMemory:  (totalMemory - usedMemory),
+				},
+				ID:            strconv.Itoa(gpuID),
+				Name:          name,
+				Compute:       fmt.Sprintf("gfx%d%x%x", major, minor, patch),
+				MinimumMemory: rocmMinimumMemory,
+				DriverMajor:   driverMajor,
+				DriverMinor:   driverMinor,
 			},
 			},
-			ID:            fmt.Sprintf("%d", gpuID),
-			Name:          name,
-			Compute:       fmt.Sprintf("gfx%d%x%x", major, minor, patch),
-			MinimumMemory: rocmMinimumMemory,
-			DriverMajor:   driverMajor,
-			DriverMinor:   driverMinor,
+			usedFilepath: usedFile,
 		}
 		}
 
 
 		// If the user wants to filter to a subset of devices, filter out if we aren't a match
 		// If the user wants to filter to a subset of devices, filter out if we aren't a match
@@ -276,7 +304,7 @@ func AMDGetGPUInfo() []GpuInfo {
 			libDir, err = AMDValidateLibDir()
 			libDir, err = AMDValidateLibDir()
 			if err != nil {
 			if err != nil {
 				slog.Warn("unable to verify rocm library, will use cpu", "error", err)
 				slog.Warn("unable to verify rocm library, will use cpu", "error", err)
-				return []GpuInfo{}
+				return nil
 			}
 			}
 		}
 		}
 		gpuInfo.DependencyPath = libDir
 		gpuInfo.DependencyPath = libDir
@@ -287,7 +315,7 @@ func AMDGetGPUInfo() []GpuInfo {
 				supported, err = GetSupportedGFX(libDir)
 				supported, err = GetSupportedGFX(libDir)
 				if err != nil {
 				if err != nil {
 					slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
 					slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
-					return []GpuInfo{}
+					return nil
 				}
 				}
 				slog.Debug("rocm supported GPUs", "types", supported)
 				slog.Debug("rocm supported GPUs", "types", supported)
 			}
 			}
@@ -378,3 +406,31 @@ func AMDDriverVersion() (driverMajor, driverMinor int, err error) {
 	}
 	}
 	return driverMajor, driverMinor, nil
 	return driverMajor, driverMinor, nil
 }
 }
+
+func (gpus RocmGPUInfoList) RefreshFreeMemory() error {
+	if len(gpus) == 0 {
+		return nil
+	}
+	for i := range gpus {
+		usedMemory, err := getFreeMemory(gpus[i].usedFilepath)
+		if err != nil {
+			return err
+		}
+		slog.Debug("updating rocm free memory", "gpu", gpus[i].ID, "name", gpus[i].Name, "before", format.HumanBytes2(gpus[i].FreeMemory), "now", format.HumanBytes2(gpus[i].TotalMemory-usedMemory))
+		gpus[i].FreeMemory = gpus[i].TotalMemory - usedMemory
+	}
+	return nil
+}
+
+func getFreeMemory(usedFile string) (uint64, error) {
+	buf, err := os.ReadFile(usedFile)
+	if err != nil {
+		return 0, fmt.Errorf("failed to read sysfs node %s %w", usedFile, err)
+	}
+	usedMemory, err := strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64)
+	if err != nil {
+		slog.Debug("failed to parse sysfs node", "file", usedFile, "error", err)
+		return 0, fmt.Errorf("failed to parse sysfs node %s %w", usedFile, err)
+	}
+	return usedMemory, nil
+}

+ 47 - 16
gpu/amd_windows.go

@@ -7,6 +7,7 @@ import (
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
 	"slices"
 	"slices"
+	"strconv"
 	"strings"
 	"strings"
 
 
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/format"
@@ -24,8 +25,8 @@ var (
 	RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob?
 	RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob?
 )
 )
 
 
-func AMDGetGPUInfo() []GpuInfo {
-	resp := []GpuInfo{}
+func AMDGetGPUInfo() []RocmGPUInfo {
+	resp := []RocmGPUInfo{}
 	hl, err := NewHipLib()
 	hl, err := NewHipLib()
 	if err != nil {
 	if err != nil {
 		slog.Debug(err.Error())
 		slog.Debug(err.Error())
@@ -117,21 +118,24 @@ func AMDGetGPUInfo() []GpuInfo {
 		// v5.7 only reports VRAM used by this process, so it's completely wrong and unusable
 		// v5.7 only reports VRAM used by this process, so it's completely wrong and unusable
 		slog.Debug("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory))
 		slog.Debug("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory))
 		slog.Debug("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory))
 		slog.Debug("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory))
-		gpuInfo := GpuInfo{
-			Library: "rocm",
-			memInfo: memInfo{
-				TotalMemory: totalMemory,
-				FreeMemory:  freeMemory,
+		gpuInfo := RocmGPUInfo{
+			GpuInfo: GpuInfo{
+				Library: "rocm",
+				memInfo: memInfo{
+					TotalMemory: totalMemory,
+					FreeMemory:  freeMemory,
+				},
+				ID:             strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices
+				DependencyPath: libDir,
+				MinimumMemory:  rocmMinimumMemory,
+				Name:           name,
+				Compute:        gfx,
+
+				// TODO - this information isn't accurate on windows, so don't report it until we find the right way to retrieve
+				// DriverMajor:    driverMajor,
+				// DriverMinor:    driverMinor,
 			},
 			},
-			ID:             fmt.Sprintf("%d", i), // TODO this is probably wrong if we specify visible devices
-			DependencyPath: libDir,
-			MinimumMemory:  rocmMinimumMemory,
-			Name:           name,
-			Compute:        gfx,
-
-			// TODO - this information isn't accurate on windows, so don't report it until we find the right way to retrieve
-			// DriverMajor:    driverMajor,
-			// DriverMinor:    driverMinor,
+			index: i,
 		}
 		}
 
 
 		resp = append(resp, gpuInfo)
 		resp = append(resp, gpuInfo)
@@ -159,3 +163,30 @@ func AMDValidateLibDir() (string, error) {
 	slog.Warn("amdgpu detected, but no compatible rocm library found.  Please install ROCm")
 	slog.Warn("amdgpu detected, but no compatible rocm library found.  Please install ROCm")
 	return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
 	return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
 }
 }
+
+func (gpus RocmGPUInfoList) RefreshFreeMemory() error {
+	if len(gpus) == 0 {
+		return nil
+	}
+	hl, err := NewHipLib()
+	if err != nil {
+		slog.Debug(err.Error())
+		return nil
+	}
+	defer hl.Release()
+
+	for i := range gpus {
+		err := hl.HipSetDevice(gpus[i].index)
+		if err != nil {
+			return err
+		}
+		freeMemory, _, err := hl.HipMemGetInfo()
+		if err != nil {
+			slog.Warn("get mem info", "id", i, "error", err)
+			continue
+		}
+		slog.Debug("updating rocm free memory", "gpu", gpus[i].ID, "name", gpus[i].Name, "before", format.HumanBytes2(gpus[i].FreeMemory), "now", format.HumanBytes2(freeMemory))
+		gpus[i].FreeMemory = freeMemory
+	}
+	return nil
+}

+ 4 - 9
gpu/cpu_common.go

@@ -1,21 +1,16 @@
 package gpu
 package gpu
 
 
 import (
 import (
-	"log/slog"
-
 	"golang.org/x/sys/cpu"
 	"golang.org/x/sys/cpu"
 )
 )
 
 
-func GetCPUVariant() string {
+func GetCPUCapability() CPUCapability {
 	if cpu.X86.HasAVX2 {
 	if cpu.X86.HasAVX2 {
-		slog.Debug("CPU has AVX2")
-		return "avx2"
+		return CPUCapabilityAVX2
 	}
 	}
 	if cpu.X86.HasAVX {
 	if cpu.X86.HasAVX {
-		slog.Debug("CPU has AVX")
-		return "avx"
+		return CPUCapabilityAVX
 	}
 	}
-	slog.Debug("CPU does not have vector extensions")
 	// else LCD
 	// else LCD
-	return ""
+	return CPUCapabilityNone
 }
 }

+ 334 - 160
gpu/gpu.go

@@ -24,19 +24,37 @@ import (
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/format"
 )
 )
 
 
-type handles struct {
+type cudaHandles struct {
 	deviceCount int
 	deviceCount int
 	cudart      *C.cudart_handle_t
 	cudart      *C.cudart_handle_t
 	nvcuda      *C.nvcuda_handle_t
 	nvcuda      *C.nvcuda_handle_t
+	nvml        *C.nvml_handle_t
+}
+
+type oneapiHandles struct {
 	oneapi      *C.oneapi_handle_t
 	oneapi      *C.oneapi_handle_t
+	deviceCount int
 }
 }
 
 
 const (
 const (
 	cudaMinimumMemory = 457 * format.MebiByte
 	cudaMinimumMemory = 457 * format.MebiByte
 	rocmMinimumMemory = 457 * format.MebiByte
 	rocmMinimumMemory = 457 * format.MebiByte
+	// TODO OneAPI minimum memory
 )
 )
 
 
-var gpuMutex sync.Mutex
+var (
+	gpuMutex      sync.Mutex
+	bootstrapped  bool
+	cpuCapability CPUCapability
+	cpus          []CPUInfo
+	cudaGPUs      []CudaGPUInfo
+	nvcudaLibPath string
+	cudartLibPath string
+	oneapiLibPath string
+	nvmlLibPath   string
+	rocmGPUs      []RocmGPUInfo
+	oneapiGPUs    []OneapiGPUInfo
+)
 
 
 // With our current CUDA compile flags, older than 5.0 will not work properly
 // With our current CUDA compile flags, older than 5.0 will not work properly
 var CudaComputeMin = [2]C.int{5, 0}
 var CudaComputeMin = [2]C.int{5, 0}
@@ -46,113 +64,113 @@ var RocmComputeMin = 9
 // TODO find a better way to detect iGPU instead of minimum memory
 // TODO find a better way to detect iGPU instead of minimum memory
 const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU
 const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU
 
 
-var CudartLinuxGlobs = []string{
-	"/usr/local/cuda/lib64/libcudart.so*",
-	"/usr/lib/x86_64-linux-gnu/nvidia/current/libcudart.so*",
-	"/usr/lib/x86_64-linux-gnu/libcudart.so*",
-	"/usr/lib/wsl/lib/libcudart.so*",
-	"/usr/lib/wsl/drivers/*/libcudart.so*",
-	"/opt/cuda/lib64/libcudart.so*",
-	"/usr/local/cuda*/targets/aarch64-linux/lib/libcudart.so*",
-	"/usr/lib/aarch64-linux-gnu/nvidia/current/libcudart.so*",
-	"/usr/lib/aarch64-linux-gnu/libcudart.so*",
-	"/usr/local/cuda/lib*/libcudart.so*",
-	"/usr/lib*/libcudart.so*",
-	"/usr/local/lib*/libcudart.so*",
-}
-
-var CudartWindowsGlobs = []string{
-	"c:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*\\bin\\cudart64_*.dll",
-}
-
-var NvcudaLinuxGlobs = []string{
-	"/usr/local/cuda*/targets/*/lib/libcuda.so*",
-	"/usr/lib/*-linux-gnu/nvidia/current/libcuda.so*",
-	"/usr/lib/*-linux-gnu/libcuda.so*",
-	"/usr/lib/wsl/lib/libcuda.so*",
-	"/usr/lib/wsl/drivers/*/libcuda.so*",
-	"/opt/cuda/lib*/libcuda.so*",
-	"/usr/local/cuda/lib*/libcuda.so*",
-	"/usr/lib*/libcuda.so*",
-	"/usr/local/lib*/libcuda.so*",
-}
-
-var NvcudaWindowsGlobs = []string{
-	"c:\\windows\\system*\\nvcuda.dll",
-}
-
-var OneapiWindowsGlobs = []string{
-	"c:\\Windows\\System32\\DriverStore\\FileRepository\\*\\ze_intel_gpu64.dll",
-}
-
-var OneapiLinuxGlobs = []string{
-	"/usr/lib/x86_64-linux-gnu/libze_intel_gpu.so*",
-	"/usr/lib*/libze_intel_gpu.so*",
-}
-
 // Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed.
 // Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed.
 // Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
 // Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
 var CudaTegra string = os.Getenv("JETSON_JETPACK")
 var CudaTegra string = os.Getenv("JETSON_JETPACK")
 
 
 // Note: gpuMutex must already be held
 // Note: gpuMutex must already be held
-func initGPUHandles() *handles {
+func initCudaHandles() *cudaHandles {
 
 
 	// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
 	// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
 
 
-	gpuHandles := &handles{}
-	var cudartMgmtName string
+	cHandles := &cudaHandles{}
+	// Short Circuit if we already know which library to use
+	if nvmlLibPath != "" {
+		cHandles.nvml, _ = LoadNVMLMgmt([]string{nvmlLibPath})
+		return cHandles
+	}
+	if nvcudaLibPath != "" {
+		cHandles.deviceCount, cHandles.nvcuda, _ = LoadNVCUDAMgmt([]string{nvcudaLibPath})
+		return cHandles
+	}
+	if cudartLibPath != "" {
+		cHandles.deviceCount, cHandles.cudart, _ = LoadCUDARTMgmt([]string{cudartLibPath})
+		return cHandles
+	}
+
+	slog.Debug("searching for GPU discovery libraries for NVIDIA")
 	var cudartMgmtPatterns []string
 	var cudartMgmtPatterns []string
-	var nvcudaMgmtName string
-	var nvcudaMgmtPatterns []string
 
 
-	tmpDir, _ := PayloadsDir()
-	switch runtime.GOOS {
-	case "windows":
-		cudartMgmtName = "cudart64_*.dll"
+	// Aligned with driver, we can't carry as payloads
+	nvcudaMgmtPatterns := NvcudaGlobs
+
+	if runtime.GOOS == "windows" {
 		localAppData := os.Getenv("LOCALAPPDATA")
 		localAppData := os.Getenv("LOCALAPPDATA")
-		cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)}
-		cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...)
-		// Aligned with driver, we can't carry as payloads
-		nvcudaMgmtName = "nvcuda.dll"
-		nvcudaMgmtPatterns = NvcudaWindowsGlobs
-	case "linux":
-		cudartMgmtName = "libcudart.so*"
-		if tmpDir != "" {
-			// TODO - add "payloads" for subprocess
-			cudartMgmtPatterns = []string{filepath.Join(tmpDir, "cuda*", cudartMgmtName)}
+		cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", CudartMgmtName)}
+	}
+	tmpDir, _ := PayloadsDir()
+	if tmpDir != "" {
+		// TODO - add "payloads" for subprocess
+		cudartMgmtPatterns = []string{filepath.Join(tmpDir, "cuda*", CudartMgmtName)}
+	}
+	cudartMgmtPatterns = append(cudartMgmtPatterns, CudartGlobs...)
+
+	if len(NvmlGlobs) > 0 {
+		nvmlLibPaths := FindGPULibs(NvmlMgmtName, NvmlGlobs)
+		if len(nvmlLibPaths) > 0 {
+			nvml, libPath := LoadNVMLMgmt(nvmlLibPaths)
+			if nvml != nil {
+				slog.Debug("nvidia-ml loaded", "library", libPath)
+				cHandles.nvml = nvml
+				nvmlLibPath = libPath
+			}
 		}
 		}
-		cudartMgmtPatterns = append(cudartMgmtPatterns, CudartLinuxGlobs...)
-		// Aligned with driver, we can't carry as payloads
-		nvcudaMgmtName = "libcuda.so*"
-		nvcudaMgmtPatterns = NvcudaLinuxGlobs
-	default:
-		return gpuHandles
 	}
 	}
 
 
-	slog.Debug("Detecting GPUs")
-	nvcudaLibPaths := FindGPULibs(nvcudaMgmtName, nvcudaMgmtPatterns)
+	nvcudaLibPaths := FindGPULibs(NvcudaMgmtName, nvcudaMgmtPatterns)
 	if len(nvcudaLibPaths) > 0 {
 	if len(nvcudaLibPaths) > 0 {
 		deviceCount, nvcuda, libPath := LoadNVCUDAMgmt(nvcudaLibPaths)
 		deviceCount, nvcuda, libPath := LoadNVCUDAMgmt(nvcudaLibPaths)
 		if nvcuda != nil {
 		if nvcuda != nil {
 			slog.Debug("detected GPUs", "count", deviceCount, "library", libPath)
 			slog.Debug("detected GPUs", "count", deviceCount, "library", libPath)
-			gpuHandles.nvcuda = nvcuda
-			gpuHandles.deviceCount = deviceCount
-			return gpuHandles
+			cHandles.nvcuda = nvcuda
+			cHandles.deviceCount = deviceCount
+			nvcudaLibPath = libPath
+			return cHandles
 		}
 		}
 	}
 	}
 
 
-	cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns)
+	cudartLibPaths := FindGPULibs(CudartMgmtName, cudartMgmtPatterns)
 	if len(cudartLibPaths) > 0 {
 	if len(cudartLibPaths) > 0 {
 		deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths)
 		deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths)
 		if cudart != nil {
 		if cudart != nil {
 			slog.Debug("detected GPUs", "library", libPath, "count", deviceCount)
 			slog.Debug("detected GPUs", "library", libPath, "count", deviceCount)
-			gpuHandles.cudart = cudart
-			gpuHandles.deviceCount = deviceCount
-			return gpuHandles
+			cHandles.cudart = cudart
+			cHandles.deviceCount = deviceCount
+			cudartLibPath = libPath
+			return cHandles
 		}
 		}
 	}
 	}
 
 
-	return gpuHandles
+	return cHandles
+}
+
+// Note: gpuMutex must already be held
+func initOneAPIHandles() *oneapiHandles {
+	oHandles := &oneapiHandles{}
+
+	// Short Circuit if we already know which library to use
+	if oneapiLibPath != "" {
+		oHandles.deviceCount, oHandles.oneapi, _ = LoadOneapiMgmt([]string{oneapiLibPath})
+		return oHandles
+	}
+
+	oneapiLibPaths := FindGPULibs(OneapiMgmtName, OneapiGlobs)
+	if len(oneapiLibPaths) > 0 {
+		oHandles.deviceCount, oHandles.oneapi, oneapiLibPath = LoadOneapiMgmt(oneapiLibPaths)
+	}
+
+	return oHandles
+}
+
+func GetCPUInfo() GpuInfoList {
+	gpuMutex.Lock()
+	if !bootstrapped {
+		gpuMutex.Unlock()
+		GetGPUInfo()
+	} else {
+		gpuMutex.Unlock()
+	}
+	return GpuInfoList{cpus[0].GpuInfo}
 }
 }
 
 
 func GetGPUInfo() GpuInfoList {
 func GetGPUInfo() GpuInfoList {
@@ -160,110 +178,245 @@ func GetGPUInfo() GpuInfoList {
 	// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
 	// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
 	gpuMutex.Lock()
 	gpuMutex.Lock()
 	defer gpuMutex.Unlock()
 	defer gpuMutex.Unlock()
-
-	gpuHandles := initGPUHandles()
+	needRefresh := true
+	var cHandles *cudaHandles
+	var oHandles *oneapiHandles
 	defer func() {
 	defer func() {
-		if gpuHandles.cudart != nil {
-			C.cudart_release(*gpuHandles.cudart)
+		if cHandles != nil {
+			if cHandles.cudart != nil {
+				C.cudart_release(*cHandles.cudart)
+			}
+			if cHandles.nvcuda != nil {
+				C.nvcuda_release(*cHandles.nvcuda)
+			}
+			if cHandles.nvml != nil {
+				C.nvml_release(*cHandles.nvml)
+			}
 		}
 		}
-		if gpuHandles.nvcuda != nil {
-			C.nvcuda_release(*gpuHandles.nvcuda)
+		if oHandles != nil {
+			if oHandles.oneapi != nil {
+				// TODO - is this needed?
+				C.oneapi_release(*oHandles.oneapi)
+			}
 		}
 		}
 	}()
 	}()
 
 
-	// All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX
-	cpuVariant := GetCPUVariant()
-	if cpuVariant == "" && runtime.GOARCH == "amd64" {
-		slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.")
-	}
+	if !bootstrapped {
+		slog.Debug("Detecting GPUs")
+		needRefresh = false
+		cpuCapability = GetCPUCapability()
+		var memInfo C.mem_info_t
 
 
-	// On windows we bundle the nvidia library one level above the runner dir
-	depPath := ""
-	if runtime.GOOS == "windows" && envconfig.RunnersDir != "" {
-		depPath = filepath.Dir(envconfig.RunnersDir)
-	}
+		mem, err := GetCPUMem()
+		if err != nil {
+			slog.Warn("error looking up system memory", "error", err)
+		}
+		cpus = []CPUInfo{CPUInfo{
+			GpuInfo: GpuInfo{
+				memInfo: mem,
+				Library: "cpu",
+				Variant: cpuCapability,
+				ID:      "0",
+			},
+		}}
+
+		// Fallback to CPU mode if we're lacking required vector extensions on x86
+		if cpuCapability < GPURunnerCPUCapability && runtime.GOARCH == "amd64" {
+			slog.Warn("CPU does not have minimum vector extensions, GPU inference disabled", "required", GPURunnerCPUCapability, "detected", cpuCapability)
+			bootstrapped = true
+			// No need to do any GPU discovery, since we can't run on them
+			return GpuInfoList{cpus[0].GpuInfo}
+		}
 
 
-	var memInfo C.mem_info_t
-	resp := []GpuInfo{}
+		// On windows we bundle the nvidia library one level above the runner dir
+		depPath := ""
+		if runtime.GOOS == "windows" && envconfig.RunnersDir != "" {
+			depPath = filepath.Dir(envconfig.RunnersDir)
+		}
 
 
-	// NVIDIA first
-	for i := range gpuHandles.deviceCount {
-		// TODO once we support CPU compilation variants of GPU libraries refine this...
-		if cpuVariant == "" && runtime.GOARCH == "amd64" {
-			continue
+		// Load ALL libraries
+		cHandles = initCudaHandles()
+
+		// NVIDIA
+		for i := range cHandles.deviceCount {
+			if cHandles.cudart != nil || cHandles.nvcuda != nil {
+				gpuInfo := CudaGPUInfo{
+					GpuInfo: GpuInfo{
+						Library: "cuda",
+					},
+					index: i,
+				}
+				var driverMajor int
+				var driverMinor int
+				if cHandles.cudart != nil {
+					C.cudart_bootstrap(*cHandles.cudart, C.int(i), &memInfo)
+				} else {
+					C.nvcuda_bootstrap(*cHandles.nvcuda, C.int(i), &memInfo)
+					driverMajor = int(cHandles.nvcuda.driver_major)
+					driverMinor = int(cHandles.nvcuda.driver_minor)
+				}
+				if memInfo.err != nil {
+					slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
+					C.free(unsafe.Pointer(memInfo.err))
+					continue
+				}
+				if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) {
+					slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
+					continue
+				}
+				gpuInfo.TotalMemory = uint64(memInfo.total)
+				gpuInfo.FreeMemory = uint64(memInfo.free)
+				gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
+				gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor)
+				gpuInfo.MinimumMemory = cudaMinimumMemory
+				gpuInfo.DependencyPath = depPath
+				gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
+				gpuInfo.DriverMajor = driverMajor
+				gpuInfo.DriverMinor = driverMinor
+
+				// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
+				cudaGPUs = append(cudaGPUs, gpuInfo)
+			}
 		}
 		}
-		if gpuHandles.cudart != nil || gpuHandles.nvcuda != nil {
-			gpuInfo := GpuInfo{
-				Library: "cuda",
+
+		// Intel
+		oHandles = initOneAPIHandles()
+		for d := 0; oHandles.oneapi != nil && d < int(oHandles.oneapi.num_drivers); d++ {
+			if oHandles.oneapi == nil {
+				// shouldn't happen
+				slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers))
+				continue
+			}
+			devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d))
+			for i := range devCount {
+				gpuInfo := OneapiGPUInfo{
+					GpuInfo: GpuInfo{
+						Library: "oneapi",
+					},
+					driverIndex: d,
+					gpuIndex:    int(i),
+				}
+				// TODO - split bootstrapping from updating free memory
+				C.oneapi_check_vram(*oHandles.oneapi, C.int(d), i, &memInfo)
+				// TODO - convert this to MinimumMemory based on testing...
+				var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
+				memInfo.free = C.uint64_t(totalFreeMem)
+				gpuInfo.TotalMemory = uint64(memInfo.total)
+				gpuInfo.FreeMemory = uint64(memInfo.free)
+				gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
+				gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
+				// TODO dependency path?
+				oneapiGPUs = append(oneapiGPUs, gpuInfo)
 			}
 			}
-			var driverMajor int
-			var driverMinor int
-			if gpuHandles.cudart != nil {
-				C.cudart_check_vram(*gpuHandles.cudart, C.int(i), &memInfo)
+		}
+
+		rocmGPUs = AMDGetGPUInfo()
+		bootstrapped = true
+	}
+
+	// For detected GPUs, load library if not loaded
+
+	// Refresh free memory usage
+	if needRefresh {
+		mem, err := GetCPUMem()
+		if err != nil {
+			slog.Warn("error looking up system memory", "error", err)
+		} else {
+			slog.Debug("updating system memory data",
+				slog.Group(
+					"before",
+					"total", format.HumanBytes2(cpus[0].TotalMemory),
+					"free", format.HumanBytes2(cpus[0].FreeMemory),
+				),
+				slog.Group(
+					"now",
+					"total", format.HumanBytes2(mem.TotalMemory),
+					"free", format.HumanBytes2(mem.FreeMemory),
+				),
+			)
+			cpus[0].FreeMemory = mem.FreeMemory
+		}
+
+		var memInfo C.mem_info_t
+		if cHandles == nil && len(cudaGPUs) > 0 {
+			cHandles = initCudaHandles()
+		}
+		for i, gpu := range cudaGPUs {
+			if cHandles.nvml != nil {
+				C.nvml_get_free(*cHandles.nvml, C.int(gpu.index), &memInfo.free, &memInfo.total, &memInfo.used)
+			} else if cHandles.cudart != nil {
+				C.cudart_bootstrap(*cHandles.cudart, C.int(gpu.index), &memInfo)
+			} else if cHandles.nvcuda != nil {
+				C.nvcuda_get_free(*cHandles.nvcuda, C.int(gpu.index), &memInfo.free, &memInfo.total)
+				memInfo.used = memInfo.total - memInfo.free
 			} else {
 			} else {
-				C.nvcuda_check_vram(*gpuHandles.nvcuda, C.int(i), &memInfo)
-				driverMajor = int(gpuHandles.nvcuda.driver_major)
-				driverMinor = int(gpuHandles.nvcuda.driver_minor)
+				// shouldn't happen
+				slog.Warn("no valid cuda library loaded to refresh vram usage")
+				break
 			}
 			}
 			if memInfo.err != nil {
 			if memInfo.err != nil {
-				slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
+				slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
 				C.free(unsafe.Pointer(memInfo.err))
 				C.free(unsafe.Pointer(memInfo.err))
 				continue
 				continue
 			}
 			}
-			if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) {
-				slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
+			if memInfo.free == 0 {
+				slog.Warn("error looking up nvidia GPU memory")
 				continue
 				continue
 			}
 			}
-			gpuInfo.TotalMemory = uint64(memInfo.total)
-			gpuInfo.FreeMemory = uint64(memInfo.free)
-			gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
-			gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor)
-			gpuInfo.MinimumMemory = cudaMinimumMemory
-			gpuInfo.DependencyPath = depPath
-			gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
-			gpuInfo.DriverMajor = driverMajor
-			gpuInfo.DriverMinor = driverMinor
-
-			// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
-			resp = append(resp, gpuInfo)
+			slog.Debug("updating cuda memory data",
+				"gpu", gpu.ID,
+				"name", gpu.Name,
+				slog.Group(
+					"before",
+					"total", format.HumanBytes2(gpu.TotalMemory),
+					"free", format.HumanBytes2(gpu.FreeMemory),
+				),
+				slog.Group(
+					"now",
+					"total", format.HumanBytes2(uint64(memInfo.total)),
+					"free", format.HumanBytes2(uint64(memInfo.free)),
+					"used", format.HumanBytes2(uint64(memInfo.used)),
+				),
+			)
+			cudaGPUs[i].FreeMemory = uint64(memInfo.free)
 		}
 		}
-	}
-
-	// Then AMD
-	resp = append(resp, AMDGetGPUInfo()...)
 
 
-	if len(resp) == 0 {
-		C.cpu_check_ram(&memInfo)
-		if memInfo.err != nil {
-			slog.Info("error looking up CPU memory", "error", C.GoString(memInfo.err))
-			C.free(unsafe.Pointer(memInfo.err))
-			return resp
+		if oHandles == nil && len(oneapiGPUs) > 0 {
+			oHandles = initOneAPIHandles()
 		}
 		}
-		gpuInfo := GpuInfo{
-			Library: "cpu",
-			Variant: cpuVariant,
+		for i, gpu := range oneapiGPUs {
+			if oHandles.oneapi == nil {
+				// shouldn't happen
+				slog.Warn("nil oneapi handle with device count", "count", oHandles.deviceCount)
+				continue
+			}
+			C.oneapi_check_vram(*oHandles.oneapi, C.int(gpu.driverIndex), C.int(gpu.gpuIndex), &memInfo)
+			// TODO - convert this to MinimumMemory based on testing...
+			var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
+			memInfo.free = C.uint64_t(totalFreeMem)
+			oneapiGPUs[i].FreeMemory = uint64(memInfo.free)
 		}
 		}
-		gpuInfo.TotalMemory = uint64(memInfo.total)
-		gpuInfo.FreeMemory = uint64(memInfo.free)
-		gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
 
 
-		resp = append(resp, gpuInfo)
+		err = RocmGPUInfoList(rocmGPUs).RefreshFreeMemory()
+		if err != nil {
+			slog.Debug("problem refreshing ROCm free memory", "error", err)
+		}
 	}
 	}
 
 
-	return resp
-}
-
-func GetCPUMem() (memInfo, error) {
-	var ret memInfo
-	var info C.mem_info_t
-	C.cpu_check_ram(&info)
-	if info.err != nil {
-		defer C.free(unsafe.Pointer(info.err))
-		return ret, fmt.Errorf(C.GoString(info.err))
+	resp := []GpuInfo{}
+	for _, gpu := range cudaGPUs {
+		resp = append(resp, gpu.GpuInfo)
+	}
+	for _, gpu := range rocmGPUs {
+		resp = append(resp, gpu.GpuInfo)
+	}
+	for _, gpu := range oneapiGPUs {
+		resp = append(resp, gpu.GpuInfo)
+	}
+	if len(resp) == 0 {
+		resp = append(resp, cpus[0].GpuInfo)
 	}
 	}
-	ret.FreeMemory = uint64(info.free)
-	ret.TotalMemory = uint64(info.total)
-	return ret, nil
+	return resp
 }
 }
 
 
 func FindGPULibs(baseLibName string, defaultPatterns []string) []string {
 func FindGPULibs(baseLibName string, defaultPatterns []string) []string {
@@ -362,8 +515,26 @@ func LoadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string) {
 	return 0, nil, ""
 	return 0, nil, ""
 }
 }
 
 
+func LoadNVMLMgmt(nvmlLibPaths []string) (*C.nvml_handle_t, string) {
+	var resp C.nvml_init_resp_t
+	resp.ch.verbose = getVerboseState()
+	for _, libPath := range nvmlLibPaths {
+		lib := C.CString(libPath)
+		defer C.free(unsafe.Pointer(lib))
+		C.nvml_init(lib, &resp)
+		if resp.err != nil {
+			slog.Info(fmt.Sprintf("Unable to load NVML management library %s: %s", libPath, C.GoString(resp.err)))
+			C.free(unsafe.Pointer(resp.err))
+		} else {
+			return &resp.ch, libPath
+		}
+	}
+	return nil, ""
+}
+
 func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) {
 func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) {
 	var resp C.oneapi_init_resp_t
 	var resp C.oneapi_init_resp_t
+	num_devices := 0
 	resp.oh.verbose = getVerboseState()
 	resp.oh.verbose = getVerboseState()
 	for _, libPath := range oneapiLibPaths {
 	for _, libPath := range oneapiLibPaths {
 		lib := C.CString(libPath)
 		lib := C.CString(libPath)
@@ -373,7 +544,10 @@ func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) {
 			slog.Debug("Unable to load oneAPI management library", "library", libPath, "error", C.GoString(resp.err))
 			slog.Debug("Unable to load oneAPI management library", "library", libPath, "error", C.GoString(resp.err))
 			C.free(unsafe.Pointer(resp.err))
 			C.free(unsafe.Pointer(resp.err))
 		} else {
 		} else {
-			return int(resp.num_devices), &resp.oh, libPath
+			for i := range resp.oh.num_drivers {
+				num_devices += int(C.oneapi_get_device_count(resp.oh, C.int(i)))
+			}
+			return num_devices, &resp.oh, libPath
 		}
 		}
 	}
 	}
 	return 0, nil, ""
 	return 0, nil, ""

+ 12 - 1
gpu/gpu_darwin.go

@@ -24,7 +24,7 @@ func GetGPUInfo() GpuInfoList {
 		return []GpuInfo{
 		return []GpuInfo{
 			{
 			{
 				Library: "cpu",
 				Library: "cpu",
-				Variant: GetCPUVariant(),
+				Variant: GetCPUCapability(),
 				memInfo: mem,
 				memInfo: mem,
 			},
 			},
 		}
 		}
@@ -42,6 +42,17 @@ func GetGPUInfo() GpuInfoList {
 	return []GpuInfo{info}
 	return []GpuInfo{info}
 }
 }
 
 
+func GetCPUInfo() GpuInfoList {
+	mem, _ := GetCPUMem()
+	return []GpuInfo{
+		{
+			Library: "cpu",
+			Variant: GetCPUCapability(),
+			memInfo: mem,
+		},
+	}
+}
+
 func GetCPUMem() (memInfo, error) {
 func GetCPUMem() (memInfo, error) {
 	return memInfo{
 	return memInfo{
 		TotalMemory: uint64(C.getPhysicalMemory()),
 		TotalMemory: uint64(C.getPhysicalMemory()),

+ 2 - 0
gpu/gpu_info.h

@@ -47,6 +47,7 @@ typedef struct mem_info {
   char gpu_name[GPU_NAME_LEN];
   char gpu_name[GPU_NAME_LEN];
   uint64_t total;
   uint64_t total;
   uint64_t free;
   uint64_t free;
+  uint64_t used;
 
 
   // Compute Capability
   // Compute Capability
   int major; 
   int major; 
@@ -62,6 +63,7 @@ void cpu_check_ram(mem_info_t *resp);
 
 
 #include "gpu_info_cudart.h"
 #include "gpu_info_cudart.h"
 #include "gpu_info_nvcuda.h"
 #include "gpu_info_nvcuda.h"
+#include "gpu_info_nvml.h"
 #include "gpu_info_oneapi.h"
 #include "gpu_info_oneapi.h"
 
 
 #endif  // __GPU_INFO_H__
 #endif  // __GPU_INFO_H__

+ 0 - 45
gpu/gpu_info_cpu.c

@@ -1,45 +0,0 @@
-#include "gpu_info.h"
-// Fallbacks for CPU mode
-
-#ifdef _WIN32
-#include <sysinfoapi.h>
-void cpu_check_ram(mem_info_t *resp) {
-  resp->err = NULL;
-  MEMORYSTATUSEX info;
-  info.dwLength = sizeof(info);
-  if (GlobalMemoryStatusEx(&info) != 0) {
-    resp->total = info.ullTotalPhys;
-    resp->free = info.ullAvailPhys;
-    snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
-  } else {
-    resp->err = LOAD_ERR();
-  }
-  return;
-}
-
-#elif __linux__
-#include <errno.h>
-#include <string.h>
-#include <sys/sysinfo.h>
-void cpu_check_ram(mem_info_t *resp) {
-  struct sysinfo info;
-  resp->err = NULL;
-  if (sysinfo(&info) != 0) {
-    resp->err = strdup(strerror(errno));
-  } else {
-    resp->total = info.totalram * info.mem_unit;
-    resp->free = info.freeram * info.mem_unit;
-    snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
-  }
-  return;
-}
-
-#elif __APPLE__
-// TODO consider an Apple implementation that does something useful
-// mem_info_t cpu_check_ram() {
-//   mem_info_t resp = {0, 0, NULL};
-//   return resp;
-// }
-#else
-#error "Unsupported platform"
-#endif

+ 3 - 1
gpu/gpu_info_cudart.c

@@ -94,7 +94,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
 }
 }
 
 
 
 
-void cudart_check_vram(cudart_handle_t h, int i, mem_info_t *resp) {
+void cudart_bootstrap(cudart_handle_t h, int i, mem_info_t *resp) {
   resp->err = NULL;
   resp->err = NULL;
   cudartMemory_t memInfo = {0,0,0};
   cudartMemory_t memInfo = {0,0,0};
   cudartReturn_t ret;
   cudartReturn_t ret;
@@ -166,9 +166,11 @@ void cudart_check_vram(cudart_handle_t h, int i, mem_info_t *resp) {
 
 
   resp->total = memInfo.total;
   resp->total = memInfo.total;
   resp->free = memInfo.free;
   resp->free = memInfo.free;
+  resp->used = memInfo.used;
 
 
   LOG(h.verbose, "[%s] CUDA totalMem %lu\n", resp->gpu_id, resp->total);
   LOG(h.verbose, "[%s] CUDA totalMem %lu\n", resp->gpu_id, resp->total);
   LOG(h.verbose, "[%s] CUDA freeMem %lu\n", resp->gpu_id, resp->free);
   LOG(h.verbose, "[%s] CUDA freeMem %lu\n", resp->gpu_id, resp->free);
+  LOG(h.verbose, "[%s] CUDA usedMem %lu\n", resp->gpu_id, resp->used);
   LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
   LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
 }
 }
 
 

+ 2 - 1
gpu/gpu_info_cudart.h

@@ -140,7 +140,8 @@ typedef struct cudart_init_resp {
 } cudart_init_resp_t;
 } cudart_init_resp_t;
 
 
 void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp);
 void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp);
-void cudart_check_vram(cudart_handle_t ch, int device_id, mem_info_t *resp);
+void cudart_bootstrap(cudart_handle_t ch, int device_id, mem_info_t *resp);
+// TODO - if we keep this library longer term, add cudart_get_free
 void cudart_release(cudart_handle_t ch);
 void cudart_release(cudart_handle_t ch);
 
 
 #endif  // __GPU_INFO_CUDART_H__
 #endif  // __GPU_INFO_CUDART_H__

+ 38 - 3
gpu/gpu_info_nvcuda.c

@@ -96,7 +96,7 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
 }
 }
 
 
 const int buflen = 256;
 const int buflen = 256;
-void nvcuda_check_vram(nvcuda_handle_t h, int i, mem_info_t *resp) {
+void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) {
   resp->err = NULL;
   resp->err = NULL;
   nvcudaMemory_t memInfo = {0,0};
   nvcudaMemory_t memInfo = {0,0};
   CUresult ret;
   CUresult ret;
@@ -168,7 +168,7 @@ void nvcuda_check_vram(nvcuda_handle_t h, int i, mem_info_t *resp) {
   // To get memory we have to set (and release) a context
   // To get memory we have to set (and release) a context
   ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
   ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
   if (ret != CUDA_SUCCESS) {
   if (ret != CUDA_SUCCESS) {
-    snprintf(buf, buflen, "nvcuda failed to get primary device context %d", ret);
+    snprintf(buf, buflen, "nvcuda failed to get device context %d", ret);
     resp->err = strdup(buf);
     resp->err = strdup(buf);
     return;
     return;
   }
   }
@@ -193,7 +193,42 @@ void nvcuda_check_vram(nvcuda_handle_t h, int i, mem_info_t *resp) {
 
 
   ret = (*h.cuCtxDestroy)(ctx);
   ret = (*h.cuCtxDestroy)(ctx);
   if (ret != CUDA_SUCCESS) {
   if (ret != CUDA_SUCCESS) {
-    LOG(1, "nvcuda failed to release primary device context %d", ret);
+    LOG(1, "nvcuda failed to release device context %d", ret);
+  }
+}
+
+void nvcuda_get_free(nvcuda_handle_t h, int i, uint64_t *free, uint64_t *total) {
+  CUresult ret;
+  CUcontext ctx = NULL;
+  CUdevice device = -1;
+  *free = 0;
+  *total = 0;
+
+  ret = (*h.cuDeviceGet)(&device, i);
+  if (ret != CUDA_SUCCESS) {
+    LOG(1, "nvcuda device failed to initialize");
+    return;
+  }
+
+
+  // To get memory we have to set (and release) a context
+  ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
+  if (ret != CUDA_SUCCESS) {
+    LOG(1, "nvcuda failed to get device context %d", ret);
+    return;
+  }
+
+  ret = (*h.cuMemGetInfo_v2)(free, total);
+  if (ret != CUDA_SUCCESS) {
+    LOG(1, "nvcuda device memory info lookup failure %d", ret);
+    // Best effort on failure...
+    (*h.cuCtxDestroy)(ctx);
+    return;
+  }
+
+  ret = (*h.cuCtxDestroy)(ctx);
+  if (ret != CUDA_SUCCESS) {
+    LOG(1, "nvcuda failed to release device context %d", ret);
   }
   }
 }
 }
 
 

+ 2 - 1
gpu/gpu_info_nvcuda.h

@@ -67,7 +67,8 @@ typedef struct nvcuda_init_resp {
 } nvcuda_init_resp_t;
 } nvcuda_init_resp_t;
 
 
 void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp);
 void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp);
-void nvcuda_check_vram(nvcuda_handle_t ch, int device_id, mem_info_t *resp);
+void nvcuda_bootstrap(nvcuda_handle_t ch, int device_id, mem_info_t *resp);
+void nvcuda_get_free(nvcuda_handle_t ch,  int device_id, uint64_t *free, uint64_t *total);
 void nvcuda_release(nvcuda_handle_t ch);
 void nvcuda_release(nvcuda_handle_t ch);
 
 
 #endif  // __GPU_INFO_NVCUDA_H__
 #endif  // __GPU_INFO_NVCUDA_H__

+ 104 - 0
gpu/gpu_info_nvml.c

@@ -0,0 +1,104 @@
+#ifndef __APPLE__  // TODO - maybe consider nvidia support on intel macs?
+
+#include <string.h>
+
+#include "gpu_info_nvml.h"
+
+void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) {
+  nvmlReturn_t ret;
+  resp->err = NULL;
+  const int buflen = 256;
+  char buf[buflen + 1];
+  int i;
+
+  struct lookup {
+    char *s;
+    void **p;
+  } l[] = {
+      {"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2},
+      {"nvmlShutdown", (void *)&resp->ch.nvmlShutdown},
+      {"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.nvmlDeviceGetHandleByIndex},
+      {"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo},
+      {NULL, NULL},
+  };
+
+  resp->ch.handle = LOAD_LIBRARY(nvml_lib_path, RTLD_LAZY);
+  if (!resp->ch.handle) {
+    char *msg = LOAD_ERR();
+    LOG(resp->ch.verbose, "library %s load err: %s\n", nvml_lib_path, msg);
+    snprintf(buf, buflen,
+             "Unable to load %s library to query for Nvidia GPUs: %s",
+             nvml_lib_path, msg);
+    free(msg);
+    resp->err = strdup(buf);
+    return;
+  }
+
+  // TODO once we've squashed the remaining corner cases remove this log
+  // LOG(resp->ch.verbose, "wiring nvidia management library functions in %s\n", nvml_lib_path);
+  
+  for (i = 0; l[i].s != NULL; i++) {
+    // TODO once we've squashed the remaining corner cases remove this log
+    // LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
+
+    *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
+    if (!l[i].p) {
+      resp->ch.handle = NULL;
+      char *msg = LOAD_ERR();
+      LOG(resp->ch.verbose, "dlerr: %s\n", msg);
+      UNLOAD_LIBRARY(resp->ch.handle);
+      snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
+               msg);
+      free(msg);
+      resp->err = strdup(buf);
+      return;
+    }
+  }
+
+  ret = (*resp->ch.nvmlInit_v2)();
+  if (ret != NVML_SUCCESS) {
+    LOG(resp->ch.verbose, "nvmlInit_v2 err: %d\n", ret);
+    UNLOAD_LIBRARY(resp->ch.handle);
+    resp->ch.handle = NULL;
+    snprintf(buf, buflen, "nvml vram init failure: %d", ret);
+    resp->err = strdup(buf);
+    return;
+  }
+}
+
+
+void nvml_get_free(nvml_handle_t h, int device_id, uint64_t *free, uint64_t *total, uint64_t *used) {
+    nvmlDevice_t device;
+    nvmlMemory_t memInfo = {0};
+    nvmlReturn_t ret;
+    ret = (*h.nvmlDeviceGetHandleByIndex)(device_id, &device);
+    if (ret != NVML_SUCCESS) {
+        LOG(1, "unable to get device handle %d: %d", device_id, ret);
+        *free = 0;
+        return;
+    }
+
+    ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo);
+    if (ret != NVML_SUCCESS) {
+        LOG(1, "device memory info lookup failure %d: %d", device_id, ret);
+        *free = 0;
+        return;
+    }
+    *free = memInfo.free;
+    *total = memInfo.total;
+    *used = memInfo.used;
+}
+
+
+void nvml_release(nvml_handle_t h) {
+  LOG(h.verbose, "releasing nvml library\n");
+  nvmlReturn_t ret;
+  ret = (*h.nvmlShutdown)();
+  if (ret != NVML_SUCCESS) {
+    LOG(1, "error during nvmlShutdown %d", ret);
+  }
+  UNLOAD_LIBRARY(h.handle);
+  h.handle = NULL;
+}
+
+#endif  // __APPLE__

+ 48 - 0
gpu/gpu_info_nvml.h

@@ -0,0 +1,48 @@
+#ifndef __APPLE__
+#ifndef __GPU_INFO_NVML_H__
+#define __GPU_INFO_NVML_H__
+#include "gpu_info.h"
+
+// Just enough typedef's to dlopen/dlsym for memory information
+typedef enum nvmlReturn_enum {
+  NVML_SUCCESS = 0,
+  // Other values omitted for now...
+} nvmlReturn_t;
+typedef void *nvmlDevice_t;  // Opaque is sufficient
+typedef struct nvmlMemory_st {
+  unsigned long long total;
+  unsigned long long free;
+  unsigned long long used;
+} nvmlMemory_t;
+
+typedef enum nvmlBrandType_enum
+{
+    NVML_BRAND_UNKNOWN          = 0,
+} nvmlBrandType_t;
+
+typedef struct nvml_handle {
+  void *handle;
+  uint16_t verbose;
+  nvmlReturn_t (*nvmlInit_v2)(void);
+  nvmlReturn_t (*nvmlShutdown)(void);
+  nvmlReturn_t (*nvmlDeviceGetHandleByIndex)(unsigned int, nvmlDevice_t *);
+  nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
+} nvml_handle_t;
+
+typedef struct nvml_init_resp {
+  char *err;  // If err is non-null handle is invalid
+  nvml_handle_t ch;
+} nvml_init_resp_t;
+
+typedef struct nvml_compute_capability {
+  char *err;
+  int major;
+  int minor;
+} nvml_compute_capability_t;
+
+void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp);
+void nvml_get_free(nvml_handle_t ch,  int device_id, uint64_t *free, uint64_t *total, uint64_t *used);
+void nvml_release(nvml_handle_t ch);
+
+#endif  // __GPU_INFO_NVML_H__
+#endif  // __APPLE__

+ 166 - 123
gpu/gpu_info_oneapi.c

@@ -4,15 +4,17 @@
 
 
 #include <string.h>
 #include <string.h>
 
 
-void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp)
-{
+void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp) {
   ze_result_t ret;
   ze_result_t ret;
   resp->err = NULL;
   resp->err = NULL;
+  resp->oh.devices = NULL;
+  resp->oh.num_devices = NULL;
+  resp->oh.drivers = NULL;
+  resp->oh.num_drivers = 0;
   const int buflen = 256;
   const int buflen = 256;
   char buf[buflen + 1];
   char buf[buflen + 1];
-  int i;
-  struct lookup
-  {
+  int i, d, count;
+  struct lookup {
     char *s;
     char *s;
     void **p;
     void **p;
   } l[] = {
   } l[] = {
@@ -28,8 +30,7 @@ void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp)
   };
   };
 
 
   resp->oh.handle = LOAD_LIBRARY(oneapi_lib_path, RTLD_LAZY);
   resp->oh.handle = LOAD_LIBRARY(oneapi_lib_path, RTLD_LAZY);
-  if (!resp->oh.handle)
-  {
+  if (!resp->oh.handle) {
     char *msg = LOAD_ERR();
     char *msg = LOAD_ERR();
     snprintf(buf, buflen,
     snprintf(buf, buflen,
              "Unable to load %s library to query for Intel GPUs: %s\n",
              "Unable to load %s library to query for Intel GPUs: %s\n",
@@ -44,14 +45,12 @@ void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp)
       "wiring Level-Zero management library functions in %s\n",
       "wiring Level-Zero management library functions in %s\n",
       oneapi_lib_path);
       oneapi_lib_path);
 
 
-  for (i = 0; l[i].s != NULL; i++)
-  {
+  for (i = 0; l[i].s != NULL; i++) {
     // TODO once we've squashed the remaining corner cases remove this log
     // TODO once we've squashed the remaining corner cases remove this log
     LOG(resp->oh.verbose, "dlsym: %s\n", l[i].s);
     LOG(resp->oh.verbose, "dlsym: %s\n", l[i].s);
 
 
     *l[i].p = LOAD_SYMBOL(resp->oh.handle, l[i].s);
     *l[i].p = LOAD_SYMBOL(resp->oh.handle, l[i].s);
-    if (!l[i].p)
-    {
+    if (!l[i].p) {
       resp->oh.handle = NULL;
       resp->oh.handle = NULL;
       char *msg = LOAD_ERR();
       char *msg = LOAD_ERR();
       LOG(resp->oh.verbose, "dlerr: %s\n", msg);
       LOG(resp->oh.verbose, "dlerr: %s\n", msg);
@@ -64,22 +63,67 @@ void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp)
   }
   }
 
 
   ret = (*resp->oh.zesInit)(0);
   ret = (*resp->oh.zesInit)(0);
-  if (ret != ZE_RESULT_SUCCESS)
-  {
-    LOG(resp->oh.verbose, "zesInit err: %d\n", ret);
-    UNLOAD_LIBRARY(resp->oh.handle);
-    resp->oh.handle = NULL;
-    snprintf(buf, buflen, "oneapi vram init failure: %d", ret);
+  if (ret != ZE_RESULT_SUCCESS) {
+    LOG(resp->oh.verbose, "zesInit err: %x\n", ret);
+    snprintf(buf, buflen, "oneapi vram init failure: %x", ret);
     resp->err = strdup(buf);
     resp->err = strdup(buf);
+    oneapi_release(resp->oh);
+    return;
   }
   }
 
 
-  (*resp->oh.zesDriverGet)(&resp->num_devices, NULL);
+  count = 0;
+  ret = (*resp->oh.zesDriverGet)(&resp->oh.num_drivers, NULL);
+  if (ret != ZE_RESULT_SUCCESS) {
+    LOG(resp->oh.verbose, "zesDriverGet err: %x\n", ret);
+    snprintf(buf, buflen, "unable to get driver count: %x", ret);
+    resp->err = strdup(buf);
+    oneapi_release(resp->oh);
+    return;
+  }
+  LOG(resp->oh.verbose, "oneapi driver count: %d\n", resp->oh.num_drivers);
+  resp->oh.drivers = malloc(resp->oh.num_drivers * sizeof(zes_driver_handle_t));
+  resp->oh.num_devices = malloc(resp->oh.num_drivers * sizeof(uint32_t));
+  memset(&resp->oh.num_devices[0], 0, resp->oh.num_drivers * sizeof(uint32_t));
+  resp->oh.devices =
+      malloc(resp->oh.num_drivers * sizeof(zes_device_handle_t *));
+  ret = (*resp->oh.zesDriverGet)(&resp->oh.num_drivers, &resp->oh.drivers[0]);
+  if (ret != ZE_RESULT_SUCCESS) {
+    LOG(resp->oh.verbose, "zesDriverGet err: %x\n", ret);
+    snprintf(buf, buflen, "unable to get driver count: %x", ret);
+    resp->err = strdup(buf);
+    oneapi_release(resp->oh);
+    return;
+  }
+
+  for (d = 0; d < resp->oh.num_drivers; d++) {
+    ret = (*resp->oh.zesDeviceGet)(resp->oh.drivers[d],
+                                   &resp->oh.num_devices[d], NULL);
+    if (ret != ZE_RESULT_SUCCESS) {
+      LOG(resp->oh.verbose, "zesDeviceGet err: %x\n", ret);
+      snprintf(buf, buflen, "unable to get device count: %x", ret);
+      resp->err = strdup(buf);
+      oneapi_release(resp->oh);
+      return;
+    }
+    resp->oh.devices[d] =
+        malloc(resp->oh.num_devices[d] * sizeof(zes_device_handle_t));
+    ret = (*resp->oh.zesDeviceGet)(
+        resp->oh.drivers[d], &resp->oh.num_devices[d], resp->oh.devices[d]);
+    if (ret != ZE_RESULT_SUCCESS) {
+      LOG(resp->oh.verbose, "zesDeviceGet err: %x\n", ret);
+      snprintf(buf, buflen, "unable to get device count: %x", ret);
+      resp->err = strdup(buf);
+      oneapi_release(resp->oh);
+      return;
+    }
+    count += resp->oh.num_devices[d];
+  }
 
 
   return;
   return;
 }
 }
 
 
-void oneapi_check_vram(oneapi_handle_t h, mem_info_t *resp)
-{
+void oneapi_check_vram(oneapi_handle_t h, int driver, int device,
+                       mem_info_t *resp) {
   ze_result_t ret;
   ze_result_t ret;
   resp->err = NULL;
   resp->err = NULL;
   uint64_t totalMem = 0;
   uint64_t totalMem = 0;
@@ -88,127 +132,126 @@ void oneapi_check_vram(oneapi_handle_t h, mem_info_t *resp)
   char buf[buflen + 1];
   char buf[buflen + 1];
   int i, d, m;
   int i, d, m;
 
 
-  if (h.handle == NULL)
-  {
+  if (h.handle == NULL) {
     resp->err = strdup("Level-Zero handle not initialized");
     resp->err = strdup("Level-Zero handle not initialized");
     return;
     return;
   }
   }
 
 
-  uint32_t driversCount = 0;
-  ret = (*h.zesDriverGet)(&driversCount, NULL);
-  if (ret != ZE_RESULT_SUCCESS)
-  {
-    snprintf(buf, buflen, "unable to get driver count: %d", ret);
-    resp->err = strdup(buf);
+  if (driver > h.num_drivers || device > h.num_devices[driver]) {
+    resp->err = strdup("driver of device index out of bounds");
     return;
     return;
   }
   }
-  LOG(h.verbose, "discovered %d Level-Zero drivers\n", driversCount);
-
-  zes_driver_handle_t *allDrivers =
-      malloc(driversCount * sizeof(zes_driver_handle_t));
-  (*h.zesDriverGet)(&driversCount, allDrivers);
 
 
   resp->total = 0;
   resp->total = 0;
   resp->free = 0;
   resp->free = 0;
 
 
-  for (d = 0; d < driversCount; d++)
-  {
-    uint32_t deviceCount = 0;
-    ret = (*h.zesDeviceGet)(allDrivers[d], &deviceCount, NULL);
-    if (ret != ZE_RESULT_SUCCESS)
-    {
-      snprintf(buf, buflen, "unable to get device count: %d", ret);
+  zes_device_ext_properties_t ext_props;
+  ext_props.stype = ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES;
+  ext_props.pNext = NULL;
+
+  zes_device_properties_t props;
+  props.stype = ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES;
+  props.pNext = &ext_props;
+
+  ret = (*h.zesDeviceGetProperties)(h.devices[driver][device], &props);
+  if (ret != ZE_RESULT_SUCCESS) {
+    snprintf(buf, buflen, "unable to get device properties: %d", ret);
+    resp->err = strdup(buf);
+    return;
+  }
+
+  snprintf(&resp->gpu_name[0], GPU_NAME_LEN, props.modelName);
+
+  // TODO this needs to map to ONEAPI_DEVICE_SELECTOR syntax
+  // (this is probably wrong...)
+  // TODO - the driver isn't included - what if there are multiple drivers?
+  snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", device);
+
+  if (h.verbose) {
+    // When in verbose mode, report more information about
+    // the card we discover.
+    LOG(h.verbose, "[%d:%d] oneAPI device name: %s\n", driver, device,
+        props.modelName);
+    LOG(h.verbose, "[%d:%d] oneAPI brand: %s\n", driver, device,
+        props.brandName);
+    LOG(h.verbose, "[%d:%d] oneAPI vendor: %s\n", driver, device,
+        props.vendorName);
+    LOG(h.verbose, "[%d:%d] oneAPI S/N: %s\n", driver, device,
+        props.serialNumber);
+    LOG(h.verbose, "[%d:%d] oneAPI board number: %s\n", driver, device,
+        props.boardNumber);
+  }
+
+  // TODO
+  // Compute Capability equivalent in resp->major, resp->minor, resp->patch
+
+  uint32_t memCount = 0;
+  ret = (*h.zesDeviceEnumMemoryModules)(h.devices[driver][device], &memCount,
+                                        NULL);
+  if (ret != ZE_RESULT_SUCCESS) {
+    snprintf(buf, buflen, "unable to enumerate Level-Zero memory modules: %x",
+             ret);
+    resp->err = strdup(buf);
+    return;
+  }
+
+  LOG(h.verbose, "discovered %d Level-Zero memory modules\n", memCount);
+
+  zes_mem_handle_t *mems = malloc(memCount * sizeof(zes_mem_handle_t));
+  (*h.zesDeviceEnumMemoryModules)(h.devices[driver][device], &memCount, mems);
+
+  for (m = 0; m < memCount; m++) {
+    zes_mem_state_t state;
+    state.stype = ZES_STRUCTURE_TYPE_MEM_STATE;
+    state.pNext = NULL;
+    ret = (*h.zesMemoryGetState)(mems[m], &state);
+    if (ret != ZE_RESULT_SUCCESS) {
+      snprintf(buf, buflen, "unable to get memory state: %x", ret);
       resp->err = strdup(buf);
       resp->err = strdup(buf);
-      free(allDrivers);
+      free(mems);
       return;
       return;
     }
     }
 
 
-    LOG(h.verbose, "discovered %d Level-Zero devices\n", deviceCount);
-
-    zes_device_handle_t *devices =
-        malloc(deviceCount * sizeof(zes_device_handle_t));
-    (*h.zesDeviceGet)(allDrivers[d], &deviceCount, devices);
-
-    for (i = 0; i < deviceCount; i++)
-    {
-      zes_device_ext_properties_t ext_props;
-      ext_props.stype = ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES;
-      ext_props.pNext = NULL;
-
-      zes_device_properties_t props;
-      props.stype = ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES;
-      props.pNext = &ext_props;
-
-      ret = (*h.zesDeviceGetProperties)(devices[i], &props);
-      if (ret != ZE_RESULT_SUCCESS)
-      {
-        snprintf(buf, buflen, "unable to get device properties: %d", ret);
-        resp->err = strdup(buf);
-        free(allDrivers);
-        free(devices);
-        return;
-      }
-
-      if (h.verbose)
-      {
-        // When in verbose mode, report more information about
-        // the card we discover.
-        LOG(h.verbose, "[%d] oneAPI device name: %s\n", i,
-            props.modelName);
-        LOG(h.verbose, "[%d] oneAPI brand: %s\n", i,
-            props.brandName);
-        LOG(h.verbose, "[%d] oneAPI vendor: %s\n", i,
-            props.vendorName);
-        LOG(h.verbose, "[%d] oneAPI S/N: %s\n", i,
-            props.serialNumber);
-        LOG(h.verbose, "[%d] oneAPI board number: %s\n", i,
-            props.boardNumber);
-      }
-
-      uint32_t memCount = 0;
-      ret = (*h.zesDeviceEnumMemoryModules)(devices[i], &memCount, NULL);
-      if (ret != ZE_RESULT_SUCCESS)
-      {
-        snprintf(buf, buflen,
-                 "unable to enumerate Level-Zero memory modules: %d", ret);
-        resp->err = strdup(buf);
-        free(allDrivers);
-        free(devices);
-        return;
-      }
-
-      LOG(h.verbose, "discovered %d Level-Zero memory modules\n", memCount);
-
-      zes_mem_handle_t *mems = malloc(memCount * sizeof(zes_mem_handle_t));
-      (*h.zesDeviceEnumMemoryModules)(devices[i], &memCount, mems);
-
-      for (m = 0; m < memCount; m++)
-      {
-        zes_mem_state_t state;
-        state.stype = ZES_STRUCTURE_TYPE_MEM_STATE;
-        state.pNext = NULL;
-        ret = (*h.zesMemoryGetState)(mems[m], &state);
-        if (ret != ZE_RESULT_SUCCESS)
-        {
-          snprintf(buf, buflen, "unable to get memory state: %d", ret);
-          resp->err = strdup(buf);
-          free(allDrivers);
-          free(devices);
-          free(mems);
-          return;
-        }
-
-        resp->total += state.size;
-        resp->free += state.free;
-      }
+    resp->total += state.size;
+    resp->free += state.free;
+  }
 
 
-      free(mems);
-    }
+  free(mems);
+}
 
 
-    free(devices);
+void oneapi_release(oneapi_handle_t h) {
+  int d;
+  LOG(h.verbose, "releasing oneapi library\n");
+  for (d = 0; d < h.num_drivers; d++) {
+    if (h.devices != NULL && h.devices[d] != NULL) {
+      free(h.devices[d]);
+    }
+  }
+  if (h.devices != NULL) {
+    free(h.devices);
+    h.devices = NULL;
   }
   }
+  if (h.num_devices != NULL) {
+    free(h.num_devices);
+    h.num_devices = NULL;
+  }
+  if (h.drivers != NULL) {
+    free(h.drivers);
+    h.drivers = NULL;
+  }
+  h.num_drivers = 0;
+  UNLOAD_LIBRARY(h.handle);
+  h.handle = NULL;
+}
 
 
-  free(allDrivers);
+int oneapi_get_device_count(oneapi_handle_t h, int driver) {
+  if (h.handle == NULL || h.num_devices == NULL) {
+    return 0;
+  }
+  if (driver > h.num_drivers) {
+    return 0;
+  }
+  return (int)h.num_devices[driver];
 }
 }
 
 
 #endif // __APPLE__
 #endif // __APPLE__

+ 34 - 42
gpu/gpu_info_oneapi.h

@@ -9,8 +9,7 @@
 #define ZE_BIT(_i) (1 << _i)
 #define ZE_BIT(_i) (1 << _i)
 
 
 // Just enough typedef's to dlopen/dlsym for memory information
 // Just enough typedef's to dlopen/dlsym for memory information
-typedef enum ze_result_t
-{
+typedef enum ze_result_t {
   ZE_RESULT_SUCCESS = 0,
   ZE_RESULT_SUCCESS = 0,
   // Other values omitted for now...
   // Other values omitted for now...
 } ze_result_t;
 } ze_result_t;
@@ -20,13 +19,11 @@ typedef struct _zes_driver_handle_t *zes_driver_handle_t;
 typedef struct _zes_device_handle_t *zes_device_handle_t;
 typedef struct _zes_device_handle_t *zes_device_handle_t;
 typedef struct _zes_mem_handle_t *zes_mem_handle_t;
 typedef struct _zes_mem_handle_t *zes_mem_handle_t;
 
 
-typedef enum _ze_structure_type_t
-{
+typedef enum _ze_structure_type_t {
   ZE_STRUCTURE_TYPE_FORCE_UINT32 = 0x7fffffff
   ZE_STRUCTURE_TYPE_FORCE_UINT32 = 0x7fffffff
 } ze_structure_type_t;
 } ze_structure_type_t;
 
 
-typedef enum _zes_structure_type_t
-{
+typedef enum _zes_structure_type_t {
   ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES = 0x1,
   ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES = 0x1,
   ZES_STRUCTURE_TYPE_MEM_PROPERTIES = 0xb,
   ZES_STRUCTURE_TYPE_MEM_PROPERTIES = 0xb,
   ZES_STRUCTURE_TYPE_MEM_STATE = 0x1e,
   ZES_STRUCTURE_TYPE_MEM_STATE = 0x1e,
@@ -34,35 +31,29 @@ typedef enum _zes_structure_type_t
   ZES_STRUCTURE_TYPE_FORCE_UINT32 = 0x7fffffff
   ZES_STRUCTURE_TYPE_FORCE_UINT32 = 0x7fffffff
 } zes_structure_type_t;
 } zes_structure_type_t;
 
 
-typedef enum _zes_mem_type_t
-{
+typedef enum _zes_mem_type_t {
   ZES_MEM_TYPE_FORCE_UINT32 = 0x7fffffff
   ZES_MEM_TYPE_FORCE_UINT32 = 0x7fffffff
 } zes_mem_type_t;
 } zes_mem_type_t;
 
 
-typedef enum _zes_mem_loc_t
-{
+typedef enum _zes_mem_loc_t {
   ZES_MEM_LOC_SYSTEM = 0,
   ZES_MEM_LOC_SYSTEM = 0,
   ZES_MEM_LOC_DEVICE = 1,
   ZES_MEM_LOC_DEVICE = 1,
   ZES_MEM_LOC_FORCE_UINT32 = 0x7fffffff
   ZES_MEM_LOC_FORCE_UINT32 = 0x7fffffff
 } zes_mem_loc_t;
 } zes_mem_loc_t;
 
 
-typedef enum _zes_mem_health_t
-{
+typedef enum _zes_mem_health_t {
   ZES_MEM_HEALTH_FORCE_UINT32 = 0x7fffffff
   ZES_MEM_HEALTH_FORCE_UINT32 = 0x7fffffff
 } zes_mem_health_t;
 } zes_mem_health_t;
 
 
-typedef struct _ze_device_uuid_t
-{
+typedef struct _ze_device_uuid_t {
   uint8_t id[ZE_MAX_DEVICE_UUID_SIZE];
   uint8_t id[ZE_MAX_DEVICE_UUID_SIZE];
 } ze_device_uuid_t;
 } ze_device_uuid_t;
 
 
-typedef struct _zes_uuid_t
-{
+typedef struct _zes_uuid_t {
   uint8_t id[ZE_MAX_DEVICE_UUID_SIZE];
   uint8_t id[ZE_MAX_DEVICE_UUID_SIZE];
 } zes_uuid_t;
 } zes_uuid_t;
 
 
-typedef enum _ze_device_type_t
-{
+typedef enum _ze_device_type_t {
   ZE_DEVICE_TYPE_GPU = 1,
   ZE_DEVICE_TYPE_GPU = 1,
   ZE_DEVICE_TYPE_CPU = 2,
   ZE_DEVICE_TYPE_CPU = 2,
   ZE_DEVICE_TYPE_FPGA = 3,
   ZE_DEVICE_TYPE_FPGA = 3,
@@ -71,8 +62,7 @@ typedef enum _ze_device_type_t
   ZE_DEVICE_TYPE_FORCE_UINT32 = 0x7fffffff
   ZE_DEVICE_TYPE_FORCE_UINT32 = 0x7fffffff
 } ze_device_type_t;
 } ze_device_type_t;
 
 
-typedef enum _zes_device_type_t
-{
+typedef enum _zes_device_type_t {
   ZES_DEVICE_TYPE_GPU = 1,
   ZES_DEVICE_TYPE_GPU = 1,
   ZES_DEVICE_TYPE_CPU = 2,
   ZES_DEVICE_TYPE_CPU = 2,
   ZES_DEVICE_TYPE_FPGA = 3,
   ZES_DEVICE_TYPE_FPGA = 3,
@@ -82,8 +72,7 @@ typedef enum _zes_device_type_t
 } zes_device_type_t;
 } zes_device_type_t;
 
 
 typedef uint32_t ze_device_property_flags_t;
 typedef uint32_t ze_device_property_flags_t;
-typedef enum _ze_device_property_flag_t
-{
+typedef enum _ze_device_property_flag_t {
   ZE_DEVICE_PROPERTY_FLAG_INTEGRATED = ZE_BIT(0),
   ZE_DEVICE_PROPERTY_FLAG_INTEGRATED = ZE_BIT(0),
   ZE_DEVICE_PROPERTY_FLAG_SUBDEVICE = ZE_BIT(1),
   ZE_DEVICE_PROPERTY_FLAG_SUBDEVICE = ZE_BIT(1),
   ZE_DEVICE_PROPERTY_FLAG_ECC = ZE_BIT(2),
   ZE_DEVICE_PROPERTY_FLAG_ECC = ZE_BIT(2),
@@ -92,8 +81,7 @@ typedef enum _ze_device_property_flag_t
 } ze_device_property_flag_t;
 } ze_device_property_flag_t;
 
 
 typedef uint32_t zes_device_property_flags_t;
 typedef uint32_t zes_device_property_flags_t;
-typedef enum _zes_device_property_flag_t
-{
+typedef enum _zes_device_property_flag_t {
   ZES_DEVICE_PROPERTY_FLAG_INTEGRATED = ZE_BIT(0),
   ZES_DEVICE_PROPERTY_FLAG_INTEGRATED = ZE_BIT(0),
   ZES_DEVICE_PROPERTY_FLAG_SUBDEVICE = ZE_BIT(1),
   ZES_DEVICE_PROPERTY_FLAG_SUBDEVICE = ZE_BIT(1),
   ZES_DEVICE_PROPERTY_FLAG_ECC = ZE_BIT(2),
   ZES_DEVICE_PROPERTY_FLAG_ECC = ZE_BIT(2),
@@ -101,8 +89,7 @@ typedef enum _zes_device_property_flag_t
   ZES_DEVICE_PROPERTY_FLAG_FORCE_UINT32 = 0x7fffffff
   ZES_DEVICE_PROPERTY_FLAG_FORCE_UINT32 = 0x7fffffff
 } zes_device_property_flag_t;
 } zes_device_property_flag_t;
 
 
-typedef struct _ze_device_properties_t
-{
+typedef struct _ze_device_properties_t {
   ze_structure_type_t stype;
   ze_structure_type_t stype;
   void *pNext;
   void *pNext;
   ze_device_type_t type;
   ze_device_type_t type;
@@ -126,8 +113,7 @@ typedef struct _ze_device_properties_t
   char name[ZE_MAX_DEVICE_NAME];
   char name[ZE_MAX_DEVICE_NAME];
 } ze_device_properties_t;
 } ze_device_properties_t;
 
 
-typedef struct _zes_device_properties_t
-{
+typedef struct _zes_device_properties_t {
   zes_structure_type_t stype;
   zes_structure_type_t stype;
   void *pNext;
   void *pNext;
   ze_device_properties_t core;
   ze_device_properties_t core;
@@ -140,8 +126,7 @@ typedef struct _zes_device_properties_t
   char driverVersion[ZES_STRING_PROPERTY_SIZE];
   char driverVersion[ZES_STRING_PROPERTY_SIZE];
 } zes_device_properties_t;
 } zes_device_properties_t;
 
 
-typedef struct _zes_device_ext_properties_t
-{
+typedef struct _zes_device_ext_properties_t {
   zes_structure_type_t stype;
   zes_structure_type_t stype;
   void *pNext;
   void *pNext;
   zes_uuid_t uuid;
   zes_uuid_t uuid;
@@ -149,8 +134,7 @@ typedef struct _zes_device_ext_properties_t
   zes_device_property_flags_t flags;
   zes_device_property_flags_t flags;
 } zes_device_ext_properties_t;
 } zes_device_ext_properties_t;
 
 
-typedef struct _zes_mem_properties_t
-{
+typedef struct _zes_mem_properties_t {
   zes_structure_type_t stype;
   zes_structure_type_t stype;
   void *pNext;
   void *pNext;
   zes_mem_type_t type;
   zes_mem_type_t type;
@@ -162,8 +146,7 @@ typedef struct _zes_mem_properties_t
   int32_t numChannels;
   int32_t numChannels;
 } zes_mem_properties_t;
 } zes_mem_properties_t;
 
 
-typedef struct _zes_mem_state_t
-{
+typedef struct _zes_mem_state_t {
   zes_structure_type_t stype;
   zes_structure_type_t stype;
   const void *pNext;
   const void *pNext;
   zes_mem_health_t health;
   zes_mem_health_t health;
@@ -171,10 +154,19 @@ typedef struct _zes_mem_state_t
   uint64_t size;
   uint64_t size;
 } zes_mem_state_t;
 } zes_mem_state_t;
 
 
-typedef struct oneapi_handle
-{
+typedef struct oneapi_handle {
   void *handle;
   void *handle;
   uint16_t verbose;
   uint16_t verbose;
+
+  uint32_t num_drivers;
+  zes_driver_handle_t *drivers;
+  uint32_t *num_devices;
+  zes_device_handle_t **devices;
+
+  // TODO Driver major, minor information
+  // int driver_major;
+  // int driver_minor;
+
   ze_result_t (*zesInit)(int);
   ze_result_t (*zesInit)(int);
   ze_result_t (*zesDriverGet)(uint32_t *pCount, zes_driver_handle_t *phDrivers);
   ze_result_t (*zesDriverGet)(uint32_t *pCount, zes_driver_handle_t *phDrivers);
   ze_result_t (*zesDeviceGet)(zes_driver_handle_t hDriver, uint32_t *pCount,
   ze_result_t (*zesDeviceGet)(zes_driver_handle_t hDriver, uint32_t *pCount,
@@ -191,21 +183,21 @@ typedef struct oneapi_handle
 
 
 } oneapi_handle_t;
 } oneapi_handle_t;
 
 
-typedef struct oneapi_init_resp
-{
+typedef struct oneapi_init_resp {
   char *err; // If err is non-null handle is invalid
   char *err; // If err is non-null handle is invalid
-  int num_devices;
   oneapi_handle_t oh;
   oneapi_handle_t oh;
 } oneapi_init_resp_t;
 } oneapi_init_resp_t;
 
 
-typedef struct oneapi_version_resp
-{
+typedef struct oneapi_version_resp {
   ze_result_t status;
   ze_result_t status;
   char *str; // Contains version or error string if status != 0
   char *str; // Contains version or error string if status != 0
 } oneapi_version_resp_t;
 } oneapi_version_resp_t;
 
 
 void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp);
 void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp);
-void oneapi_check_vram(oneapi_handle_t rh, mem_info_t *resp);
+void oneapi_check_vram(oneapi_handle_t h, int driver, int device,
+                       mem_info_t *resp);
+void oneapi_release(oneapi_handle_t h);
+int oneapi_get_device_count(oneapi_handle_t h, int driver);
 
 
 #endif // __GPU_INFO_INTEL_H__
 #endif // __GPU_INFO_INTEL_H__
 #endif // __APPLE__
 #endif // __APPLE__

+ 89 - 0
gpu/gpu_linux.go

@@ -0,0 +1,89 @@
+package gpu
+
+import (
+	"bufio"
+	"fmt"
+	"os"
+	"strings"
+
+	"github.com/ollama/ollama/format"
+)
+
+var CudartGlobs = []string{
+	"/usr/local/cuda/lib64/libcudart.so*",
+	"/usr/lib/x86_64-linux-gnu/nvidia/current/libcudart.so*",
+	"/usr/lib/x86_64-linux-gnu/libcudart.so*",
+	"/usr/lib/wsl/lib/libcudart.so*",
+	"/usr/lib/wsl/drivers/*/libcudart.so*",
+	"/opt/cuda/lib64/libcudart.so*",
+	"/usr/local/cuda*/targets/aarch64-linux/lib/libcudart.so*",
+	"/usr/lib/aarch64-linux-gnu/nvidia/current/libcudart.so*",
+	"/usr/lib/aarch64-linux-gnu/libcudart.so*",
+	"/usr/local/cuda/lib*/libcudart.so*",
+	"/usr/lib*/libcudart.so*",
+	"/usr/local/lib*/libcudart.so*",
+}
+
+var NvmlGlobs = []string{}
+
+var NvcudaGlobs = []string{
+	"/usr/local/cuda*/targets/*/lib/libcuda.so*",
+	"/usr/lib/*-linux-gnu/nvidia/current/libcuda.so*",
+	"/usr/lib/*-linux-gnu/libcuda.so*",
+	"/usr/lib/wsl/lib/libcuda.so*",
+	"/usr/lib/wsl/drivers/*/libcuda.so*",
+	"/opt/cuda/lib*/libcuda.so*",
+	"/usr/local/cuda/lib*/libcuda.so*",
+	"/usr/lib*/libcuda.so*",
+	"/usr/local/lib*/libcuda.so*",
+}
+
+var OneapiGlobs = []string{
+	"/usr/lib/x86_64-linux-gnu/libze_intel_gpu.so*",
+	"/usr/lib*/libze_intel_gpu.so*",
+}
+
+var CudartMgmtName = "libcudart.so*"
+var NvcudaMgmtName = "libcuda.so*"
+var NvmlMgmtName = "" // not currently wired on linux
+var OneapiMgmtName = "libze_intel_gpu.so"
+
+func GetCPUMem() (memInfo, error) {
+	var mem memInfo
+	var total, available, free, buffers, cached uint64
+	f, err := os.Open("/proc/meminfo")
+	if err != nil {
+		return mem, err
+	}
+	defer f.Close()
+	s := bufio.NewScanner(f)
+	for s.Scan() {
+		line := s.Text()
+		switch {
+		case strings.HasPrefix(line, "MemTotal:"):
+			_, err = fmt.Sscanf(line, "MemTotal:%d", &total)
+		case strings.HasPrefix(line, "MemAvailable:"):
+			_, err = fmt.Sscanf(line, "MemAvailable:%d", &available)
+		case strings.HasPrefix(line, "MemFree:"):
+			_, err = fmt.Sscanf(line, "MemFree:%d", &free)
+		case strings.HasPrefix(line, "Buffers:"):
+			_, err = fmt.Sscanf(line, "Buffers:%d", &buffers)
+		case strings.HasPrefix(line, "Cached:"):
+			_, err = fmt.Sscanf(line, "Cached:%d", &cached)
+		default:
+			continue
+		}
+		if err != nil {
+			return mem, err
+		}
+
+		if total > 0 && available > 0 {
+			mem.TotalMemory = total * format.KibiByte
+			mem.FreeMemory = available * format.KibiByte
+			return mem, nil
+		}
+	}
+	mem.TotalMemory = total * format.KibiByte
+	mem.FreeMemory = (free + buffers + cached) * format.KibiByte
+	return mem, nil
+}

+ 55 - 0
gpu/gpu_windows.go

@@ -0,0 +1,55 @@
+package gpu
+
+import (
+	"fmt"
+	"syscall"
+	"unsafe"
+)
+
+type MEMORYSTATUSEX struct {
+	length               uint32
+	MemoryLoad           uint32
+	TotalPhys            uint64
+	AvailPhys            uint64
+	TotalPageFile        uint64
+	AvailPageFile        uint64
+	TotalVirtual         uint64
+	AvailVirtual         uint64
+	AvailExtendedVirtual uint64
+}
+
+var (
+	k32                      = syscall.NewLazyDLL("kernel32.dll")
+	globalMemoryStatusExProc = k32.NewProc("GlobalMemoryStatusEx")
+	sizeofMemoryStatusEx     = uint32(unsafe.Sizeof(MEMORYSTATUSEX{}))
+)
+
+var CudartGlobs = []string{
+	"c:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*\\bin\\cudart64_*.dll",
+}
+
+var NvmlGlobs = []string{
+	"c:\\Windows\\System32\\nvml.dll",
+}
+
+var NvcudaGlobs = []string{
+	"c:\\windows\\system*\\nvcuda.dll",
+}
+
+var OneapiGlobs = []string{
+	"c:\\Windows\\System32\\DriverStore\\FileRepository\\*\\ze_intel_gpu64.dll",
+}
+
+var CudartMgmtName = "cudart64_*.dll"
+var NvcudaMgmtName = "nvcuda.dll"
+var NvmlMgmtName = "nvml.dll"
+var OneapiMgmtName = "ze_intel_gpu64.dll"
+
+func GetCPUMem() (memInfo, error) {
+	memStatus := MEMORYSTATUSEX{length: sizeofMemoryStatusEx}
+	r1, _, err := globalMemoryStatusExProc.Call(uintptr(unsafe.Pointer(&memStatus)))
+	if r1 == 0 {
+		return memInfo{}, fmt.Errorf("GlobalMemoryStatusEx failed: %w", err)
+	}
+	return memInfo{TotalMemory: memStatus.TotalPhys, FreeMemory: memStatus.AvailPhys}, nil
+}

+ 50 - 3
gpu/types.go

@@ -18,7 +18,7 @@ type GpuInfo struct {
 	Library string `json:"library,omitempty"`
 	Library string `json:"library,omitempty"`
 
 
 	// Optional variant to select (e.g. versions, cpu feature flags)
 	// Optional variant to select (e.g. versions, cpu feature flags)
-	Variant string `json:"variant,omitempty"`
+	Variant CPUCapability `json:"variant"`
 
 
 	// MinimumMemory represents the minimum memory required to use the GPU
 	// MinimumMemory represents the minimum memory required to use the GPU
 	MinimumMemory uint64 `json:"-"`
 	MinimumMemory uint64 `json:"-"`
@@ -38,6 +38,30 @@ type GpuInfo struct {
 	// TODO other performance capability info to help in scheduling decisions
 	// TODO other performance capability info to help in scheduling decisions
 }
 }
 
 
+type CPUInfo struct {
+	GpuInfo
+}
+
+type CudaGPUInfo struct {
+	GpuInfo
+	index int //nolint:unused,nolintlint
+}
+type CudaGPUInfoList []CudaGPUInfo
+
+type RocmGPUInfo struct {
+	GpuInfo
+	usedFilepath string //nolint:unused,nolintlint
+	index        int    //nolint:unused,nolintlint
+}
+type RocmGPUInfoList []RocmGPUInfo
+
+type OneapiGPUInfo struct {
+	GpuInfo
+	driverIndex int //nolint:unused,nolintlint
+	gpuIndex    int //nolint:unused,nolintlint
+}
+type OneapiGPUInfoList []OneapiGPUInfo
+
 type GpuInfoList []GpuInfo
 type GpuInfoList []GpuInfo
 
 
 // Split up the set of gpu info's by Library and variant
 // Split up the set of gpu info's by Library and variant
@@ -47,8 +71,8 @@ func (l GpuInfoList) ByLibrary() []GpuInfoList {
 	for _, info := range l {
 	for _, info := range l {
 		found := false
 		found := false
 		requested := info.Library
 		requested := info.Library
-		if info.Variant != "" {
-			requested += "_" + info.Variant
+		if info.Variant != CPUCapabilityNone {
+			requested += "_" + info.Variant.String()
 		}
 		}
 		for i, lib := range libs {
 		for i, lib := range libs {
 			if lib == requested {
 			if lib == requested {
@@ -86,3 +110,26 @@ type ByFreeMemory []GpuInfo
 func (a ByFreeMemory) Len() int           { return len(a) }
 func (a ByFreeMemory) Len() int           { return len(a) }
 func (a ByFreeMemory) Swap(i, j int)      { a[i], a[j] = a[j], a[i] }
 func (a ByFreeMemory) Swap(i, j int)      { a[i], a[j] = a[j], a[i] }
 func (a ByFreeMemory) Less(i, j int) bool { return a[i].FreeMemory < a[j].FreeMemory }
 func (a ByFreeMemory) Less(i, j int) bool { return a[i].FreeMemory < a[j].FreeMemory }
+
+type CPUCapability uint32
+
+// Override at build time when building base GPU runners
+var GPURunnerCPUCapability = CPUCapabilityAVX
+
+const (
+	CPUCapabilityNone CPUCapability = iota
+	CPUCapabilityAVX
+	CPUCapabilityAVX2
+	// TODO AVX512
+)
+
+func (c CPUCapability) String() string {
+	switch c {
+	case CPUCapabilityAVX:
+		return "avx"
+	case CPUCapabilityAVX2:
+		return "avx2"
+	default:
+		return "no vector extensions"
+	}
+}

+ 59 - 17
integration/concurrency_test.go

@@ -19,17 +19,19 @@ func TestMultiModelConcurrency(t *testing.T) {
 	var (
 	var (
 		req = [2]api.GenerateRequest{
 		req = [2]api.GenerateRequest{
 			{
 			{
-				Model:  "orca-mini",
-				Prompt: "why is the ocean blue?",
-				Stream: &stream,
+				Model:     "orca-mini",
+				Prompt:    "why is the ocean blue?",
+				Stream:    &stream,
+				KeepAlive: &api.Duration{Duration: 10 * time.Second},
 				Options: map[string]interface{}{
 				Options: map[string]interface{}{
 					"seed":        42,
 					"seed":        42,
 					"temperature": 0.0,
 					"temperature": 0.0,
 				},
 				},
 			}, {
 			}, {
-				Model:  "tinydolphin",
-				Prompt: "what is the origin of the us thanksgiving holiday?",
-				Stream: &stream,
+				Model:     "tinydolphin",
+				Prompt:    "what is the origin of the us thanksgiving holiday?",
+				Stream:    &stream,
+				KeepAlive: &api.Duration{Duration: 10 * time.Second},
 				Options: map[string]interface{}{
 				Options: map[string]interface{}{
 					"seed":        42,
 					"seed":        42,
 					"temperature": 0.0,
 					"temperature": 0.0,
@@ -38,42 +40,64 @@ func TestMultiModelConcurrency(t *testing.T) {
 		}
 		}
 		resp = [2][]string{
 		resp = [2][]string{
 			[]string{"sunlight"},
 			[]string{"sunlight"},
-			[]string{"england", "english", "massachusetts", "pilgrims"},
+			[]string{"england", "english", "massachusetts", "pilgrims", "british"},
 		}
 		}
 	)
 	)
 	var wg sync.WaitGroup
 	var wg sync.WaitGroup
 	wg.Add(len(req))
 	wg.Add(len(req))
-	ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*240)
 	defer cancel()
 	defer cancel()
+
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+
+	for i := 0; i < len(req); i++ {
+		require.NoError(t, PullIfMissing(ctx, client, req[i].Model))
+	}
+
 	for i := 0; i < len(req); i++ {
 	for i := 0; i < len(req); i++ {
 		go func(i int) {
 		go func(i int) {
 			defer wg.Done()
 			defer wg.Done()
-			GenerateTestHelper(ctx, t, req[i], resp[i])
+			DoGenerate(ctx, t, client, req[i], resp[i], 60*time.Second, 10*time.Second)
 		}(i)
 		}(i)
 	}
 	}
 	wg.Wait()
 	wg.Wait()
 }
 }
 
 
 func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
 func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
-	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) // GTX 750 2G card takes ~9 minutes
+	req, resp := GenerateRequests()
+	reqLimit := len(req)
+	iterLimit := 5
+
+	vram := os.Getenv("OLLAMA_MAX_VRAM")
+	if vram != "" {
+		max, err := strconv.ParseUint(vram, 10, 64)
+		require.NoError(t, err)
+		// Don't hammer on small VRAM cards...
+		if max < 4*1024*1024*1024 {
+			reqLimit = min(reqLimit, 2)
+			iterLimit = 2
+		}
+	}
+
+	ctx, cancel := context.WithTimeout(context.Background(), 9*time.Minute)
 	defer cancel()
 	defer cancel()
 	client, _, cleanup := InitServerConnection(ctx, t)
 	client, _, cleanup := InitServerConnection(ctx, t)
 	defer cleanup()
 	defer cleanup()
 
 
-	req, resp := GenerateRequests()
 	// Get the server running (if applicable) warm the model up with a single initial request
 	// Get the server running (if applicable) warm the model up with a single initial request
-	DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 5*time.Second)
+	DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 10*time.Second)
 
 
 	var wg sync.WaitGroup
 	var wg sync.WaitGroup
-	wg.Add(len(req))
-	for i := 0; i < len(req); i++ {
+	wg.Add(reqLimit)
+	for i := 0; i < reqLimit; i++ {
 		go func(i int) {
 		go func(i int) {
 			defer wg.Done()
 			defer wg.Done()
-			for j := 0; j < 5; j++ {
+			for j := 0; j < iterLimit; j++ {
 				slog.Info("Starting", "req", i, "iter", j)
 				slog.Info("Starting", "req", i, "iter", j)
-				// On slower GPUs it can take a while to process the 4 concurrent requests
+				// On slower GPUs it can take a while to process the concurrent requests
 				// so we allow a much longer initial timeout
 				// so we allow a much longer initial timeout
-				DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
+				DoGenerate(ctx, t, client, req[i], resp[i], 120*time.Second, 20*time.Second)
 			}
 			}
 		}(i)
 		}(i)
 	}
 	}
@@ -221,5 +245,23 @@ func TestMultiModelStress(t *testing.T) {
 			}
 			}
 		}(i)
 		}(i)
 	}
 	}
+	go func() {
+		for {
+			time.Sleep(2 * time.Second)
+			select {
+			case <-ctx.Done():
+				return
+			default:
+				models, err := client.ListRunning(ctx)
+				if err != nil {
+					slog.Warn("failed to list running models", "error", err)
+					continue
+				}
+				for _, m := range models.Models {
+					slog.Info("loaded model snapshot", "model", m)
+				}
+			}
+		}
+	}()
 	wg.Wait()
 	wg.Wait()
 }
 }

+ 2 - 1
integration/context_test.go

@@ -11,7 +11,8 @@ import (
 )
 )
 
 
 func TestContextExhaustion(t *testing.T) {
 func TestContextExhaustion(t *testing.T) {
-	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) // TODO maybe shorter?
+	// Longer needed for small footprint GPUs
+	ctx, cancel := context.WithTimeout(context.Background(), 6*time.Minute)
 	defer cancel()
 	defer cancel()
 	// Set up the test data
 	// Set up the test data
 	req := api.GenerateRequest{
 	req := api.GenerateRequest{

+ 5 - 1
integration/llm_image_test.go

@@ -32,7 +32,11 @@ func TestIntegrationMultimodal(t *testing.T) {
 	resp := "the ollam"
 	resp := "the ollam"
 	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
 	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
 	defer cancel()
 	defer cancel()
-	GenerateTestHelper(ctx, t, req, []string{resp})
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+	require.NoError(t, PullIfMissing(ctx, client, req.Model))
+	// llava models on CPU can be quite slow to start,
+	DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
 }
 }
 
 
 const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
 const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb

+ 22 - 17
integration/utils_test.go

@@ -140,7 +140,7 @@ func PullIfMissing(ctx context.Context, client *api.Client, modelName string) er
 
 
 	showCtx, cancel := context.WithDeadlineCause(
 	showCtx, cancel := context.WithDeadlineCause(
 		ctx,
 		ctx,
-		time.Now().Add(5*time.Second),
+		time.Now().Add(10*time.Second),
 		fmt.Errorf("show for existing model %s took too long", modelName),
 		fmt.Errorf("show for existing model %s took too long", modelName),
 	)
 	)
 	defer cancel()
 	defer cancel()
@@ -287,41 +287,46 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
 func GenerateRequests() ([]api.GenerateRequest, [][]string) {
 func GenerateRequests() ([]api.GenerateRequest, [][]string) {
 	return []api.GenerateRequest{
 	return []api.GenerateRequest{
 			{
 			{
-				Model:  "orca-mini",
-				Prompt: "why is the ocean blue?",
-				Stream: &stream,
+				Model:     "orca-mini",
+				Prompt:    "why is the ocean blue?",
+				Stream:    &stream,
+				KeepAlive: &api.Duration{Duration: 10 * time.Second},
 				Options: map[string]interface{}{
 				Options: map[string]interface{}{
 					"seed":        42,
 					"seed":        42,
 					"temperature": 0.0,
 					"temperature": 0.0,
 				},
 				},
 			}, {
 			}, {
-				Model:  "orca-mini",
-				Prompt: "why is the color of dirt brown?",
-				Stream: &stream,
+				Model:     "orca-mini",
+				Prompt:    "why is the color of dirt brown?",
+				Stream:    &stream,
+				KeepAlive: &api.Duration{Duration: 10 * time.Second},
 				Options: map[string]interface{}{
 				Options: map[string]interface{}{
 					"seed":        42,
 					"seed":        42,
 					"temperature": 0.0,
 					"temperature": 0.0,
 				},
 				},
 			}, {
 			}, {
-				Model:  "orca-mini",
-				Prompt: "what is the origin of the us thanksgiving holiday?",
-				Stream: &stream,
+				Model:     "orca-mini",
+				Prompt:    "what is the origin of the us thanksgiving holiday?",
+				Stream:    &stream,
+				KeepAlive: &api.Duration{Duration: 10 * time.Second},
 				Options: map[string]interface{}{
 				Options: map[string]interface{}{
 					"seed":        42,
 					"seed":        42,
 					"temperature": 0.0,
 					"temperature": 0.0,
 				},
 				},
 			}, {
 			}, {
-				Model:  "orca-mini",
-				Prompt: "what is the origin of independence day?",
-				Stream: &stream,
+				Model:     "orca-mini",
+				Prompt:    "what is the origin of independence day?",
+				Stream:    &stream,
+				KeepAlive: &api.Duration{Duration: 10 * time.Second},
 				Options: map[string]interface{}{
 				Options: map[string]interface{}{
 					"seed":        42,
 					"seed":        42,
 					"temperature": 0.0,
 					"temperature": 0.0,
 				},
 				},
 			}, {
 			}, {
-				Model:  "orca-mini",
-				Prompt: "what is the composition of air?",
-				Stream: &stream,
+				Model:     "orca-mini",
+				Prompt:    "what is the composition of air?",
+				Stream:    &stream,
+				KeepAlive: &api.Duration{Duration: 10 * time.Second},
 				Options: map[string]interface{}{
 				Options: map[string]interface{}{
 					"seed":        42,
 					"seed":        42,
 					"temperature": 0.0,
 					"temperature": 0.0,
@@ -331,7 +336,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
 		[][]string{
 		[][]string{
 			[]string{"sunlight"},
 			[]string{"sunlight"},
 			[]string{"soil", "organic", "earth", "black", "tan"},
 			[]string{"soil", "organic", "earth", "black", "tan"},
-			[]string{"england", "english", "massachusetts", "pilgrims"},
+			[]string{"england", "english", "massachusetts", "pilgrims", "british"},
 			[]string{"fourth", "july", "declaration", "independence"},
 			[]string{"fourth", "july", "declaration", "independence"},
 			[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
 			[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
 		}
 		}

+ 7 - 7
llm/ext_server/server.cpp

@@ -2335,9 +2335,9 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, g
                 invalid_param = true;
                 invalid_param = true;
                 break;
                 break;
             }
             }
-#ifndef GGML_USE_CUBLAS
-            fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Setting the split mode has no effect.\n");
-#endif // GGML_USE_CUBLAS
+#ifndef GGML_USE_CUDA
+            fprintf(stderr, "warning: llama.cpp was compiled without CUDA. Setting the split mode has no effect.\n");
+#endif // GGML_USE_CUDA
         }
         }
         else if (arg == "--tensor-split" || arg == "-ts")
         else if (arg == "--tensor-split" || arg == "-ts")
         {
         {
@@ -2346,7 +2346,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, g
                 invalid_param = true;
                 invalid_param = true;
                 break;
                 break;
             }
             }
-#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_SYCL)
+#if defined(GGML_USE_CUDA) || defined(GGML_USE_SYCL)
             std::string arg_next = argv[i];
             std::string arg_next = argv[i];
 
 
             // split string by , and /
             // split string by , and /
@@ -2367,8 +2367,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, g
                 }
                 }
             }
             }
 #else
 #else
-            LOG_WARNING("llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n", {});
-#endif // GGML_USE_CUBLAS
+            LOG_WARNING("llama.cpp was compiled without CUDA. It is not possible to set a tensor split.\n", {});
+#endif // GGML_USE_CUDA
         }
         }
         else if (arg == "--main-gpu" || arg == "-mg")
         else if (arg == "--main-gpu" || arg == "-mg")
         {
         {
@@ -2377,7 +2377,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, g
                 invalid_param = true;
                 invalid_param = true;
                 break;
                 break;
             }
             }
-#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_SYCL)
+#if defined(GGML_USE_CUDA) || defined(GGML_USE_SYCL)
             params.main_gpu = std::stoi(argv[i]);
             params.main_gpu = std::stoi(argv[i]);
 #else
 #else
             LOG_WARNING("llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.", {});
             LOG_WARNING("llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.", {});

+ 1 - 0
llm/ggml.go

@@ -307,6 +307,7 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
 
 
 		partialOffload = 4 * batch * embedding
 		partialOffload = 4 * batch * embedding
 		partialOffload += max(
 		partialOffload += max(
+			// 4*batch*(4+6*embedding+context*(2*heads)+llm.KV().GQA()),
 			4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV),
 			4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV),
 			4*batch*(embedding+vocab)+embedding*vocab*105/128,
 			4*batch*(embedding+vocab)+embedding*vocab*105/128,
 		)
 		)

+ 202 - 55
llm/memory.go

@@ -1,11 +1,11 @@
 package llm
 package llm
 
 
 import (
 import (
-	"fmt"
 	"log/slog"
 	"log/slog"
+	"strconv"
+	"strings"
 
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/api"
-	"github.com/ollama/ollama/envconfig"
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/gpu"
 	"github.com/ollama/ollama/gpu"
 )
 )
@@ -16,7 +16,8 @@ func PredictServerFit(allGpus gpu.GpuInfoList, ggml *GGML, adapters, projectors
 	var estimatedVRAM uint64
 	var estimatedVRAM uint64
 	for _, gpus := range allGpus.ByLibrary() {
 	for _, gpus := range allGpus.ByLibrary() {
 		var layerCount int
 		var layerCount int
-		layerCount, estimatedVRAM, _ = EstimateGPULayers(gpus, ggml, projectors, opts)
+		estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
+		layerCount, estimatedVRAM = estimate.Layers, estimate.VRAMSize
 		if opts.NumGPU < 0 {
 		if opts.NumGPU < 0 {
 			if layerCount > 0 && layerCount >= int(ggml.KV().BlockCount()+1) {
 			if layerCount > 0 && layerCount >= int(ggml.KV().BlockCount()+1) {
 				return true, estimatedVRAM
 				return true, estimatedVRAM
@@ -30,24 +31,64 @@ func PredictServerFit(allGpus gpu.GpuInfoList, ggml *GGML, adapters, projectors
 	return false, estimatedVRAM
 	return false, estimatedVRAM
 }
 }
 
 
+type MemoryEstimate struct {
+	// How many layers we predict we can load
+	Layers int
+
+	// The size of the graph which occupies the main GPU
+	Graph uint64
+
+	// How much VRAM will be allocated given the number of layers we predict
+	VRAMSize uint64
+
+	// The total size of the model if loaded into VRAM.  If all layers are loaded, VRAMSize == TotalSize
+	TotalSize uint64
+
+	// For multi-GPU scenarios, this provides the tensor split parameter
+	TensorSplit string
+
+	// For multi-GPU scenarios, this is the size in bytes per GPU
+	GPUSizes []uint64
+}
+
 // Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
 // Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
 // The GPUs provided must all be the same Library
 // The GPUs provided must all be the same Library
-func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts api.Options) (int, uint64, uint64) {
-	var memoryAvailable uint64
-	for _, info := range gpus {
-		memoryAvailable += info.FreeMemory
-	}
-	if envconfig.MaxVRAM > 0 {
-		memoryAvailable = envconfig.MaxVRAM
-	}
+func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts api.Options) MemoryEstimate {
+	// Graph size for a partial offload, applies to all GPUs
+	var graphPartialOffload uint64
+
+	// Graph size when all layers are offloaded, applies to all GPUs
+	var graphFullOffload uint64
+
+	// Final graph offload once we know full or partial
+	var graphOffload uint64
+
+	// Projectors loaded into GPU0 only
+	var projectorSize uint64
+
+	// Conditional output size on GPU 0
+	var memoryLayerOutput uint64
 
 
-	slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", format.HumanBytes2(memoryAvailable))
+	// The sizes of a layer
+	var layerSize uint64
 
 
-	// TODO - this is probably wrong, first GPU vs secondaries will have different overheads
-	memoryMinimum := gpus[0].MinimumMemory
+	// The sum of all the layer sizes (just for logging)
+	var memoryWeights uint64
+
+	// True if all the layers are loaded
+	var fullyLoaded bool
+
+	// Overflow that didn't fit into the GPU
+	var overflow uint64
+
+	availableList := make([]string, len(gpus))
+	for i, gpu := range gpus {
+		availableList[i] = format.HumanBytes2(gpu.FreeMemory)
+	}
+	slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", availableList)
 
 
 	for _, projector := range projectors {
 	for _, projector := range projectors {
-		memoryMinimum += projectorMemoryRequirements(projector)
+		projectorSize += projectorMemoryRequirements(projector)
 
 
 		// multimodal models require at least 2048 context
 		// multimodal models require at least 2048 context
 		opts.NumCtx = max(opts.NumCtx, 2048)
 		opts.NumCtx = max(opts.NumCtx, 2048)
@@ -56,79 +97,160 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts
 	layers := ggml.Tensors().Layers()
 	layers := ggml.Tensors().Layers()
 	// add one layer worth of memory as a buffer
 	// add one layer worth of memory as a buffer
 	if blk0, ok := layers["blk.0"]; ok {
 	if blk0, ok := layers["blk.0"]; ok {
-		memoryMinimum += blk0.size()
+		layerSize = blk0.size()
+	} else {
+		slog.Warn("model missing blk.0 layer size")
 	}
 	}
 
 
 	// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
 	// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
 	var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
 	var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
 
 
-	graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
+	// KV is proportional to the number of layers
+	layerSize += kv / ggml.KV().BlockCount()
+
+	graphPartialOffload, graphFullOffload = ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
 	if graphPartialOffload == 0 {
 	if graphPartialOffload == 0 {
 		graphPartialOffload = ggml.KV().GQA() * kv / 6
 		graphPartialOffload = ggml.KV().GQA() * kv / 6
 	}
 	}
-
 	if graphFullOffload == 0 {
 	if graphFullOffload == 0 {
 		graphFullOffload = graphPartialOffload
 		graphFullOffload = graphPartialOffload
 	}
 	}
 
 
-	graphFullOffload *= uint64(len(gpus))
-	graphPartialOffload *= uint64(len(gpus))
-
 	// on metal there's no partial offload overhead
 	// on metal there's no partial offload overhead
 	if gpus[0].Library == "metal" {
 	if gpus[0].Library == "metal" {
 		graphPartialOffload = graphFullOffload
 		graphPartialOffload = graphFullOffload
+	} else if len(gpus) > 1 {
+		// multigpu should always use the partial graph size
+		graphFullOffload = graphPartialOffload
 	}
 	}
 
 
-	// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
-	memoryRequiredTotal := memoryMinimum + graphFullOffload
-
-	// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
-	memoryRequiredPartial := memoryMinimum + graphPartialOffload
-
-	var memoryLayerOutput uint64
 	if layer, ok := layers["output_norm"]; ok {
 	if layer, ok := layers["output_norm"]; ok {
 		memoryLayerOutput += layer.size()
 		memoryLayerOutput += layer.size()
 	}
 	}
-
 	if layer, ok := layers["output"]; ok {
 	if layer, ok := layers["output"]; ok {
 		memoryLayerOutput += layer.size()
 		memoryLayerOutput += layer.size()
 	} else if layer, ok := layers["token_embd"]; ok {
 	} else if layer, ok := layers["token_embd"]; ok {
 		memoryLayerOutput += layer.size()
 		memoryLayerOutput += layer.size()
 	}
 	}
 
 
-	if gpus[0].Library == "metal" && opts.UseMMap {
-		// memory is preallocated for output tensors
-		memoryRequiredTotal += memoryLayerOutput
-		memoryRequiredPartial += memoryLayerOutput
-	}
+	// Output layer handled at the end if we have space
+	gpuZeroOverhead := projectorSize
 
 
+	// Reduce set of GPUs to only those that have sufficient space to fit overhead and at least one layer
 	var layerCount int
 	var layerCount int
+	layerCounts := make([]int, len(gpus))
+	gpuAllocations := make([]uint64, len(gpus))
+	type gs struct {
+		i int
+		g *gpu.GpuInfo
+	}
+	gpusWithSpace := []gs{}
+	for i := range gpus {
+		var gzo uint64
+		if len(gpusWithSpace) == 0 {
+			gzo = gpuZeroOverhead
+		}
+		// Only include GPUs that can fit the graph, gpu minimum, the layer buffer and at least more layer
+		if gpus[i].FreeMemory < gzo+max(graphPartialOffload, graphFullOffload)+gpus[i].MinimumMemory+2*layerSize {
+			slog.Debug("gpu has too little memory to allocate any layers", "gpu", gpus[i])
+			continue
+		}
+		gpusWithSpace = append(gpusWithSpace, gs{i, &gpus[i]})
+		gpuAllocations[i] += gpus[i].MinimumMemory + layerSize // We hold off on graph until we know partial vs. full
+	}
+
+	var gpuZeroID int
+	if len(gpusWithSpace) > 0 {
+		gpuZeroID = gpusWithSpace[0].i
+		gpuAllocations[gpuZeroID] += gpuZeroOverhead
+	}
+
+	// For all the layers, find where they can fit on the GPU(s)
 	for i := range int(ggml.KV().BlockCount()) {
 	for i := range int(ggml.KV().BlockCount()) {
-		if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
-			memoryLayer := blk.size()
+		memoryWeights += layerSize
 
 
-			// KV is proportional to the number of layers
-			memoryLayer += kv / ggml.KV().BlockCount()
+		if opts.NumGPU >= 0 && layerCount >= opts.NumGPU {
+			// Stop allocating on GPU(s) once we hit the users target NumGPU
+			continue
+		}
 
 
-			memoryRequiredTotal += memoryLayer
-			if (opts.NumGPU >= 0 && layerCount+1 <= opts.NumGPU) || (opts.NumGPU < 0 && memoryAvailable > memoryRequiredPartial+memoryLayer) {
-				memoryRequiredPartial += memoryLayer
+		// distribute the layers across the GPU(s) that have space
+		for j := len(gpusWithSpace); j > 0; j-- {
+			g := gpusWithSpace[i%j]
+			used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload)
+			if g.g.FreeMemory > used+layerSize {
+				gpuAllocations[g.i] += layerSize
+				layerCounts[g.i]++
 				layerCount++
 				layerCount++
+				break
+			} else {
+				gpusWithSpace = append(gpusWithSpace[:i%j], gpusWithSpace[i%j+1:]...)
 			}
 			}
 		}
 		}
 	}
 	}
+	if layerCount >= int(ggml.KV().BlockCount()) {
+		fullyLoaded = true
+	} else {
+		for i := layerCount; i < int(ggml.KV().BlockCount()); i++ {
+			overflow += layerSize
+		}
+	}
+
+	// Determine if we need to consider output then find where it fits
+	if memoryLayerOutput > 0 && (opts.NumGPU < 0 || layerCount < opts.NumGPU) {
+		for j := len(gpusWithSpace); j > 0; j-- {
+			g := gpusWithSpace[layerCount%j]
+			used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload)
+			if g.g.FreeMemory > used+memoryLayerOutput {
+				gpuAllocations[g.i] += memoryLayerOutput
+				layerCounts[g.i]++
+				layerCount++
+				break
+			}
+		}
 
 
-	if gpus[0].Library != "metal" || !opts.UseMMap {
-		// memory was not preallocated for output tensors
-		memoryRequiredTotal += memoryLayerOutput
+		if layerCount < int(ggml.KV().BlockCount())+1 {
+			fullyLoaded = false
+			overflow += memoryLayerOutput
+		}
 	}
 	}
 
 
-	if (opts.NumGPU >= 0 && layerCount+1 <= opts.NumGPU) || (opts.NumGPU < 0 && memoryAvailable > memoryRequiredTotal) {
-		layerCount = int(ggml.KV().BlockCount()) + 1
-		memoryRequiredPartial = memoryRequiredTotal
+	// Add the applicable (full or partial) graph allocations
+	for i := range gpus {
+		if layerCounts[i] <= 0 {
+			continue
+		}
+		if fullyLoaded {
+			gpuAllocations[i] += graphFullOffload
+		} else {
+			gpuAllocations[i] += graphPartialOffload
+		}
+	}
+	if fullyLoaded {
+		graphOffload = graphFullOffload
+	} else {
+		graphOffload = graphPartialOffload
 	}
 	}
 
 
-	memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
+	// Summaries for the log
+	var memoryRequiredPartial, memoryRequiredTotal uint64
+	for i := range gpuAllocations {
+		memoryRequiredPartial += gpuAllocations[i]
+	}
+	memoryRequiredTotal = memoryRequiredPartial + overflow
+
+	tensorSplit := ""
+	if len(gpus) > 1 {
+		splits := make([]string, len(gpus))
+		for i, count := range layerCounts {
+			splits[i] = strconv.Itoa(count)
+		}
+		tensorSplit = strings.Join(splits, ",")
+	}
+	allocationsList := []string{}
+	for _, a := range gpuAllocations {
+		allocationsList = append(allocationsList, format.HumanBytes2(a))
+	}
 
 
 	slog.Info(
 	slog.Info(
 		"offload to gpu",
 		"offload to gpu",
@@ -136,13 +258,17 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts
 			"layers",
 			"layers",
 			// requested number of layers to offload
 			// requested number of layers to offload
 			"requested", opts.NumGPU,
 			"requested", opts.NumGPU,
+			// The number of layers the model has (including output)
+			"model", int(ggml.KV().BlockCount())+1,
 			// estimated number of layers that can be offloaded
 			// estimated number of layers that can be offloaded
-			"real", layerCount,
+			"offload", layerCount,
+			// multi-gpu split for tesnors
+			"split", tensorSplit,
 		),
 		),
 		slog.Group(
 		slog.Group(
 			"memory",
 			"memory",
-			// memory available for offloading
-			"available", format.HumanBytes2(memoryAvailable),
+			// memory available by GPU for offloading
+			"available", availableList,
 			slog.Group(
 			slog.Group(
 				"required",
 				"required",
 				// memory required for full offloading
 				// memory required for full offloading
@@ -151,6 +277,8 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts
 				"partial", format.HumanBytes2(memoryRequiredPartial),
 				"partial", format.HumanBytes2(memoryRequiredPartial),
 				// memory of KV cache
 				// memory of KV cache
 				"kv", format.HumanBytes2(kv),
 				"kv", format.HumanBytes2(kv),
+				// Allocations across the GPUs
+				"allocations", allocationsList,
 			),
 			),
 			slog.Group(
 			slog.Group(
 				"weights",
 				"weights",
@@ -171,12 +299,31 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts
 		),
 		),
 	)
 	)
 	if gpus[0].Library == "cpu" {
 	if gpus[0].Library == "cpu" {
-		return 0, 0, memoryRequiredTotal
+		return MemoryEstimate{
+			Layers:    0,
+			Graph:     0,
+			VRAMSize:  0,
+			TotalSize: memoryRequiredTotal,
+			GPUSizes:  []uint64{},
+		}
 	}
 	}
-	if memoryRequiredPartial > memoryAvailable {
+	if layerCount == 0 {
 		slog.Debug("insufficient VRAM to load any model layers")
 		slog.Debug("insufficient VRAM to load any model layers")
-		return 0, 0, memoryRequiredTotal
+		return MemoryEstimate{
+			Layers:    0,
+			Graph:     0,
+			VRAMSize:  0,
+			TotalSize: memoryRequiredTotal,
+			GPUSizes:  []uint64{},
+		}
 	}
 	}
 
 
-	return layerCount, memoryRequiredPartial, memoryRequiredTotal
+	return MemoryEstimate{
+		Layers:      layerCount,
+		Graph:       graphOffload,
+		VRAMSize:    memoryRequiredPartial,
+		TotalSize:   memoryRequiredTotal,
+		TensorSplit: tensorSplit,
+		GPUSizes:    gpuAllocations,
+	}
 }
 }

+ 127 - 0
llm/memory_test.go

@@ -0,0 +1,127 @@
+package llm
+
+import (
+	"bytes"
+	"encoding/binary"
+	"fmt"
+	"os"
+	"testing"
+
+	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/envconfig"
+	"github.com/ollama/ollama/gpu"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func TestEstimateGPULayers(t *testing.T) {
+	envconfig.Debug = true
+	modelName := "dummy"
+	f, err := os.CreateTemp(t.TempDir(), modelName)
+	require.NoError(t, err)
+	defer f.Close()
+	gguf := NewGGUFV3(binary.LittleEndian)
+	inputLayerCount := 5
+	tensors := []Tensor{
+		{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
+		{Name: "blk.1.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
+		{Name: "blk.2.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
+		{Name: "blk.3.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
+		{Name: "blk.4.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
+		{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
+	}
+	assert.Len(t, tensors, inputLayerCount+1)
+	err = gguf.Encode(f, KV{
+		"general.architecture":          "llama",
+		"general.name":                  "name",
+		"llama.context_length":          uint32(32),
+		"llama.embedding_length":        uint32(4096),
+		"llama.block_count":             uint32(inputLayerCount),
+		"llama.attention.head_count":    uint32(32),
+		"llama.attention.head_count_kv": uint32(32),
+		"tokenizer.ggml.tokens":         []string{" "},
+		"tokenizer.ggml.scores":         []float32{0},
+		"tokenizer.ggml.token_type":     []int32{0},
+	}, tensors)
+	require.NoError(t, err)
+
+	ggml, err := LoadModel(f.Name())
+	require.NoError(t, err)
+
+	// Simple CPU scenario
+	gpus := []gpu.GpuInfo{
+		{
+			Library: "cpu",
+		},
+	}
+	projectors := []string{}
+	opts := api.DefaultOptions()
+	t.Run("cpu", func(t *testing.T) {
+		estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
+		assert.Equal(t, 0, estimate.Layers)
+		assert.Equal(t, uint64(0), estimate.Graph)
+	})
+
+	// derived from the dummy ggml file above
+	graphPartialOffload := uint64(202377216)
+	graphFullOffload := uint64(171968512)
+	layerSize := uint64(33554436)
+	projectorSize := uint64(0)
+	memoryLayerOutput := uint64(4)
+
+	// Dual CUDA scenario with assymetry
+	gpuMinimumMemory := uint64(2048)
+	gpus = []gpu.GpuInfo{
+		{
+			Library:       "cuda",
+			MinimumMemory: gpuMinimumMemory,
+		},
+		{
+			Library:       "cuda",
+			MinimumMemory: gpuMinimumMemory,
+		},
+	}
+	// Nested array: GPU0 layer space, GPU1 layer space, expected gpu0, expected gpu1
+	for i, s := range []struct {
+		layer0, layer1   uint64
+		expect0, expect1 uint64
+	}{
+		{1, 1, 1, 1},
+		{2, 1, 2, 1},
+		{2, 2, 2, 2},
+		{1, 2, 1, 2},
+		{3, 3, 3, 3},
+		{4, 4, 3, 3},
+		{6, 6, 3, 3},
+		{0, 3, 0, 3},
+	} {
+		t.Run(fmt.Sprintf("%v", s), func(t *testing.T) {
+			gpus[0].FreeMemory = 0
+			gpus[1].FreeMemory = 0
+			gpus[0].FreeMemory += projectorSize
+			if s.layer0 > 0 {
+				gpus[0].FreeMemory += memoryLayerOutput
+			} else {
+				gpus[1].FreeMemory += memoryLayerOutput
+			}
+			gpus[0].FreeMemory += gpuMinimumMemory + layerSize + s.layer0*layerSize + 1
+			gpus[1].FreeMemory += gpuMinimumMemory + layerSize + s.layer1*layerSize + 1
+			gpus[0].FreeMemory += max(graphFullOffload, graphPartialOffload)
+			gpus[1].FreeMemory += max(graphFullOffload, graphPartialOffload)
+			estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
+			assert.Equal(t, int(s.expect0+s.expect1), estimate.Layers, "scenario %d: %v", i, s)
+			assert.Equal(t, fmt.Sprintf("%d,%d", s.expect0, s.expect1), estimate.TensorSplit, "scenario %d: %v", i, s)
+			var layerSums uint64
+			for _, b := range estimate.GPUSizes {
+				layerSums += b
+			}
+			if estimate.Layers < inputLayerCount+1 {
+				assert.Less(t, estimate.VRAMSize, estimate.TotalSize, "scenario %d: %v %+v", i, s, estimate)
+				assert.Equal(t, estimate.VRAMSize, layerSums, "scenario %d: %v %+v", i, s, estimate)
+			} else {
+				assert.Equal(t, estimate.VRAMSize, estimate.TotalSize, "scenario %d: %v %+v", i, s, estimate)
+				assert.Equal(t, estimate.TotalSize, layerSums, "scenario %d: %v %+v", i, s, estimate)
+			}
+		})
+	}
+}

+ 8 - 8
llm/payload.go

@@ -82,8 +82,8 @@ func serversForGpu(info gpu.GpuInfo) []string {
 	// glob workDir for files that start with ollama_
 	// glob workDir for files that start with ollama_
 	availableServers := availableServers()
 	availableServers := availableServers()
 	requested := info.Library
 	requested := info.Library
-	if info.Variant != "" {
-		requested += "_" + info.Variant
+	if info.Variant != gpu.CPUCapabilityNone {
+		requested += "_" + info.Variant.String()
 	}
 	}
 
 
 	servers := []string{}
 	servers := []string{}
@@ -117,14 +117,14 @@ func serversForGpu(info gpu.GpuInfo) []string {
 
 
 	// Load up the best CPU variant if not primary requested
 	// Load up the best CPU variant if not primary requested
 	if info.Library != "cpu" {
 	if info.Library != "cpu" {
-		variant := gpu.GetCPUVariant()
+		variant := gpu.GetCPUCapability()
 		// If no variant, then we fall back to default
 		// If no variant, then we fall back to default
 		// If we have a variant, try that if we find an exact match
 		// If we have a variant, try that if we find an exact match
 		// Attempting to run the wrong CPU instructions will panic the
 		// Attempting to run the wrong CPU instructions will panic the
 		// process
 		// process
-		if variant != "" {
+		if variant != gpu.CPUCapabilityNone {
 			for cmp := range availableServers {
 			for cmp := range availableServers {
-				if cmp == "cpu_"+variant {
+				if cmp == "cpu_"+variant.String() {
 					servers = append(servers, cmp)
 					servers = append(servers, cmp)
 					break
 					break
 				}
 				}
@@ -146,11 +146,11 @@ func serverForCpu() string {
 	if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
 	if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
 		return "metal"
 		return "metal"
 	}
 	}
-	variant := gpu.GetCPUVariant()
+	variant := gpu.GetCPUCapability()
 	availableServers := availableServers()
 	availableServers := availableServers()
-	if variant != "" {
+	if variant != gpu.CPUCapabilityNone {
 		for cmp := range availableServers {
 		for cmp := range availableServers {
-			if cmp == "cpu_"+variant {
+			if cmp == "cpu_"+variant.String() {
 				return cmp
 				return cmp
 			}
 			}
 		}
 		}

+ 50 - 36
llm/server.go

@@ -37,8 +37,9 @@ type LlamaServer interface {
 	Tokenize(ctx context.Context, content string) ([]int, error)
 	Tokenize(ctx context.Context, content string) ([]int, error)
 	Detokenize(ctx context.Context, tokens []int) (string, error)
 	Detokenize(ctx context.Context, tokens []int) (string, error)
 	Close() error
 	Close() error
-	EstimatedVRAM() uint64
+	EstimatedVRAM() uint64 // Total VRAM across all GPUs
 	EstimatedTotal() uint64
 	EstimatedTotal() uint64
+	EstimatedVRAMByGPU(gpuID string) uint64
 }
 }
 
 
 // llmServer is an instance of the llama.cpp server
 // llmServer is an instance of the llama.cpp server
@@ -49,13 +50,12 @@ type llmServer struct {
 	status  *StatusWriter
 	status  *StatusWriter
 	options api.Options
 	options api.Options
 
 
-	// TODO - this should be broken down by GPU
-	estimatedVRAM  uint64 // Estimated usage of VRAM by the loaded model
-	estimatedTotal uint64 // Total size of model
-	totalLayers    uint64
-	gpuCount       int
-	loadDuration   time.Duration // Record how long it took the model to load
-	loadProgress   float32
+	estimate    MemoryEstimate
+	totalLayers uint64
+	// gpuCount     int
+	gpus         gpu.GpuInfoList // Recorded just before the model loaded, free space will be incorrect
+	loadDuration time.Duration   // Record how long it took the model to load
+	loadProgress float32
 
 
 	sem *semaphore.Weighted
 	sem *semaphore.Weighted
 }
 }
@@ -80,16 +80,16 @@ func LoadModel(model string) (*GGML, error) {
 func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) {
 func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) {
 	var err error
 	var err error
 	var cpuRunner string
 	var cpuRunner string
-	var estimatedVRAM uint64
-	var estimatedTotal uint64
+	var estimate MemoryEstimate
 	var systemMemory uint64
 	var systemMemory uint64
-	gpuCount := len(gpus)
-	if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 {
-		// TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner
 
 
+	// If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info
+	if opts.NumGPU == 0 {
+		gpus = gpu.GetCPUInfo()
+	}
+	if len(gpus) == 1 && gpus[0].Library == "cpu" {
 		cpuRunner = serverForCpu()
 		cpuRunner = serverForCpu()
-		gpuCount = 0
-		_, _, estimatedTotal = EstimateGPULayers(gpus, ggml, projectors, opts)
+		estimate = EstimateGPULayers(gpus, ggml, projectors, opts)
 	} else {
 	} else {
 		if gpus[0].Library == "metal" {
 		if gpus[0].Library == "metal" {
 			memInfo, err := gpu.GetCPUMem()
 			memInfo, err := gpu.GetCPUMem()
@@ -100,20 +100,19 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 				slog.Debug("system memory", "total", format.HumanBytes2(systemMemory))
 				slog.Debug("system memory", "total", format.HumanBytes2(systemMemory))
 			}
 			}
 		}
 		}
-		var layers int
-		layers, estimatedVRAM, estimatedTotal = EstimateGPULayers(gpus, ggml, projectors, opts)
+		estimate = EstimateGPULayers(gpus, ggml, projectors, opts)
 
 
 		switch {
 		switch {
-		case gpus[0].Library == "metal" && estimatedVRAM > systemMemory:
+		case gpus[0].Library == "metal" && estimate.VRAMSize > systemMemory:
 			// disable partial offloading when model is greater than total system memory as this
 			// disable partial offloading when model is greater than total system memory as this
 			// can lead to locking up the system
 			// can lead to locking up the system
 			opts.NumGPU = 0
 			opts.NumGPU = 0
-		case gpus[0].Library != "metal" && layers == 0:
+		case gpus[0].Library != "metal" && estimate.Layers == 0:
 			// Don't bother loading into the GPU if no layers can fit
 			// Don't bother loading into the GPU if no layers can fit
 			cpuRunner = serverForCpu()
 			cpuRunner = serverForCpu()
-			gpuCount = 0
-		case opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu":
-			opts.NumGPU = layers
+			gpus = gpu.GetCPUInfo()
+		case opts.NumGPU < 0 && estimate.Layers > 0 && gpus[0].Library != "cpu":
+			opts.NumGPU = estimate.Layers
 		}
 		}
 	}
 	}
 
 
@@ -232,6 +231,14 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 
 
 	params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
 	params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
 
 
+	if estimate.TensorSplit != "" {
+		params = append(params, "--tensor-split", estimate.TensorSplit)
+	}
+
+	if estimate.TensorSplit != "" {
+		params = append(params, "--tensor-split", estimate.TensorSplit)
+	}
+
 	for i := range len(servers) {
 	for i := range len(servers) {
 		dir := availableServers[servers[i]]
 		dir := availableServers[servers[i]]
 		if dir == "" {
 		if dir == "" {
@@ -242,8 +249,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 		}
 		}
 
 
 		if strings.HasPrefix(servers[i], "cpu") {
 		if strings.HasPrefix(servers[i], "cpu") {
-			// TODO if we tried a gpu runner first, and it failed, record the error and bubble that back up
-			gpuCount = 0
+			gpus = gpu.GetCPUInfo()
 		}
 		}
 
 
 		// Find an availableServers  port, retry on each iteration in case the failure was a port conflict race
 		// Find an availableServers  port, retry on each iteration in case the failure was a port conflict race
@@ -299,16 +305,15 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 		}
 		}
 
 
 		s := &llmServer{
 		s := &llmServer{
-			port:           port,
-			cmd:            exec.Command(server, finalParams...),
-			status:         NewStatusWriter(os.Stderr),
-			options:        opts,
-			estimatedVRAM:  estimatedVRAM,
-			estimatedTotal: estimatedTotal,
-			sem:            semaphore.NewWeighted(int64(numParallel)),
-			totalLayers:    ggml.KV().BlockCount() + 1,
-			gpuCount:       gpuCount,
-			done:           make(chan error, 1),
+			port:        port,
+			cmd:         exec.Command(server, finalParams...),
+			status:      NewStatusWriter(os.Stderr),
+			options:     opts,
+			estimate:    estimate,
+			sem:         semaphore.NewWeighted(int64(numParallel)),
+			totalLayers: ggml.KV().BlockCount() + 1,
+			gpus:        gpus,
+			done:        make(chan error, 1),
 		}
 		}
 
 
 		s.cmd.Env = os.Environ()
 		s.cmd.Env = os.Environ()
@@ -1004,11 +1009,20 @@ func (s *llmServer) Close() error {
 }
 }
 
 
 func (s *llmServer) EstimatedVRAM() uint64 {
 func (s *llmServer) EstimatedVRAM() uint64 {
-	return s.estimatedVRAM
+	return s.estimate.VRAMSize
 }
 }
 
 
 func (s *llmServer) EstimatedTotal() uint64 {
 func (s *llmServer) EstimatedTotal() uint64 {
-	return s.estimatedTotal
+	return s.estimate.TotalSize
+}
+
+func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 {
+	for i, gpu := range s.gpus {
+		if gpu.ID == gpuID {
+			return s.estimate.GPUSizes[i]
+		}
+	}
+	return 0
 }
 }
 
 
 func parseDurationMs(ms float64) time.Duration {
 func parseDurationMs(ms float64) time.Duration {

+ 127 - 42
server/sched.go

@@ -7,7 +7,6 @@ import (
 	"log/slog"
 	"log/slog"
 	"reflect"
 	"reflect"
 	"runtime"
 	"runtime"
-	"slices"
 	"sort"
 	"sort"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
@@ -27,6 +26,7 @@ type LlmRequest struct {
 	sessionDuration time.Duration
 	sessionDuration time.Duration
 	successCh       chan *runnerRef
 	successCh       chan *runnerRef
 	errCh           chan error
 	errCh           chan error
+	schedAttempts   uint
 }
 }
 
 
 type Scheduler struct {
 type Scheduler struct {
@@ -38,9 +38,11 @@ type Scheduler struct {
 	loaded   map[string]*runnerRef
 	loaded   map[string]*runnerRef
 	loadedMu sync.Mutex
 	loadedMu sync.Mutex
 
 
-	loadFn      func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList)
-	newServerFn func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error)
-	getGpuFn    func() gpu.GpuInfoList
+	loadFn       func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList)
+	newServerFn  func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error)
+	getGpuFn     func() gpu.GpuInfoList
+	getCpuFn     func() gpu.GpuInfoList
+	reschedDelay time.Duration
 }
 }
 
 
 var ErrMaxQueue = fmt.Errorf("server busy, please try again.  maximum pending requests exceeded")
 var ErrMaxQueue = fmt.Errorf("server busy, please try again.  maximum pending requests exceeded")
@@ -54,6 +56,8 @@ func InitScheduler(ctx context.Context) *Scheduler {
 		loaded:        make(map[string]*runnerRef),
 		loaded:        make(map[string]*runnerRef),
 		newServerFn:   llm.NewLlamaServer,
 		newServerFn:   llm.NewLlamaServer,
 		getGpuFn:      gpu.GetGPUInfo,
 		getGpuFn:      gpu.GetGPUInfo,
+		getCpuFn:      gpu.GetCPUInfo,
+		reschedDelay:  250 * time.Millisecond,
 	}
 	}
 	sched.loadFn = sched.load
 	sched.loadFn = sched.load
 	return sched
 	return sched
@@ -105,6 +109,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
 			return
 			return
 		case pending := <-s.pendingReqCh:
 		case pending := <-s.pendingReqCh:
 			// Block other requests until we get this pending request running
 			// Block other requests until we get this pending request running
+			pending.schedAttempts++
 
 
 			if pending.ctx.Err() != nil {
 			if pending.ctx.Err() != nil {
 				slog.Debug("pending request cancelled or timed out, skipping scheduling")
 				slog.Debug("pending request cancelled or timed out, skipping scheduling")
@@ -131,7 +136,12 @@ func (s *Scheduler) processPending(ctx context.Context) {
 				} else {
 				} else {
 					// Either no models are loaded or below envconfig.MaxRunners
 					// Either no models are loaded or below envconfig.MaxRunners
 					// Get a refreshed GPU list
 					// Get a refreshed GPU list
-					gpus := s.getGpuFn()
+					var gpus gpu.GpuInfoList
+					if pending.opts.NumGPU == 0 {
+						gpus = s.getCpuFn()
+					} else {
+						gpus = s.getGpuFn()
+					}
 
 
 					// Load model for fitting
 					// Load model for fitting
 					ggml, err := llm.LoadModel(pending.model.ModelPath)
 					ggml, err := llm.LoadModel(pending.model.ModelPath)
@@ -140,16 +150,22 @@ func (s *Scheduler) processPending(ctx context.Context) {
 						break
 						break
 					}
 					}
 
 
-					// If we're CPU only mode, just limit by envconfig.MaxRunners above
-					// TODO handle system memory exhaustion
-					if (len(gpus) == 1 && gpus[0].Library == "cpu") || pending.opts.NumGPU == 0 {
-						slog.Debug("cpu mode with existing models, loading")
-						s.loadFn(pending, ggml, gpus)
-						break
-					}
-
-					// No models loaded. Load the model but prefer the best fit.
-					if loadedCount == 0 {
+					// Evaluate if the model will fit in the available system memory, or if we should unload a model first
+					if len(gpus) == 1 && gpus[0].Library == "cpu" {
+						if loadedCount == 0 {
+							slog.Debug("cpu mode with first model, loading")
+							s.loadFn(pending, ggml, gpus)
+							break
+						}
+						runnerToExpire = s.maybeFindCPURunnerToUnload(pending, ggml, gpus)
+						if runnerToExpire == nil {
+							slog.Debug("cpu mode with available system memory or first model, loading")
+							s.loadFn(pending, ggml, gpus)
+							break
+						}
+						// else we need to expire a runner
+					} else if loadedCount == 0 {
+						// No models loaded. Load the model but prefer the best fit.
 						slog.Debug("loading first model", "model", pending.model.ModelPath)
 						slog.Debug("loading first model", "model", pending.model.ModelPath)
 						g := pickBestFitGPUs(pending, ggml, gpus)
 						g := pickBestFitGPUs(pending, ggml, gpus)
 						if g != nil {
 						if g != nil {
@@ -159,16 +175,44 @@ func (s *Scheduler) processPending(ctx context.Context) {
 						break
 						break
 					}
 					}
 
 
-					// More than one loaded model, so we have to see if the new one fits
-					// Update free memory from currently loaded models
-					s.updateFreeSpace(gpus)
-					gpus = pickBestFitGPUs(pending, ggml, gpus)
-					if gpus != nil {
-						slog.Debug("new model fits with existing models, loading")
-						s.loadFn(pending, ggml, gpus)
-						break
+					if runnerToExpire == nil {
+						// More than one loaded model, so we have to see if the
+						// new one fits
+						//
+						// We want to avoid loading on any GPUs that have other
+						// models still loading on them to avoid potential races
+						// with VRAM consumption ramping up during load
+						availGpus := s.filterGPUsWithoutLoadingModels(gpus)
+
+						// Update free memory from currently loaded models
+						s.updateFreeSpace(availGpus)
+						fitGpus := pickBestFitGPUs(pending, ggml, availGpus)
+						if fitGpus != nil {
+							slog.Debug("new model fits with existing models, loading")
+							s.loadFn(pending, ggml, fitGpus)
+							break
+						}
+
+						// We couldn't find a set of GPUs to fully load the new
+						// model. If no other models are loading (both GPU lists
+						// are the same) then we need to unload another model to
+						// make room
+						if len(availGpus) < len(gpus) {
+							// There are other requests pending, and this one
+							// needs more time, so put it on the back of the
+							// queue so that we might satisfy other pending
+							// requests that aren't blocked
+							go func() {
+								// Process in a go routine to avoid deadlocking
+								// the scheduler if our queue is full
+								slog.Debug("delaying scheduling while other models finish loading", "attempts", pending.schedAttempts, "model", pending.model.ModelPath)
+								time.Sleep(s.reschedDelay)
+								s.pendingReqCh <- pending
+							}()
+							break
+						}
+						runnerToExpire = s.findRunnerToUnload()
 					}
 					}
-					runnerToExpire = s.findRunnerToUnload()
 				}
 				}
 
 
 				if runnerToExpire == nil {
 				if runnerToExpire == nil {
@@ -368,17 +412,9 @@ func (s *Scheduler) updateFreeSpace(allGpus gpu.GpuInfoList) {
 	s.loadedMu.Lock()
 	s.loadedMu.Lock()
 	for _, r := range s.loaded {
 	for _, r := range s.loaded {
 		r.refMu.Lock()
 		r.refMu.Lock()
-		gpuIDs := make([]string, 0, len(r.gpus))
 		if r.llama != nil {
 		if r.llama != nil {
-			// TODO this should be broken down by GPU instead of assuming uniform spread
-			estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus))
-			for _, gpu := range r.gpus {
-				gpuIDs = append(gpuIDs, gpu.ID)
-			}
 			for _, gpu := range allGpus {
 			for _, gpu := range allGpus {
-				if slices.Contains(gpuIDs, gpu.ID) {
-					predMap[predKey{gpu.Library, gpu.ID}] += estimatedVRAMPerGPU
-				}
+				predMap[predKey{gpu.Library, gpu.ID}] += r.llama.EstimatedVRAMByGPU(gpu.ID)
 			}
 			}
 		} else {
 		} else {
 			slog.Warn("unexpected nil runner reference, memory prediction may be incorrect")
 			slog.Warn("unexpected nil runner reference, memory prediction may be incorrect")
@@ -401,11 +437,36 @@ func (s *Scheduler) updateFreeSpace(allGpus gpu.GpuInfoList) {
 				// after we start our first runner, then we'll never acount for that, so picking the smallest free value seems prudent.
 				// after we start our first runner, then we'll never acount for that, so picking the smallest free value seems prudent.
 				allGpus[i].FreeMemory = allGpus[i].TotalMemory - p
 				allGpus[i].FreeMemory = allGpus[i].TotalMemory - p
 			}
 			}
-			slog.Info("updated VRAM", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "total", format.HumanBytes2(allGpus[i].TotalMemory), "available", format.HumanBytes2(allGpus[i].FreeMemory))
+			slog.Info("updated VRAM based on existing loaded models", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "total", format.HumanBytes2(allGpus[i].TotalMemory), "available", format.HumanBytes2(allGpus[i].FreeMemory))
 		}
 		}
 	}
 	}
 }
 }
 
 
+// While models are loading the VRAM consumption numbers will be indeterminate, so we have
+// to avoid scheduling another model on the same GPU(s) that haven't stabilized.
+// This routine returns the set of GPUs that do not have an active loading model.
+// If all GPUs have loading models, an empty list will be returned (not a single CPU entry)
+func (s *Scheduler) filterGPUsWithoutLoadingModels(allGpus gpu.GpuInfoList) gpu.GpuInfoList {
+	ret := append(gpu.GpuInfoList{}, allGpus...)
+	s.loadedMu.Lock()
+	defer s.loadedMu.Unlock()
+	for _, runner := range s.loaded {
+		if runner.loading {
+			slog.Debug("overlapping loads detected", "gpus", runner.gpus, "model", runner.modelPath)
+			for _, busyGPU := range runner.gpus {
+				for i := range ret {
+					if ret[i].ID == busyGPU.ID {
+						ret = append(ret[:i], ret[i+1:]...)
+						break
+					}
+				}
+			}
+		}
+	}
+	return ret
+}
+
+// TODO consolidate sched_types.go
 type runnerRef struct {
 type runnerRef struct {
 	refMu sync.Mutex
 	refMu sync.Mutex
 	// refCond   sync.Cond // Signaled on transition from 1 -> 0 refCount
 	// refCond   sync.Cond // Signaled on transition from 1 -> 0 refCount
@@ -487,8 +548,11 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
 func (runner *runnerRef) waitForVRAMRecovery() chan interface{} {
 func (runner *runnerRef) waitForVRAMRecovery() chan interface{} {
 	finished := make(chan interface{}, 1)
 	finished := make(chan interface{}, 1)
 
 
-	// CPU or Metal don't need checking, so no waiting required, windows can page VRAM, and the APIs we query tend to be optimistic on free space
-	if (len(runner.gpus) == 1 && (runner.gpus[0].Library == "cpu" || runner.gpus[0].Library == "metal")) || runtime.GOOS == "windows" {
+	// CPU or Metal don't need checking, so no waiting required
+	// windows can page VRAM, only cuda currently can report accurate used vram usage
+	if len(runner.gpus) == 0 ||
+		(len(runner.gpus) == 1 && (runner.gpus[0].Library == "cpu" || runner.gpus[0].Library == "metal")) ||
+		(runtime.GOOS == "windows" && runner.gpus[0].Library != "cuda") {
 		finished <- struct{}{}
 		finished <- struct{}{}
 		return finished
 		return finished
 	}
 	}
@@ -508,7 +572,7 @@ func (runner *runnerRef) waitForVRAMRecovery() chan interface{} {
 		for {
 		for {
 			<-ticker.C
 			<-ticker.C
 			if time.Now().After(expiresAt) {
 			if time.Now().After(expiresAt) {
-				slog.Warn("gpu VRAM usage didn't recover within timeout", "seconds", time.Since(start).Seconds())
+				slog.Warn("gpu VRAM usage didn't recover within timeout", "seconds", time.Since(start).Seconds(), "model", runner.modelPath)
 				finished <- struct{}{}
 				finished <- struct{}{}
 			}
 			}
 
 
@@ -521,7 +585,7 @@ func (runner *runnerRef) waitForVRAMRecovery() chan interface{} {
 			}
 			}
 			// If we're within ~80% of the estimated memory usage recovered, bail out
 			// If we're within ~80% of the estimated memory usage recovered, bail out
 			if float32(freeMemoryNow-freeMemoryBefore) > float32(runner.estimatedVRAM)*0.8 {
 			if float32(freeMemoryNow-freeMemoryBefore) > float32(runner.estimatedVRAM)*0.8 {
-				slog.Debug(fmt.Sprintf("gpu VRAM free memory converged after %0.2f seconds", time.Since(start).Seconds()))
+				slog.Debug(fmt.Sprintf("gpu VRAM free memory converged after %0.2f seconds", time.Since(start).Seconds()), "model", runner.modelPath)
 				finished <- struct{}{}
 				finished <- struct{}{}
 				return
 				return
 			}
 			}
@@ -558,10 +622,12 @@ func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) gpu.
 		sort.Sort(sort.Reverse(gpu.ByFreeMemory(sgl)))
 		sort.Sort(sort.Reverse(gpu.ByFreeMemory(sgl)))
 
 
 		// First attempt to fit the model into a single GPU
 		// First attempt to fit the model into a single GPU
-		for _, g := range sgl {
-			if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
-				slog.Debug("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
-				return []gpu.GpuInfo{g}
+		if !envconfig.SchedSpread {
+			for _, g := range sgl {
+				if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
+					slog.Debug("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
+					return []gpu.GpuInfo{g}
+				}
 			}
 			}
 		}
 		}
 
 
@@ -586,6 +652,10 @@ func (s *Scheduler) findRunnerToUnload() *runnerRef {
 		runnerList = append(runnerList, r)
 		runnerList = append(runnerList, r)
 	}
 	}
 	s.loadedMu.Unlock()
 	s.loadedMu.Unlock()
+	if len(runnerList) == 0 {
+		slog.Debug("no loaded runner to unload")
+		return nil
+	}
 
 
 	// In the future we can enhance the algorithm to be smarter about picking the optimal runner to unload
 	// In the future we can enhance the algorithm to be smarter about picking the optimal runner to unload
 	// e.g., if we have multiple options, will one make room for the request?
 	// e.g., if we have multiple options, will one make room for the request?
@@ -616,3 +686,18 @@ func (s *Scheduler) unloadAllRunners() {
 		}
 		}
 	}
 	}
 }
 }
+
+// If other runners are loaded, make sure the pending request will fit in system memory
+// If not, pick a runner to unload, else return nil and the request can be loaded
+func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) *runnerRef {
+	slog.Debug("evaluating if CPU model load will fit in available system memory")
+	estimate := llm.EstimateGPULayers(gpus, ggml, req.model.ProjectorPaths, req.opts)
+	if estimate.TotalSize <= gpus[0].FreeMemory {
+		slog.Debug("cpu inference mode, model fits in available system memory", "model", format.HumanBytes2(estimate.TotalSize), "available", format.HumanBytes2(gpus[0].FreeMemory))
+		return nil
+	}
+
+	// TODO - optimization: try to find CPU only runners first, or partial offloads with enough in system memory to make room
+
+	return s.findRunnerToUnload()
+}

+ 71 - 28
server/sched_test.go

@@ -60,7 +60,7 @@ func TestLoad(t *testing.T) {
 	err := <-req.errCh
 	err := <-req.errCh
 	require.Contains(t, err.Error(), "this model may be incompatible")
 	require.Contains(t, err.Error(), "this model may be incompatible")
 
 
-	server := &mockLlm{estimatedVRAM: 10}
+	server := &mockLlm{estimatedVRAM: 10, estimatedVRAMByGPU: map[string]uint64{}}
 	s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
 	s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
 		return server, nil
 		return server, nil
 	}
 	}
@@ -129,6 +129,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
 		"tokenizer.ggml.token_type":     []int32{0},
 		"tokenizer.ggml.token_type":     []int32{0},
 	}, []llm.Tensor{
 	}, []llm.Tensor{
 		{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
 		{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
+		{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
 	})
 	})
 	require.NoError(t, err)
 	require.NoError(t, err)
 
 
@@ -145,7 +146,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
 		successCh:       make(chan *runnerRef, 1),
 		successCh:       make(chan *runnerRef, 1),
 		errCh:           make(chan error, 1),
 		errCh:           make(chan error, 1),
 	}
 	}
-	scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM}
+	scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
 	return scenario
 	return scenario
 }
 }
 
 
@@ -155,7 +156,7 @@ func TestRequests(t *testing.T) {
 
 
 	// Same model, same request
 	// Same model, same request
 	scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
 	scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
-	scenario1a.req.sessionDuration = 0
+	scenario1a.req.sessionDuration = 5 * time.Millisecond
 	scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
 	scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
 	scenario1b.req.model = scenario1a.req.model
 	scenario1b.req.model = scenario1a.req.model
 	scenario1b.ggml = scenario1a.ggml
 	scenario1b.ggml = scenario1a.ggml
@@ -166,6 +167,7 @@ func TestRequests(t *testing.T) {
 	tmpModel := *scenario1a.req.model
 	tmpModel := *scenario1a.req.model
 	scenario2a.req.model = &tmpModel
 	scenario2a.req.model = &tmpModel
 	scenario2a.ggml = scenario1a.ggml
 	scenario2a.ggml = scenario1a.ggml
+	scenario2a.req.sessionDuration = 5 * time.Millisecond
 
 
 	// Multiple loaded models
 	// Multiple loaded models
 	scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
 	scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
@@ -181,6 +183,12 @@ func TestRequests(t *testing.T) {
 		g.FreeMemory = 12 * format.GigaByte
 		g.FreeMemory = 12 * format.GigaByte
 		return []gpu.GpuInfo{g}
 		return []gpu.GpuInfo{g}
 	}
 	}
+	s.getCpuFn = func() gpu.GpuInfoList {
+		g := gpu.GpuInfo{Library: "cpu"}
+		g.TotalMemory = 32 * format.GigaByte
+		g.FreeMemory = 26 * format.GigaByte
+		return []gpu.GpuInfo{g}
+	}
 	s.newServerFn = scenario1a.newServer
 	s.newServerFn = scenario1a.newServer
 	slog.Info("scenario1a")
 	slog.Info("scenario1a")
 	s.pendingReqCh <- scenario1a.req
 	s.pendingReqCh <- scenario1a.req
@@ -309,7 +317,6 @@ func TestGetRunner(t *testing.T) {
 	ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
 	ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
 	defer done()
 	defer done()
 
 
-	// Same model, same request
 	scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
 	scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
 	scenario1a.req.sessionDuration = 0
 	scenario1a.req.sessionDuration = 0
 	scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
 	scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
@@ -419,7 +426,7 @@ func TestUseLoadedRunner(t *testing.T) {
 		sessionDuration: 2,
 		sessionDuration: 2,
 	}
 	}
 	finished := make(chan *LlmRequest)
 	finished := make(chan *LlmRequest)
-	llm1 := &mockLlm{}
+	llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
 	r1 := &runnerRef{llama: llm1, sessionDuration: 1}
 	r1 := &runnerRef{llama: llm1, sessionDuration: 1}
 	req.useLoadedRunner(r1, finished)
 	req.useLoadedRunner(r1, finished)
 	require.Equal(t, uint(1), r1.refCount)
 	require.Equal(t, uint(1), r1.refCount)
@@ -452,8 +459,8 @@ func TestUpdateFreeSpace(t *testing.T) {
 	gpus[0].FreeMemory = 900
 	gpus[0].FreeMemory = 900
 	gpus[1].TotalMemory = 2000
 	gpus[1].TotalMemory = 2000
 	gpus[1].FreeMemory = 1900
 	gpus[1].FreeMemory = 1900
-	llm1 := &mockLlm{estimatedVRAM: 100}
-	llm2 := &mockLlm{estimatedVRAM: 200}
+	llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{"1": 50, "2": 50}}
+	llm2 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{"1": 125, "2": 75}}
 	r1 := &runnerRef{llama: llm1, gpus: gpus}
 	r1 := &runnerRef{llama: llm1, gpus: gpus}
 	r2 := &runnerRef{llama: llm2, gpus: gpus}
 	r2 := &runnerRef{llama: llm2, gpus: gpus}
 
 
@@ -464,8 +471,42 @@ func TestUpdateFreeSpace(t *testing.T) {
 	s.loadedMu.Unlock()
 	s.loadedMu.Unlock()
 
 
 	s.updateFreeSpace(gpus)
 	s.updateFreeSpace(gpus)
-	require.Equal(t, uint64(850), gpus[0].FreeMemory)
-	require.Equal(t, uint64(1850), gpus[1].FreeMemory)
+	require.Equal(t, uint64(1000-50-125), gpus[0].FreeMemory)
+	require.Equal(t, uint64(2000-50-75), gpus[1].FreeMemory)
+}
+
+func TestFilterGPUsWithoutLoadingModels(t *testing.T) {
+	ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
+	defer done()
+	gpus := gpu.GpuInfoList{
+		{
+			Library: "cuda",
+			ID:      "0",
+		},
+		{
+			Library: "cuda",
+			ID:      "1",
+		},
+	}
+	r1 := &runnerRef{gpus: gpu.GpuInfoList{gpus[0]}, loading: true}
+
+	s := InitScheduler(ctx)
+	s.loadedMu.Lock()
+	s.loaded["a"] = r1
+	s.loadedMu.Unlock()
+
+	tmp := s.filterGPUsWithoutLoadingModels(gpus)
+	require.Len(t, tmp, 1)
+	require.Equal(t, "1", tmp[0].ID)
+
+	r1.gpus = gpu.GpuInfoList{gpus[1]}
+	tmp = s.filterGPUsWithoutLoadingModels(gpus)
+	require.Len(t, tmp, 1)
+	require.Equal(t, "0", tmp[0].ID)
+
+	r1.gpus = gpu.GpuInfoList{}
+	tmp = s.filterGPUsWithoutLoadingModels(gpus)
+	require.Len(t, tmp, 2)
 }
 }
 
 
 func TestFindRunnerToUnload(t *testing.T) {
 func TestFindRunnerToUnload(t *testing.T) {
@@ -492,7 +533,7 @@ func TestNeedsReload(t *testing.T) {
 	ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
 	ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
 	defer done()
 	defer done()
 
 
-	llm := &mockLlm{}
+	llm := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
 	do := api.DefaultOptions()
 	do := api.DefaultOptions()
 	runner := &runnerRef{
 	runner := &runnerRef{
 		model:   &Model{AdapterPaths: []string{"adapter1"}, ProjectorPaths: []string{"projector1"}},
 		model:   &Model{AdapterPaths: []string{"adapter1"}, ProjectorPaths: []string{"projector1"}},
@@ -535,8 +576,8 @@ func TestUnloadAllRunners(t *testing.T) {
 	ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
 	ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
 	defer done()
 	defer done()
 
 
-	llm1 := &mockLlm{}
-	llm2 := &mockLlm{}
+	llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
+	llm2 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
 	s := InitScheduler(ctx)
 	s := InitScheduler(ctx)
 	s.unloadAllRunners()
 	s.unloadAllRunners()
 
 
@@ -554,7 +595,7 @@ func TestUnloadAllRunners(t *testing.T) {
 }
 }
 
 
 func TestUnload(t *testing.T) {
 func TestUnload(t *testing.T) {
-	llm1 := &mockLlm{}
+	llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
 	r1 := &runnerRef{llama: llm1}
 	r1 := &runnerRef{llama: llm1}
 	r2 := &runnerRef{model: &Model{AdapterPaths: []string{"A"}}}
 	r2 := &runnerRef{model: &Model{AdapterPaths: []string{"A"}}}
 	r1.unload()
 	r1.unload()
@@ -564,19 +605,20 @@ func TestUnload(t *testing.T) {
 }
 }
 
 
 type mockLlm struct {
 type mockLlm struct {
-	pingResp          error
-	waitResp          error
-	completionResp    error
-	embeddingResp     []float64
-	embeddingRespErr  error
-	tokenizeResp      []int
-	tokenizeRespErr   error
-	detokenizeResp    string
-	detonekizeRespErr error
-	closeResp         error
-	closeCalled       bool
-	estimatedVRAM     uint64
-	estimatedTotal    uint64
+	pingResp           error
+	waitResp           error
+	completionResp     error
+	embeddingResp      []float64
+	embeddingRespErr   error
+	tokenizeResp       []int
+	tokenizeRespErr    error
+	detokenizeResp     string
+	detonekizeRespErr  error
+	closeResp          error
+	closeCalled        bool
+	estimatedVRAM      uint64
+	estimatedTotal     uint64
+	estimatedVRAMByGPU map[string]uint64
 }
 }
 
 
 func (s *mockLlm) Ping(ctx context.Context) error             { return s.pingResp }
 func (s *mockLlm) Ping(ctx context.Context) error             { return s.pingResp }
@@ -597,5 +639,6 @@ func (s *mockLlm) Close() error {
 	s.closeCalled = true
 	s.closeCalled = true
 	return s.closeResp
 	return s.closeResp
 }
 }
-func (s *mockLlm) EstimatedVRAM() uint64  { return s.estimatedVRAM }
-func (s *mockLlm) EstimatedTotal() uint64 { return s.estimatedTotal }
+func (s *mockLlm) EstimatedVRAM() uint64                  { return s.estimatedVRAM }
+func (s *mockLlm) EstimatedTotal() uint64                 { return s.estimatedTotal }
+func (s *mockLlm) EstimatedVRAMByGPU(gpuid string) uint64 { return s.estimatedVRAMByGPU[gpuid] }