Bläddra i källkod

Request and model concurrency

This change adds support for multiple concurrent requests, as well as
loading multiple models by spawning multiple runners. The default
settings are currently set at 1 concurrent request per model and only 1
loaded model at a time, but these can be adjusted by setting
OLLAMA_NUM_PARALLEL and OLLAMA_MAX_LOADED_MODELS.
Daniel Hiltgen 1 år sedan
förälder
incheckning
34b9db5afc

+ 7 - 0
api/client.go

@@ -91,6 +91,13 @@ func ClientFromEnvironment() (*Client, error) {
 	}, nil
 }
 
+func NewClient(base *url.URL, http *http.Client) *Client {
+	return &Client{
+		base: base,
+		http: http,
+	}
+}
+
 func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
 	var reqBody io.Reader
 	var data []byte

+ 1 - 0
format/bytes.go

@@ -15,6 +15,7 @@ const (
 
 	KibiByte = Byte * 1024
 	MebiByte = KibiByte * 1024
+	GibiByte = MebiByte * 1024
 )
 
 func HumanBytes(b int64) string {

+ 56 - 14
gpu/amd_common.go

@@ -7,7 +7,7 @@ import (
 	"log/slog"
 	"os"
 	"path/filepath"
-	"strconv"
+	"runtime"
 	"strings"
 )
 
@@ -35,22 +35,64 @@ func GetSupportedGFX(libDir string) ([]string, error) {
 	return ret, nil
 }
 
-func amdSetVisibleDevices(ids []int, skip map[int]interface{}) {
-	// Set the visible devices if not already set
-	// TODO - does sort order matter?
-	devices := []string{}
-	for i := range ids {
-		if _, skipped := skip[i]; skipped {
+func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
+	ids := []string{}
+	for _, info := range gpuInfo {
+		if info.Library != "rocm" {
+			// TODO shouldn't happen if things are wired correctly...
+			slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
 			continue
 		}
-		devices = append(devices, strconv.Itoa(i))
+		ids = append(ids, info.ID)
 	}
+	return "HIP_VISIBLE_DEVICES", strings.Join(ids, ",")
+}
 
-	val := strings.Join(devices, ",")
-	err := os.Setenv("HIP_VISIBLE_DEVICES", val)
-	if err != nil {
-		slog.Warn(fmt.Sprintf("failed to set env: %s", err))
-	} else {
-		slog.Info("Setting HIP_VISIBLE_DEVICES=" + val)
+func commonAMDValidateLibDir() (string, error) {
+	// We try to favor system paths first, so that we can wire up the subprocess to use
+	// the system version.  Only use our bundled version if the system version doesn't work
+	// This gives users a more recovery options if versions have subtle problems at runtime
+
+	// Prefer explicit HIP env var
+	hipPath := os.Getenv("HIP_PATH")
+	if hipPath != "" {
+		hipLibDir := filepath.Join(hipPath, "bin")
+		if rocmLibUsable(hipLibDir) {
+			slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
+			return hipLibDir, nil
+		}
+	}
+
+	// Scan the LD_LIBRARY_PATH or PATH
+	pathEnv := "LD_LIBRARY_PATH"
+	if runtime.GOOS == "windows" {
+		pathEnv = "PATH"
+	}
+
+	paths := os.Getenv(pathEnv)
+	for _, path := range filepath.SplitList(paths) {
+		d, err := filepath.Abs(path)
+		if err != nil {
+			continue
+		}
+		if rocmLibUsable(d) {
+			return d, nil
+		}
+	}
+
+	// Well known location(s)
+	if rocmLibUsable(RocmStandardLocation) {
+		return RocmStandardLocation, nil
+	}
+
+	// Installer payload location if we're running the installed binary
+	exe, err := os.Executable()
+	if err == nil {
+		rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
+		if rocmLibUsable(rocmTargetDir) {
+			slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
+			return rocmTargetDir, nil
+		}
 	}
+	return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
 }

+ 2 - 2
gpu/amd_hip_windows.go

@@ -69,7 +69,7 @@ func NewHipLib() (*HipLib, error) {
 func (hl *HipLib) Release() {
 	err := windows.FreeLibrary(hl.dll)
 	if err != nil {
-		slog.Warn(fmt.Sprintf("failed to unload amdhip64.dll: %s", err))
+		slog.Warn("failed to unload amdhip64.dll", "error", err)
 	}
 	hl.dll = 0
 }
@@ -98,7 +98,7 @@ func (hl *HipLib) HipGetDeviceCount() int {
 		return 0
 	}
 	if status != hipSuccess {
-		slog.Warn(fmt.Sprintf("failed call to hipGetDeviceCount: %d %s", status, err))
+		slog.Warn("failed call to hipGetDeviceCount", "status", status, "error", err)
 	}
 	return count
 }

+ 174 - 285
gpu/amd_linux.go

@@ -11,6 +11,8 @@ import (
 	"slices"
 	"strconv"
 	"strings"
+
+	"github.com/ollama/ollama/format"
 )
 
 // Discovery logic for AMD/ROCm GPUs
@@ -24,9 +26,6 @@ const (
 	GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
 	GPUUsedMemoryFileGlob  = "mem_banks/*/used_memory"
 	RocmStandardLocation   = "/opt/rocm/lib"
-
-	// TODO find a better way to detect iGPU instead of minimum memory
-	IGPUMemLimit = 1024 * 1024 * 1024 // 512G is what they typically report, so anything less than 1G must be iGPU
 )
 
 var (
@@ -35,14 +34,11 @@ var (
 )
 
 // Gather GPU information from the amdgpu driver if any supported GPUs are detected
-// HIP_VISIBLE_DEVICES will be set if we detect a mix of unsupported and supported devices
-// and the user hasn't already set this variable
-func AMDGetGPUInfo(resp *GpuInfo) {
-	// TODO - DRY this out with windows
+func AMDGetGPUInfo() []GpuInfo {
+	resp := []GpuInfo{}
 	if !AMDDetected() {
-		return
+		return resp
 	}
-	skip := map[int]interface{}{}
 
 	// Opportunistic logging of driver version to aid in troubleshooting
 	ver, err := AMDDriverVersion()
@@ -50,160 +46,117 @@ func AMDGetGPUInfo(resp *GpuInfo) {
 		slog.Info("AMD Driver: " + ver)
 	} else {
 		// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
-		slog.Warn(fmt.Sprintf("ollama recommends running the https://www.amd.com/en/support/linux-drivers: %s", err))
+		slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err)
 	}
 
-	// If the user has specified exactly which GPUs to use, look up their memory
-	visibleDevices := os.Getenv("HIP_VISIBLE_DEVICES")
-	if visibleDevices != "" {
-		ids := []int{}
-		for _, idStr := range strings.Split(visibleDevices, ",") {
-			id, err := strconv.Atoi(idStr)
-			if err != nil {
-				slog.Warn(fmt.Sprintf("malformed HIP_VISIBLE_DEVICES=%s %s", visibleDevices, err))
-			} else {
-				ids = append(ids, id)
-			}
-		}
-		amdProcMemLookup(resp, nil, ids)
-		return
+	// Determine if the user has already pre-selected which GPUs to look at, then ignore the others
+	var visibleDevices []string
+	hipVD := os.Getenv("HIP_VISIBLE_DEVICES")   // zero based index only
+	rocrVD := os.Getenv("ROCR_VISIBLE_DEVICES") // zero based index or UUID, but consumer cards seem to not support UUID
+	gpuDO := os.Getenv("GPU_DEVICE_ORDINAL")    // zero based index
+	switch {
+	// TODO is this priorty order right?
+	case hipVD != "":
+		visibleDevices = strings.Split(hipVD, ",")
+	case rocrVD != "":
+		visibleDevices = strings.Split(rocrVD, ",")
+		// TODO - since we don't yet support UUIDs, consider detecting and reporting here
+		// all our test systems show GPU-XX indicating UUID is not supported
+	case gpuDO != "":
+		visibleDevices = strings.Split(gpuDO, ",")
 	}
 
-	// Gather GFX version information from all detected cards
-	gfx := AMDGFXVersions()
-	verStrings := []string{}
-	for i, v := range gfx {
-		verStrings = append(verStrings, v.ToGFXString())
-		if v.Major == 0 {
-			// Silently skip CPUs
-			skip[i] = struct{}{}
+	gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
+	var supported []string
+	libDir := ""
+
+	// The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract
+	// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
+	matches, _ := filepath.Glob(GPUPropertiesFileGlob)
+	cpuCount := 0
+	for _, match := range matches {
+		slog.Debug("evaluating amdgpu node " + match)
+		fp, err := os.Open(match)
+		if err != nil {
+			slog.Debug("failed to open sysfs node", "file", match, "error", err)
 			continue
 		}
-		if v.Major < 9 {
-			// TODO consider this a build-time setting if we can support 8xx family GPUs
-			slog.Warn(fmt.Sprintf("amdgpu [%d] too old %s", i, v.ToGFXString()))
-			skip[i] = struct{}{}
+		defer fp.Close()
+		nodeID, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
+		if err != nil {
+			slog.Debug("failed to parse node ID", "error", err)
+			continue
 		}
-	}
-	slog.Info(fmt.Sprintf("detected amdgpu versions %v", verStrings))
 
-	// Abort if all GPUs are skipped
-	if len(skip) >= len(gfx) {
-		slog.Info("all detected amdgpus are skipped, falling back to CPU")
-		return
-	}
+		scanner := bufio.NewScanner(fp)
+		isCPU := false
+		var major, minor, patch uint64
+		for scanner.Scan() {
+			line := strings.TrimSpace(scanner.Text())
+			// Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
+			if strings.HasPrefix(line, "gfx_target_version") {
+				ver := strings.Fields(line)
 
-	// If we got this far, then we have at least 1 GPU that's a ROCm candidate, so make sure we have a lib
-	libDir, err := AMDValidateLibDir()
-	if err != nil {
-		slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
-		return
-	}
+				// Detect CPUs
+				if len(ver) == 2 && ver[1] == "0" {
+					slog.Debug("detected CPU " + match)
+					isCPU = true
+					break
+				}
 
-	updateLibPath(libDir)
+				if len(ver) != 2 || len(ver[1]) < 5 {
+					slog.Warn("malformed "+match, "gfx_target_version", line)
+					// If this winds up being a CPU, our offsets may be wrong
+					continue
+				}
+				l := len(ver[1])
+				var err1, err2, err3 error
+				patch, err1 = strconv.ParseUint(ver[1][l-2:l], 10, 32)
+				minor, err2 = strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
+				major, err3 = strconv.ParseUint(ver[1][:l-4], 10, 32)
+				if err1 != nil || err2 != nil || err3 != nil {
+					slog.Debug("malformed int " + line)
+					continue
+				}
+			}
 
-	gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
-	if gfxOverride == "" {
-		supported, err := GetSupportedGFX(libDir)
-		if err != nil {
-			slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
-			return
+			// TODO - any other properties we want to extract and record?
+			// vendor_id + device_id -> pci lookup for "Name"
+			// Other metrics that may help us understand relative performance between multiple GPUs
 		}
-		slog.Debug(fmt.Sprintf("rocm supported GPU types %v", supported))
 
-		for i, v := range gfx {
-			if !slices.Contains[[]string, string](supported, v.ToGFXString()) {
-				slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported))
-				// TODO - consider discrete markdown just for ROCM troubleshooting?
-				slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
-				skip[i] = struct{}{}
-			} else {
-				slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString()))
-			}
+		if isCPU {
+			cpuCount++
+			continue
 		}
-	} else {
-		slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
-	}
 
-	if len(skip) >= len(gfx) {
-		slog.Info("all detected amdgpus are skipped, falling back to CPU")
-		return
-	}
+		// CPUs are always first in the list
+		gpuID := nodeID - cpuCount
 
-	ids := make([]int, len(gfx))
-	i := 0
-	for k := range gfx {
-		ids[i] = k
-		i++
-	}
-	amdProcMemLookup(resp, skip, ids)
-	if resp.memInfo.DeviceCount == 0 {
-		return
-	}
-	if len(skip) > 0 {
-		amdSetVisibleDevices(ids, skip)
-	}
-}
-
-func updateLibPath(libDir string) {
-	ldPaths := []string{}
-	if val, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
-		ldPaths = strings.Split(val, ":")
-	}
-	for _, d := range ldPaths {
-		if d == libDir {
-			return
-		}
-	}
-	val := strings.Join(append(ldPaths, libDir), ":")
-	slog.Debug("updated lib path", "LD_LIBRARY_PATH", val)
-	os.Setenv("LD_LIBRARY_PATH", val)
-}
-
-// Walk the sysfs nodes for the available GPUs and gather information from them
-// skipping over any devices in the skip map
-func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
-	resp.memInfo.DeviceCount = 0
-	resp.memInfo.TotalMemory = 0
-	resp.memInfo.FreeMemory = 0
-	slog.Debug("discovering VRAM for amdgpu devices")
-	if len(ids) == 0 {
-		entries, err := os.ReadDir(AMDNodesSysfsDir)
-		if err != nil {
-			slog.Warn(fmt.Sprintf("failed to read amdgpu sysfs %s - %s", AMDNodesSysfsDir, err))
-			return
-		}
-		for _, node := range entries {
-			if !node.IsDir() {
-				continue
-			}
-			id, err := strconv.Atoi(node.Name())
-			if err != nil {
-				slog.Warn("malformed amdgpu sysfs node id " + node.Name())
-				continue
-			}
-			ids = append(ids, id)
+		// Shouldn't happen, but just in case...
+		if gpuID < 0 {
+			slog.Error("unexpected amdgpu sysfs data resulted in negative GPU ID, please set OLLAMA_DEBUG=1 and report an issue")
+			return []GpuInfo{}
 		}
-	}
-	slog.Debug(fmt.Sprintf("amdgpu devices %v", ids))
 
-	for _, id := range ids {
-		if _, skipped := skip[id]; skipped {
+		if int(major) < RocmComputeMin {
+			slog.Warn(fmt.Sprintf("amdgpu too old gfx%d%d%d", major, minor, patch), "gpu", gpuID)
 			continue
 		}
+
+		// Look up the memory for the current node
 		totalMemory := uint64(0)
 		usedMemory := uint64(0)
-		// Adjust for sysfs vs HIP ids
-		propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id+1), GPUTotalMemoryFileGlob)
+		propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUTotalMemoryFileGlob)
 		propFiles, err := filepath.Glob(propGlob)
 		if err != nil {
-			slog.Warn(fmt.Sprintf("error looking up total GPU memory: %s %s", propGlob, err))
+			slog.Warn("error looking up total GPU memory", "glob", propGlob, "error", err)
 		}
 		// 1 or more memory banks - sum the values of all of them
 		for _, propFile := range propFiles {
 			fp, err := os.Open(propFile)
 			if err != nil {
-				slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", propFile, err))
+				slog.Warn("failed to open sysfs node", "file", propFile, "erroir", err)
 				continue
 			}
 			defer fp.Close()
@@ -226,49 +179,113 @@ func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
 			}
 		}
 		if totalMemory == 0 {
-			slog.Warn(fmt.Sprintf("amdgpu [%d] reports zero total memory, skipping", id))
-			skip[id] = struct{}{}
+			slog.Warn("amdgpu reports zero total memory", "gpu", gpuID)
 			continue
 		}
-		if totalMemory < IGPUMemLimit {
-			slog.Info(fmt.Sprintf("amdgpu [%d] appears to be an iGPU with %dM reported total memory, skipping", id, totalMemory/1024/1024))
-			skip[id] = struct{}{}
-			continue
-		}
-		usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUUsedMemoryFileGlob)
+		usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUUsedMemoryFileGlob)
 		usedFiles, err := filepath.Glob(usedGlob)
 		if err != nil {
-			slog.Warn(fmt.Sprintf("error looking up used GPU memory: %s %s", usedGlob, err))
+			slog.Warn("error looking up used GPU memory", "glob", usedGlob, "error", err)
 			continue
 		}
 		for _, usedFile := range usedFiles {
 			fp, err := os.Open(usedFile)
 			if err != nil {
-				slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", usedFile, err))
+				slog.Warn("failed to open sysfs node", "file", usedFile, "error", err)
 				continue
 			}
 			defer fp.Close()
 			data, err := io.ReadAll(fp)
 			if err != nil {
-				slog.Warn(fmt.Sprintf("failed to read sysfs node file %s: %s", usedFile, err))
+				slog.Warn("failed to read sysfs node", "file", usedFile, "error", err)
 				continue
 			}
 			used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
 			if err != nil {
-				slog.Warn(fmt.Sprintf("malformed used memory %s: %s", string(data), err))
+				slog.Warn("malformed used memory", "data", string(data), "error", err)
 				continue
 			}
 			usedMemory += used
 		}
-		slog.Info(fmt.Sprintf("[%d] amdgpu totalMemory %dM", id, totalMemory/1024/1024))
-		slog.Info(fmt.Sprintf("[%d] amdgpu freeMemory  %dM", id, (totalMemory-usedMemory)/1024/1024))
-		resp.memInfo.DeviceCount++
-		resp.memInfo.TotalMemory += totalMemory
-		resp.memInfo.FreeMemory += (totalMemory - usedMemory)
+
+		// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
+		if totalMemory < IGPUMemLimit {
+			slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
+			continue
+		}
+
+		slog.Info("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
+		slog.Info("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
+		gpuInfo := GpuInfo{
+			Library: "rocm",
+			memInfo: memInfo{
+				TotalMemory: totalMemory,
+				FreeMemory:  (totalMemory - usedMemory),
+			},
+			ID: fmt.Sprintf("%d", gpuID),
+			// Name: not exposed in sysfs directly, would require pci device id lookup
+			Major:         int(major),
+			Minor:         int(minor),
+			Patch:         int(patch),
+			MinimumMemory: rocmMinimumMemory,
+		}
+
+		// If the user wants to filter to a subset of devices, filter out if we aren't a match
+		if len(visibleDevices) > 0 {
+			include := false
+			for _, visible := range visibleDevices {
+				if visible == gpuInfo.ID {
+					include = true
+					break
+				}
+			}
+			if !include {
+				slog.Info("filtering out device per user request", "id", gpuInfo.ID, "visible_devices", visibleDevices)
+				continue
+			}
+		}
+
+		// Final validation is gfx compatibility - load the library if we haven't already loaded it
+		// even if the user overrides, we still need to validate the library
+		if libDir == "" {
+			libDir, err = AMDValidateLibDir()
+			if err != nil {
+				slog.Warn("unable to verify rocm library, will use cpu", "error", err)
+				return []GpuInfo{}
+			}
+		}
+		gpuInfo.DependencyPath = libDir
+
+		if gfxOverride == "" {
+			// Only load supported list once
+			if len(supported) == 0 {
+				supported, err = GetSupportedGFX(libDir)
+				if err != nil {
+					slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
+					return []GpuInfo{}
+				}
+				slog.Debug("rocm supported GPUs", "types", supported)
+			}
+			gfx := fmt.Sprintf("gfx%d%d%d", gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch)
+			if !slices.Contains[[]string, string](supported, gfx) {
+				slog.Warn("amdgpu is not supported", "gpu", gpuInfo.ID, "gpu_type", gfx, "library", libDir, "supported_types", supported)
+				// TODO - consider discrete markdown just for ROCM troubleshooting?
+				slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
+				continue
+			} else {
+				slog.Info("amdgpu is supported", "gpu", gpuInfo.ID, "gpu_type", gfx)
+			}
+		} else {
+			slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
+		}
+
+		// The GPU has passed all the verification steps and is supported
+		resp = append(resp, gpuInfo)
 	}
-	if resp.memInfo.DeviceCount > 0 {
-		resp.Library = "rocm"
+	if len(resp) == 0 {
+		slog.Info("no compatible amdgpu devices detected")
 	}
+	return resp
 }
 
 // Quick check for AMD driver so we can skip amdgpu discovery if not present
@@ -280,87 +297,24 @@ func AMDDetected() bool {
 		slog.Debug("amdgpu driver not detected " + sysfsDir)
 		return false
 	} else if err != nil {
-		slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err))
+		slog.Debug("error looking up amd driver", "path", sysfsDir, "error", err)
 		return false
 	}
 	return true
 }
 
-func setupLink(source, target string) error {
-	if err := os.RemoveAll(target); err != nil {
-		return fmt.Errorf("failed to remove old rocm directory %s %w", target, err)
-	}
-	if err := os.Symlink(source, target); err != nil {
-		return fmt.Errorf("failed to create link %s => %s %w", source, target, err)
-	}
-	slog.Debug(fmt.Sprintf("host rocm linked %s => %s", source, target))
-	return nil
-}
-
-// Ensure the AMD rocm lib dir is wired up
 // Prefer to use host installed ROCm, as long as it meets our minimum requirements
 // failing that, tell the user how to download it on their own
 func AMDValidateLibDir() (string, error) {
-	// We rely on the rpath compiled into our library to find rocm
-	// so we establish a symlink to wherever we find it on the system
-	// to <payloads>/rocm
-	payloadsDir, err := PayloadsDir()
-	if err != nil {
-		return "", err
-	}
-
-	// If we already have a rocm dependency wired, nothing more to do
-	rocmTargetDir := filepath.Clean(filepath.Join(payloadsDir, "..", "rocm"))
-	if rocmLibUsable(rocmTargetDir) {
-		return rocmTargetDir, nil
-	}
-
-	// next to the running binary
-	exe, err := os.Executable()
+	libDir, err := commonAMDValidateLibDir()
 	if err == nil {
-		peerDir := filepath.Dir(exe)
-		if rocmLibUsable(peerDir) {
-			slog.Debug("detected ROCM next to ollama executable " + peerDir)
-			return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
-		}
-		peerDir = filepath.Join(filepath.Dir(exe), "rocm")
-		if rocmLibUsable(peerDir) {
-			slog.Debug("detected ROCM next to ollama executable " + peerDir)
-			return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
-		}
+		return libDir, nil
 	}
 
 	// Well known ollama installer path
 	installedRocmDir := "/usr/share/ollama/lib/rocm"
 	if rocmLibUsable(installedRocmDir) {
-		return rocmTargetDir, setupLink(installedRocmDir, rocmTargetDir)
-	}
-
-	// Prefer explicit HIP env var
-	hipPath := os.Getenv("HIP_PATH")
-	if hipPath != "" {
-		hipLibDir := filepath.Join(hipPath, "lib")
-		if rocmLibUsable(hipLibDir) {
-			slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
-			return rocmTargetDir, setupLink(hipLibDir, rocmTargetDir)
-		}
-	}
-
-	// Scan the library path for potential matches
-	ldPaths := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
-	for _, ldPath := range ldPaths {
-		d, err := filepath.Abs(ldPath)
-		if err != nil {
-			continue
-		}
-		if rocmLibUsable(d) {
-			return rocmTargetDir, setupLink(d, rocmTargetDir)
-		}
-	}
-
-	// Well known location(s)
-	if rocmLibUsable("/opt/rocm/lib") {
-		return rocmTargetDir, setupLink("/opt/rocm/lib", rocmTargetDir)
+		return installedRocmDir, nil
 	}
 
 	// If we still haven't found a usable rocm, the user will have to install it on their own
@@ -384,68 +338,3 @@ func AMDDriverVersion() (string, error) {
 	}
 	return strings.TrimSpace(string(verString)), nil
 }
-
-func AMDGFXVersions() map[int]Version {
-	// The amdgpu driver always exposes the host CPU as node 0, but we have to skip that and subtract one
-	// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
-	res := map[int]Version{}
-	matches, _ := filepath.Glob(GPUPropertiesFileGlob)
-	for _, match := range matches {
-		fp, err := os.Open(match)
-		if err != nil {
-			slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err))
-			continue
-		}
-		defer fp.Close()
-		i, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
-		if err != nil {
-			slog.Debug(fmt.Sprintf("failed to parse node ID %s", err))
-			continue
-		}
-
-		if i == 0 {
-			// Skipping the CPU
-			continue
-		}
-		// Align with HIP IDs (zero is first GPU, not CPU)
-		i -= 1
-
-		scanner := bufio.NewScanner(fp)
-		for scanner.Scan() {
-			line := strings.TrimSpace(scanner.Text())
-			if strings.HasPrefix(line, "gfx_target_version") {
-				ver := strings.Fields(line)
-				if len(ver) != 2 || len(ver[1]) < 5 {
-					if ver[1] != "0" {
-						slog.Debug("malformed " + line)
-					}
-					res[i] = Version{
-						Major: 0,
-						Minor: 0,
-						Patch: 0,
-					}
-					continue
-				}
-				l := len(ver[1])
-				patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32)
-				minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
-				major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32)
-				if err1 != nil || err2 != nil || err3 != nil {
-					slog.Debug("malformed int " + line)
-					continue
-				}
-
-				res[i] = Version{
-					Major: uint(major),
-					Minor: uint(minor),
-					Patch: uint(patch),
-				}
-			}
-		}
-	}
-	return res
-}
-
-func (v Version) ToGFXString() string {
-	return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch)
-}

+ 80 - 74
gpu/amd_windows.go

@@ -7,7 +7,10 @@ import (
 	"os"
 	"path/filepath"
 	"slices"
+	"strconv"
 	"strings"
+
+	"github.com/ollama/ollama/format"
 )
 
 const (
@@ -22,36 +25,32 @@ var (
 	ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
 )
 
-func AMDGetGPUInfo(resp *GpuInfo) {
+func AMDGetGPUInfo() []GpuInfo {
+	resp := []GpuInfo{}
 	hl, err := NewHipLib()
 	if err != nil {
 		slog.Debug(err.Error())
-		return
+		return nil
 	}
 	defer hl.Release()
-	skip := map[int]interface{}{}
-	ids := []int{}
-	resp.memInfo.DeviceCount = 0
-	resp.memInfo.TotalMemory = 0
-	resp.memInfo.FreeMemory = 0
 
 	ver, err := hl.AMDDriverVersion()
 	if err == nil {
 		slog.Info("AMD Driver: " + ver)
 	} else {
 		// For now this is benign, but we may eventually need to fail compatibility checks
-		slog.Debug(fmt.Sprintf("error looking up amd driver version: %s", err))
+		slog.Debug("error looking up amd driver version", "error", err)
 	}
 
-	// Note: the HIP library automatically handles HIP_VISIBLE_DEVICES
+	// Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
 	count := hl.HipGetDeviceCount()
 	if count == 0 {
-		return
+		return nil
 	}
 	libDir, err := AMDValidateLibDir()
 	if err != nil {
-		slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
-		return
+		slog.Warn("unable to verify rocm library, will use cpu", "error", err)
+		return nil
 	}
 
 	var supported []string
@@ -59,95 +58,117 @@ func AMDGetGPUInfo(resp *GpuInfo) {
 	if gfxOverride == "" {
 		supported, err = GetSupportedGFX(libDir)
 		if err != nil {
-			slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
-			return
+			slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
+			return nil
 		}
 	} else {
 		slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
 	}
 
-	slog.Info(fmt.Sprintf("detected %d hip devices", count))
+	slog.Info("detected hip devices", "count", count)
+	// TODO how to determine the underlying device ID when visible devices is causing this to subset?
 	for i := 0; i < count; i++ {
-		ids = append(ids, i)
 		err = hl.HipSetDevice(i)
 		if err != nil {
-			slog.Warn(fmt.Sprintf("[%d] %s", i, err))
-			skip[i] = struct{}{}
+			slog.Warn("set device", "id", i, "error", err)
 			continue
 		}
 
 		props, err := hl.HipGetDeviceProperties(i)
 		if err != nil {
-			slog.Warn(fmt.Sprintf("[%d] %s", i, err))
-			skip[i] = struct{}{}
+			slog.Warn("get properties", "id", i, "error", err)
 			continue
 		}
 		n := bytes.IndexByte(props.Name[:], 0)
 		name := string(props.Name[:n])
-		slog.Info(fmt.Sprintf("[%d] Name: %s", i, name))
+		// TODO is UUID actually populated on windows?
+		// Can luid be used on windows for setting visible devices (and is it actually set?)
 		n = bytes.IndexByte(props.GcnArchName[:], 0)
 		gfx := string(props.GcnArchName[:n])
-		slog.Info(fmt.Sprintf("[%d] GcnArchName: %s", i, gfx))
+		slog.Info("hip device", "id", i, "name", name, "gfx", gfx)
+		var major, minor, patch string
+		switch len(gfx) {
+		case 6:
+			major, minor, patch = gfx[3:4], gfx[4:5], gfx[5:]
+		case 7:
+			major, minor, patch = gfx[3:5], gfx[5:6], gfx[6:]
+		}
 		//slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY!  Always 0
 		// TODO  Why isn't props.iGPU accurate!?
 		if strings.EqualFold(name, iGPUName) {
-			slog.Info(fmt.Sprintf("iGPU detected [%d] skipping", i))
-			skip[i] = struct{}{}
+			slog.Info("iGPU detected skipping", "id", i)
 			continue
 		}
 		if gfxOverride == "" {
 			if !slices.Contains[[]string, string](supported, gfx) {
-				slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, gfx, libDir, supported))
+				slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
 				// TODO - consider discrete markdown just for ROCM troubleshooting?
 				slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
-				skip[i] = struct{}{}
 				continue
 			} else {
-				slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, gfx))
+				slog.Info("amdgpu is supported", "gpu", i, "gpu_type", gfx)
 			}
 		}
 
-		totalMemory, freeMemory, err := hl.HipMemGetInfo()
+		freeMemory, totalMemory, err := hl.HipMemGetInfo()
 		if err != nil {
-			slog.Warn(fmt.Sprintf("[%d] %s", i, err))
+			slog.Warn("get mem info", "id", i, "error", err)
 			continue
 		}
 
-		// TODO according to docs, freeMem may lie on windows!
-		slog.Info(fmt.Sprintf("[%d] Total Mem: %d", i, totalMemory))
-		slog.Info(fmt.Sprintf("[%d] Free Mem:  %d", i, freeMemory))
-		resp.memInfo.DeviceCount++
-		resp.memInfo.TotalMemory += totalMemory
-		resp.memInfo.FreeMemory += freeMemory
-	}
-	if resp.memInfo.DeviceCount > 0 {
-		resp.Library = "rocm"
-	}
-	// Abort if all GPUs are skipped
-	if len(skip) >= count {
-		slog.Info("all detected amdgpus are skipped, falling back to CPU")
-		return
-	}
-	if len(skip) > 0 {
-		amdSetVisibleDevices(ids, skip)
+		// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
+		if totalMemory < IGPUMemLimit {
+			slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", i, "total", format.HumanBytes2(totalMemory))
+			continue
+		}
+
+		// TODO revisit this once ROCm v6 is available on windows.
+		// v5.7 only reports VRAM used by this process, so it's completely wrong and unusable
+		slog.Info("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory))
+		slog.Info("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory))
+		gpuInfo := GpuInfo{
+			Library: "rocm",
+			memInfo: memInfo{
+				TotalMemory: totalMemory,
+				FreeMemory:  freeMemory,
+			},
+			ID:             fmt.Sprintf("%d", i), // TODO this is probably wrong if we specify visible devices
+			DependencyPath: libDir,
+			MinimumMemory:  rocmMinimumMemory,
+		}
+		if major != "" {
+			gpuInfo.Major, err = strconv.Atoi(major)
+			if err != nil {
+				slog.Info("failed to parse version", "version", gfx, "error", err)
+			}
+		}
+		if minor != "" {
+			gpuInfo.Minor, err = strconv.Atoi(minor)
+			if err != nil {
+				slog.Info("failed to parse version", "version", gfx, "error", err)
+			}
+		}
+		if patch != "" {
+			gpuInfo.Patch, err = strconv.Atoi(patch)
+			if err != nil {
+				slog.Info("failed to parse version", "version", gfx, "error", err)
+			}
+		}
+		if gpuInfo.Major < RocmComputeMin {
+			slog.Warn(fmt.Sprintf("amdgpu [%s] too old gfx%d%d%d", gpuInfo.ID, gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch))
+			continue
+		}
+
+		resp = append(resp, gpuInfo)
 	}
-	UpdatePath(libDir)
+
+	return resp
 }
 
 func AMDValidateLibDir() (string, error) {
-	// On windows non-admins typically can't create links
-	// so instead of trying to rely on rpath and a link in
-	// $LibDir/rocm, we instead rely on setting PATH to point
-	// to the location of the ROCm library
-
-	// Installer payload location if we're running the installed binary
-	exe, err := os.Executable()
+	libDir, err := commonAMDValidateLibDir()
 	if err == nil {
-		rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
-		if rocmLibUsable(rocmTargetDir) {
-			slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
-			return rocmTargetDir, nil
-		}
+		return libDir, nil
 	}
 
 	// Installer payload (if we're running from some other location)
@@ -159,21 +180,6 @@ func AMDValidateLibDir() (string, error) {
 		return rocmTargetDir, nil
 	}
 
-	// Prefer explicit HIP env var
-	hipPath := os.Getenv("HIP_PATH")
-	if hipPath != "" {
-		hipLibDir := filepath.Join(hipPath, "bin")
-		if rocmLibUsable(hipLibDir) {
-			slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
-			return hipLibDir, nil
-		}
-	}
-
-	// Well known location(s)
-	if rocmLibUsable(RocmStandardLocation) {
-		return RocmStandardLocation, nil
-	}
-
 	// Should not happen on windows since we include it in the installer, but stand-alone binary might hit this
 	slog.Warn("amdgpu detected, but no compatible rocm library found.  Please install ROCm")
 	return "", fmt.Errorf("no suitable rocm found, falling back to CPU")

+ 2 - 2
gpu/assets.go

@@ -80,7 +80,7 @@ func cleanupTmpDirs() {
 		}
 		err = os.RemoveAll(d)
 		if err != nil {
-			slog.Debug(fmt.Sprintf("unable to cleanup stale tmpdir %s: %s", d, err))
+			slog.Debug("unable to cleanup stale tmpdir", "path", d, "error", err)
 		}
 	}
 }
@@ -120,7 +120,7 @@ func UpdatePath(dir string) {
 			}
 		}
 		newPath := strings.Join(append([]string{dir}, pathComponents...), ";")
-		slog.Info(fmt.Sprintf("Updating PATH to %s", newPath))
+		slog.Info("updating", "PATH", newPath)
 		os.Setenv("PATH", newPath)
 	}
 	// linux and darwin rely on rpath

+ 22 - 0
gpu/cuda_common.go

@@ -0,0 +1,22 @@
+//go:build linux || windows
+
+package gpu
+
+import (
+	"log/slog"
+	"strings"
+)
+
+func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
+	ids := []string{}
+	for _, info := range gpuInfo {
+		if info.Library != "cuda" {
+			// TODO shouldn't happen if things are wired correctly...
+			slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library)
+			continue
+		}
+		ids = append(ids, info.ID)
+	}
+	return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
+
+}

+ 82 - 145
gpu/gpu.go

@@ -16,7 +16,6 @@ import (
 	"os"
 	"path/filepath"
 	"runtime"
-	"strconv"
 	"strings"
 	"sync"
 	"unsafe"
@@ -25,8 +24,8 @@ import (
 )
 
 type handles struct {
-	nvml   *C.nvml_handle_t
-	cudart *C.cudart_handle_t
+	deviceCount int
+	cudart      *C.cudart_handle_t
 }
 
 const (
@@ -39,26 +38,10 @@ var gpuMutex sync.Mutex
 // With our current CUDA compile flags, older than 5.0 will not work properly
 var CudaComputeMin = [2]C.int{5, 0}
 
-// Possible locations for the nvidia-ml library
-var NvmlLinuxGlobs = []string{
-	"/usr/local/cuda/lib64/libnvidia-ml.so*",
-	"/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so*",
-	"/usr/lib/x86_64-linux-gnu/libnvidia-ml.so*",
-	"/usr/lib/wsl/lib/libnvidia-ml.so*",
-	"/usr/lib/wsl/drivers/*/libnvidia-ml.so*",
-	"/opt/cuda/lib64/libnvidia-ml.so*",
-	"/usr/lib*/libnvidia-ml.so*",
-	"/usr/lib/aarch64-linux-gnu/nvidia/current/libnvidia-ml.so*",
-	"/usr/lib/aarch64-linux-gnu/libnvidia-ml.so*",
-	"/usr/local/lib*/libnvidia-ml.so*",
-
-	// TODO: are these stubs ever valid?
-	"/opt/cuda/targets/x86_64-linux/lib/stubs/libnvidia-ml.so*",
-}
+var RocmComputeMin = 9
 
-var NvmlWindowsGlobs = []string{
-	"c:\\Windows\\System32\\nvml.dll",
-}
+// TODO find a better way to detect iGPU instead of minimum memory
+const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU
 
 var CudartLinuxGlobs = []string{
 	"/usr/local/cuda/lib64/libcudart.so*",
@@ -88,26 +71,18 @@ func initGPUHandles() *handles {
 
 	// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
 
-	gpuHandles := &handles{nil, nil}
-	var nvmlMgmtName string
-	var nvmlMgmtPatterns []string
+	gpuHandles := &handles{}
 	var cudartMgmtName string
 	var cudartMgmtPatterns []string
 
 	tmpDir, _ := PayloadsDir()
 	switch runtime.GOOS {
 	case "windows":
-		nvmlMgmtName = "nvml.dll"
-		nvmlMgmtPatterns = make([]string, len(NvmlWindowsGlobs))
-		copy(nvmlMgmtPatterns, NvmlWindowsGlobs)
 		cudartMgmtName = "cudart64_*.dll"
 		localAppData := os.Getenv("LOCALAPPDATA")
 		cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)}
 		cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...)
 	case "linux":
-		nvmlMgmtName = "libnvidia-ml.so"
-		nvmlMgmtPatterns = make([]string, len(NvmlLinuxGlobs))
-		copy(nvmlMgmtPatterns, NvmlLinuxGlobs)
 		cudartMgmtName = "libcudart.so*"
 		if tmpDir != "" {
 			// TODO - add "payloads" for subprocess
@@ -118,31 +93,21 @@ func initGPUHandles() *handles {
 		return gpuHandles
 	}
 
-	slog.Info("Detecting GPU type")
+	slog.Info("Detecting GPUs")
 	cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns)
 	if len(cudartLibPaths) > 0 {
-		cudart := LoadCUDARTMgmt(cudartLibPaths)
+		deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths)
 		if cudart != nil {
-			slog.Info("Nvidia GPU detected via cudart")
+			slog.Info("detected GPUs", "library", libPath, "count", deviceCount)
 			gpuHandles.cudart = cudart
-			return gpuHandles
-		}
-	}
-
-	// TODO once we build confidence, remove this and the gpu_info_nvml.[ch] files
-	nvmlLibPaths := FindGPULibs(nvmlMgmtName, nvmlMgmtPatterns)
-	if len(nvmlLibPaths) > 0 {
-		nvml := LoadNVMLMgmt(nvmlLibPaths)
-		if nvml != nil {
-			slog.Info("Nvidia GPU detected via nvidia-ml")
-			gpuHandles.nvml = nvml
+			gpuHandles.deviceCount = deviceCount
 			return gpuHandles
 		}
 	}
 	return gpuHandles
 }
 
-func GetGPUInfo() GpuInfo {
+func GetGPUInfo() GpuInfoList {
 	// TODO - consider exploring lspci (and equivalent on windows) to check for
 	// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
 	gpuMutex.Lock()
@@ -150,9 +115,6 @@ func GetGPUInfo() GpuInfo {
 
 	gpuHandles := initGPUHandles()
 	defer func() {
-		if gpuHandles.nvml != nil {
-			C.nvml_release(*gpuHandles.nvml)
-		}
 		if gpuHandles.cudart != nil {
 			C.cudart_release(*gpuHandles.cudart)
 		}
@@ -165,72 +127,63 @@ func GetGPUInfo() GpuInfo {
 	}
 
 	var memInfo C.mem_info_t
-	resp := GpuInfo{}
-	if gpuHandles.nvml != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
-		C.nvml_check_vram(*gpuHandles.nvml, &memInfo)
-		if memInfo.err != nil {
-			slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU memory: %s", C.GoString(memInfo.err)))
-			C.free(unsafe.Pointer(memInfo.err))
-		} else if memInfo.count > 0 {
-			// Verify minimum compute capability
-			var cc C.nvml_compute_capability_t
-			C.nvml_compute_capability(*gpuHandles.nvml, &cc)
-			if cc.err != nil {
-				slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU compute capability: %s", C.GoString(cc.err)))
-				C.free(unsafe.Pointer(cc.err))
-			} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
-				slog.Info(fmt.Sprintf("[nvidia-ml] NVML CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
-				resp.Library = "cuda"
-				resp.MinimumMemory = cudaMinimumMemory
-			} else {
-				slog.Info(fmt.Sprintf("[nvidia-ml] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
-			}
+	resp := []GpuInfo{}
+
+	// NVIDIA first
+	for i := 0; i < gpuHandles.deviceCount; i++ {
+		// TODO once we support CPU compilation variants of GPU libraries refine this...
+		if cpuVariant == "" && runtime.GOARCH == "amd64" {
+			continue
 		}
-	} else if gpuHandles.cudart != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
-		C.cudart_check_vram(*gpuHandles.cudart, &memInfo)
+		gpuInfo := GpuInfo{
+			Library: "cuda",
+		}
+		C.cudart_check_vram(*gpuHandles.cudart, C.int(i), &memInfo)
 		if memInfo.err != nil {
-			slog.Info(fmt.Sprintf("[cudart] error looking up CUDART GPU memory: %s", C.GoString(memInfo.err)))
+			slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
 			C.free(unsafe.Pointer(memInfo.err))
-		} else if memInfo.count > 0 {
-			// Verify minimum compute capability
-			var cc C.cudart_compute_capability_t
-			C.cudart_compute_capability(*gpuHandles.cudart, &cc)
-			if cc.err != nil {
-				slog.Info(fmt.Sprintf("[cudart] error looking up CUDA compute capability: %s", C.GoString(cc.err)))
-				C.free(unsafe.Pointer(cc.err))
-			} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
-				slog.Info(fmt.Sprintf("[cudart] CUDART CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
-				resp.Library = "cuda"
-				resp.MinimumMemory = cudaMinimumMemory
-			} else {
-				slog.Info(fmt.Sprintf("[cudart] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
-			}
+			continue
 		}
-	} else {
-		AMDGetGPUInfo(&resp)
-		if resp.Library != "" {
-			resp.MinimumMemory = rocmMinimumMemory
-			return resp
+		if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) {
+			slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
+			continue
 		}
+		gpuInfo.TotalMemory = uint64(memInfo.total)
+		gpuInfo.FreeMemory = uint64(memInfo.free)
+		gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
+		gpuInfo.Major = int(memInfo.major)
+		gpuInfo.Minor = int(memInfo.minor)
+		gpuInfo.MinimumMemory = cudaMinimumMemory
+
+		// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
+		resp = append(resp, gpuInfo)
 	}
-	if resp.Library == "" {
+
+	// Then AMD
+	resp = append(resp, AMDGetGPUInfo()...)
+
+	if len(resp) == 0 {
 		C.cpu_check_ram(&memInfo)
-		resp.Library = "cpu"
-		resp.Variant = cpuVariant
-	}
-	if memInfo.err != nil {
-		slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err)))
-		C.free(unsafe.Pointer(memInfo.err))
-		return resp
+		if memInfo.err != nil {
+			slog.Info("error looking up CPU memory", "error", C.GoString(memInfo.err))
+			C.free(unsafe.Pointer(memInfo.err))
+			return resp
+		}
+		gpuInfo := GpuInfo{
+			Library: "cpu",
+			Variant: cpuVariant,
+		}
+		gpuInfo.TotalMemory = uint64(memInfo.total)
+		gpuInfo.FreeMemory = uint64(memInfo.free)
+		gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
+
+		resp = append(resp, gpuInfo)
 	}
 
-	resp.DeviceCount = uint32(memInfo.count)
-	resp.FreeMemory = uint64(memInfo.free)
-	resp.TotalMemory = uint64(memInfo.total)
 	return resp
 }
 
-func getCPUMem() (memInfo, error) {
+func GetCPUMem() (memInfo, error) {
 	var ret memInfo
 	var info C.mem_info_t
 	C.cpu_check_ram(&info)
@@ -243,29 +196,11 @@ func getCPUMem() (memInfo, error) {
 	return ret, nil
 }
 
-func CheckVRAM() (uint64, error) {
-	userLimit := os.Getenv("OLLAMA_MAX_VRAM")
-	if userLimit != "" {
-		avail, err := strconv.ParseInt(userLimit, 10, 64)
-		if err != nil {
-			return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
-		}
-		slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
-		return uint64(avail), nil
-	}
-	gpuInfo := GetGPUInfo()
-	if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
-		return gpuInfo.FreeMemory, nil
-	}
-
-	return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
-}
-
 func FindGPULibs(baseLibName string, patterns []string) []string {
 	// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
 	var ldPaths []string
 	gpuLibPaths := []string{}
-	slog.Info(fmt.Sprintf("Searching for GPU management library %s", baseLibName))
+	slog.Debug("Searching for GPU library", "name", baseLibName)
 
 	switch runtime.GOOS {
 	case "windows":
@@ -283,7 +218,7 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
 		}
 		patterns = append(patterns, filepath.Join(d, baseLibName+"*"))
 	}
-	slog.Debug(fmt.Sprintf("gpu management search paths: %v", patterns))
+	slog.Debug("gpu library search", "globs", patterns)
 	for _, pattern := range patterns {
 		// Ignore glob discovery errors
 		matches, _ := filepath.Glob(pattern)
@@ -311,28 +246,11 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
 			}
 		}
 	}
-	slog.Info(fmt.Sprintf("Discovered GPU libraries: %v", gpuLibPaths))
+	slog.Debug("discovered GPU libraries", "paths", gpuLibPaths)
 	return gpuLibPaths
 }
 
-func LoadNVMLMgmt(nvmlLibPaths []string) *C.nvml_handle_t {
-	var resp C.nvml_init_resp_t
-	resp.ch.verbose = getVerboseState()
-	for _, libPath := range nvmlLibPaths {
-		lib := C.CString(libPath)
-		defer C.free(unsafe.Pointer(lib))
-		C.nvml_init(lib, &resp)
-		if resp.err != nil {
-			slog.Info(fmt.Sprintf("Unable to load NVML management library %s: %s", libPath, C.GoString(resp.err)))
-			C.free(unsafe.Pointer(resp.err))
-		} else {
-			return &resp.ch
-		}
-	}
-	return nil
-}
-
-func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t {
+func LoadCUDARTMgmt(cudartLibPaths []string) (int, *C.cudart_handle_t, string) {
 	var resp C.cudart_init_resp_t
 	resp.ch.verbose = getVerboseState()
 	for _, libPath := range cudartLibPaths {
@@ -340,13 +258,13 @@ func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t {
 		defer C.free(unsafe.Pointer(lib))
 		C.cudart_init(lib, &resp)
 		if resp.err != nil {
-			slog.Info(fmt.Sprintf("Unable to load cudart CUDA management library %s: %s", libPath, C.GoString(resp.err)))
+			slog.Debug("Unable to load cudart", "library", libPath, "error", C.GoString(resp.err))
 			C.free(unsafe.Pointer(resp.err))
 		} else {
-			return &resp.ch
+			return int(resp.num_devices), &resp.ch, libPath
 		}
 	}
-	return nil
+	return 0, nil, ""
 }
 
 func getVerboseState() C.uint16_t {
@@ -355,3 +273,22 @@ func getVerboseState() C.uint16_t {
 	}
 	return C.uint16_t(0)
 }
+
+// Given the list of GPUs this instantiation is targeted for,
+// figure out the visible devices environment variable
+//
+// If different libraries are detected, the first one is what we use
+func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
+	if len(l) == 0 {
+		return "", ""
+	}
+	switch l[0].Library {
+	case "cuda":
+		return cudaGetVisibleDevicesEnv(l)
+	case "rocm":
+		return rocmGetVisibleDevicesEnv(l)
+	default:
+		slog.Debug("no filter required for library " + l[0].Library)
+		return "", ""
+	}
+}

+ 23 - 34
gpu/gpu_darwin.go

@@ -9,52 +9,41 @@ package gpu
 */
 import "C"
 import (
-	"fmt"
-	"log/slog"
-	"os"
 	"runtime"
-	"strconv"
 )
 
-// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
-func CheckVRAM() (uint64, error) {
-	userLimit := os.Getenv("OLLAMA_MAX_VRAM")
-	if userLimit != "" {
-		avail, err := strconv.ParseInt(userLimit, 10, 64)
-		if err != nil {
-			return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
-		}
-		slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
-		return uint64(avail), nil
-	}
-
-	if runtime.GOARCH == "amd64" {
-		// gpu not supported, this may not be metal
-		return 0, nil
-	}
-
-	return uint64(C.getRecommendedMaxVRAM()), nil
-}
-
-func GetGPUInfo() GpuInfo {
-	mem, _ := getCPUMem()
+func GetGPUInfo() GpuInfoList {
+	mem, _ := GetCPUMem()
 	if runtime.GOARCH == "amd64" {
-		return GpuInfo{
-			Library: "cpu",
-			Variant: GetCPUVariant(),
-			memInfo: mem,
+		return []GpuInfo{
+			{
+				Library: "cpu",
+				Variant: GetCPUVariant(),
+				memInfo: mem,
+			},
 		}
 	}
-	return GpuInfo{
+	info := GpuInfo{
 		Library: "metal",
-		memInfo: mem,
+		ID:      "0",
 	}
+	info.TotalMemory = uint64(C.getRecommendedMaxVRAM())
+
+	// TODO is there a way to gather actual allocated video memory? (currentAllocatedSize doesn't work)
+	info.FreeMemory = info.TotalMemory
+
+	info.MinimumMemory = 0
+	return []GpuInfo{info}
 }
 
-func getCPUMem() (memInfo, error) {
+func GetCPUMem() (memInfo, error) {
 	return memInfo{
 		TotalMemory: uint64(C.getPhysicalMemory()),
 		FreeMemory:  0,
-		DeviceCount: 1,
 	}, nil
 }
+
+func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
+	// No-op on darwin
+	return "", ""
+}

+ 8 - 4
gpu/gpu_info.h

@@ -38,12 +38,17 @@
 extern "C" {
 #endif
 
+#define GPU_ID_LEN 64
+
 typedef struct mem_info {
+  char *err;  // If non-nill, caller responsible for freeing
+  char gpu_id[GPU_ID_LEN];
   uint64_t total;
   uint64_t free;
-  unsigned int count;
-  int igpu_index; // If >= 0, we detected an integrated GPU to ignore
-  char *err;  // If non-nill, caller responsible for freeing
+
+  // Compute Capability
+  int major; 
+  int minor;
 } mem_info_t;
 
 void cpu_check_ram(mem_info_t *resp);
@@ -52,7 +57,6 @@ void cpu_check_ram(mem_info_t *resp);
 }
 #endif
 
-#include "gpu_info_nvml.h"
 #include "gpu_info_cudart.h"
 
 #endif  // __GPU_INFO_H__

+ 6 - 2
gpu/gpu_info_cpu.c

@@ -8,9 +8,11 @@ void cpu_check_ram(mem_info_t *resp) {
   MEMORYSTATUSEX info;
   info.dwLength = sizeof(info);
   if (GlobalMemoryStatusEx(&info) != 0) {
-    resp->count = 1;
     resp->total = info.ullTotalPhys;
     resp->free = info.ullAvailPhys;
+    resp->major = 0;
+    resp->minor = 0;
+    snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
   } else {
     resp->err = LOAD_ERR();
   }
@@ -27,9 +29,11 @@ void cpu_check_ram(mem_info_t *resp) {
   if (sysinfo(&info) != 0) {
     resp->err = strdup(strerror(errno));
   } else {
-    resp->count = 1;
     resp->total = info.totalram * info.mem_unit;
     resp->free = info.freeram * info.mem_unit;
+    resp->major = 0;
+    resp->minor = 0;
+    snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
   }
   return;
 }

+ 63 - 82
gpu/gpu_info_cudart.c

@@ -6,6 +6,7 @@
 void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
   cudartReturn_t ret;
   resp->err = NULL;
+  resp->num_devices = 0;
   const int buflen = 256;
   char buf[buflen + 1];
   int i;
@@ -21,6 +22,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
       {"cudaGetDeviceCount", (void *)&resp->ch.cudaGetDeviceCount},
       {"cudaDeviceGetAttribute", (void *)&resp->ch.cudaDeviceGetAttribute},
       {"cudaDriverGetVersion", (void *)&resp->ch.cudaDriverGetVersion},
+      {"cudaGetDeviceProperties", (void *)&resp->ch.cudaGetDeviceProperties},
       {NULL, NULL},
   };
 
@@ -36,13 +38,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
     return;
   }
 
-  // TODO once we've squashed the remaining corner cases remove this log
-  LOG(resp->ch.verbose, "wiring cudart library functions in %s\n", cudart_lib_path);
-  
   for (i = 0; l[i].s != NULL; i++) {
-    // TODO once we've squashed the remaining corner cases remove this log
-    LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
-
     *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
     if (!l[i].p) {
       char *msg = LOAD_ERR();
@@ -63,7 +59,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
     UNLOAD_LIBRARY(resp->ch.handle);
     resp->ch.handle = NULL;
     if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
-      resp->err = strdup("your nvidia driver is too old or missing, please upgrade to run ollama");
+      resp->err = strdup("your nvidia driver is too old or missing.  If you have a CUDA GPU please upgrade to run ollama");
       return;
     }
     snprintf(buf, buflen, "cudart init failure: %d", ret);
@@ -85,110 +81,95 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
     driverVersion.minor = (version - (driverVersion.major * 1000)) / 10;
     LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor);
   }
+
+  ret = (*resp->ch.cudaGetDeviceCount)(&resp->num_devices);
+  if (ret != CUDART_SUCCESS) {
+    LOG(resp->ch.verbose, "cudaGetDeviceCount err: %d\n", ret);
+    UNLOAD_LIBRARY(resp->ch.handle);
+    resp->ch.handle = NULL;
+    snprintf(buf, buflen, "unable to get device count: %d", ret);
+    resp->err = strdup(buf);
+    return;
+  }
 }
 
 
-void cudart_check_vram(cudart_handle_t h, mem_info_t *resp) {
+void cudart_check_vram(cudart_handle_t h, int i, mem_info_t *resp) {
   resp->err = NULL;
   cudartMemory_t memInfo = {0,0,0};
   cudartReturn_t ret;
   const int buflen = 256;
   char buf[buflen + 1];
-  int i;
 
   if (h.handle == NULL) {
     resp->err = strdup("cudart handle isn't initialized");
     return;
   }
 
-  // cudaGetDeviceCount takes int type, resp-> count is uint
-  int deviceCount;
-  ret = (*h.cudaGetDeviceCount)(&deviceCount);
+  ret = (*h.cudaSetDevice)(i);
   if (ret != CUDART_SUCCESS) {
-    snprintf(buf, buflen, "unable to get device count: %d", ret);
+    snprintf(buf, buflen, "cudart device failed to initialize");
     resp->err = strdup(buf);
     return;
-  } else {
-    resp->count = (unsigned int)deviceCount;
   }
 
-  resp->total = 0;
-  resp->free = 0;
-  for (i = 0; i < resp-> count; i++) {  
-    ret = (*h.cudaSetDevice)(i);
-    if (ret != CUDART_SUCCESS) {
-      snprintf(buf, buflen, "cudart device failed to initialize");
-      resp->err = strdup(buf);
-      return;
+  cudaDeviceProp_t props;
+  ret = (*h.cudaGetDeviceProperties)(&props, i);
+  if (ret != CUDART_SUCCESS) {
+    LOG(h.verbose, "[%d] device properties lookup failure: %d\n", i, ret);
+    snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
+    resp->major = 0;
+    resp->minor = 0;
+  } else {
+    int allNull = 1;
+    for (int j = 0; j < 16; j++) {
+      if (props.uuid.bytes[j] != 0) {
+        allNull = 0;
+        break;
+      }
     }
-    ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total);
-    if (ret != CUDART_SUCCESS) {
-      snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret);
-      resp->err = strdup(buf);
-      return;
+    if (allNull != 0) {
+      snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
+    } else {
+      // GPU-d110a105-ac29-1d54-7b49-9c90440f215b
+      snprintf(&resp->gpu_id[0], GPU_ID_LEN,
+          "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
+          props.uuid.bytes[0],
+          props.uuid.bytes[1],
+          props.uuid.bytes[2],
+          props.uuid.bytes[3],
+          props.uuid.bytes[4],
+          props.uuid.bytes[5],
+          props.uuid.bytes[6],
+          props.uuid.bytes[7],
+          props.uuid.bytes[8],
+          props.uuid.bytes[9],
+          props.uuid.bytes[10],
+          props.uuid.bytes[11],
+          props.uuid.bytes[12],
+          props.uuid.bytes[13],
+          props.uuid.bytes[14],
+          props.uuid.bytes[15]
+        );
     }
+    resp->major = props.major;
+    resp->minor = props.minor;
 
-    LOG(h.verbose, "[%d] CUDA totalMem %lu\n", i, memInfo.total);
-    LOG(h.verbose, "[%d] CUDA freeMem %lu\n", i, memInfo.free);
-
-    resp->total += memInfo.total;
-    resp->free += memInfo.free;
-  }
-}
-
-void cudart_compute_capability(cudart_handle_t h, cudart_compute_capability_t *resp) {
-  resp->err = NULL;
-  resp->major = 0;
-  resp->minor = 0;
-  int major = 0;
-  int minor = 0;
-  cudartReturn_t ret;
-  const int buflen = 256;
-  char buf[buflen + 1];
-  int i;
-
-  if (h.handle == NULL) {
-    resp->err = strdup("cudart handle not initialized");
-    return;
+    // TODO add other useful properties from props
   }
-
-  int devices;
-  ret = (*h.cudaGetDeviceCount)(&devices);
+  ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total);
   if (ret != CUDART_SUCCESS) {
-    snprintf(buf, buflen, "unable to get cudart device count: %d", ret);
+    snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret);
     resp->err = strdup(buf);
     return;
   }
 
-  for (i = 0; i < devices; i++) {
-    ret = (*h.cudaSetDevice)(i);
-    if (ret != CUDART_SUCCESS) {
-      snprintf(buf, buflen, "cudart device failed to initialize");
-      resp->err = strdup(buf);
-      return;
-    }
+  resp->total = memInfo.total;
+  resp->free = memInfo.free;
 
-    ret = (*h.cudaDeviceGetAttribute)(&major, cudartDevAttrComputeCapabilityMajor, i);
-    if (ret != CUDART_SUCCESS) {
-      snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
-      resp->err = strdup(buf);
-      return;
-    }
-    ret = (*h.cudaDeviceGetAttribute)(&minor, cudartDevAttrComputeCapabilityMinor, i);
-    if (ret != CUDART_SUCCESS) {
-      snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
-      resp->err = strdup(buf);
-      return;
-    }
-      
-    // Report the lowest major.minor we detect as that limits our compatibility
-    if (resp->major == 0 || resp->major > major ) {
-      resp->major = major;
-      resp->minor = minor;
-    } else if ( resp->major == major && resp->minor > minor ) {
-      resp->minor = minor;
-    }
-  }
+  LOG(h.verbose, "[%s] CUDA totalMem %lu\n", resp->gpu_id, resp->total);
+  LOG(h.verbose, "[%s] CUDA freeMem %lu\n", resp->gpu_id, resp->free);
+  LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
 }
 
 void cudart_release(cudart_handle_t h) {

+ 96 - 10
gpu/gpu_info_cudart.h

@@ -6,7 +6,8 @@
 // Just enough typedef's to dlopen/dlsym for memory information
 typedef enum cudartReturn_enum {
   CUDART_SUCCESS = 0,
-  CUDART_UNSUPPORTED = 1,
+  CUDA_ERROR_INVALID_VALUE = 1,
+  CUDA_ERROR_MEMORY_ALLOCATION = 2,
   CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
   // Other values omitted for now...
 } cudartReturn_t;
@@ -14,6 +15,11 @@ typedef enum cudartReturn_enum {
 typedef enum cudartDeviceAttr_enum {
   cudartDevAttrComputeCapabilityMajor = 75,
   cudartDevAttrComputeCapabilityMinor = 76,
+
+  // TODO - not yet wired up but may be useful for Jetson or other
+  // integrated GPU scenarios with shared memory
+  cudaDevAttrIntegrated = 18
+
 } cudartDeviceAttr_t;
 
 typedef void *cudartDevice_t;  // Opaque is sufficient
@@ -28,6 +34,92 @@ typedef struct cudartDriverVersion {
   int minor;
 } cudartDriverVersion_t;
 
+typedef struct cudaUUID {
+    unsigned char bytes[16];
+} cudaUUID_t;
+typedef struct cudaDeviceProp {
+    char         name[256];                  /**< ASCII string identifying device */
+    cudaUUID_t   uuid;                       /**< 16-byte unique identifier */
+    char         luid[8];                    /**< 8-byte locally unique identifier. Value is undefined on TCC and non-Windows platforms */
+    unsigned int luidDeviceNodeMask;         /**< LUID device node mask. Value is undefined on TCC and non-Windows platforms */
+    size_t       totalGlobalMem;             /**< Global memory available on device in bytes */
+    size_t       sharedMemPerBlock;          /**< Shared memory available per block in bytes */
+    int          regsPerBlock;               /**< 32-bit registers available per block */
+    int          warpSize;                   /**< Warp size in threads */
+    size_t       memPitch;                   /**< Maximum pitch in bytes allowed by memory copies */
+    int          maxThreadsPerBlock;         /**< Maximum number of threads per block */
+    int          maxThreadsDim[3];           /**< Maximum size of each dimension of a block */
+    int          maxGridSize[3];             /**< Maximum size of each dimension of a grid */
+    int          clockRate;                  /**< Clock frequency in kilohertz */
+    size_t       totalConstMem;              /**< Constant memory available on device in bytes */
+    int          major;                      /**< Major compute capability */
+    int          minor;                      /**< Minor compute capability */
+    size_t       textureAlignment;           /**< Alignment requirement for textures */
+    size_t       texturePitchAlignment;      /**< Pitch alignment requirement for texture references bound to pitched memory */
+    int          deviceOverlap;              /**< Device can concurrently copy memory and execute a kernel. Deprecated. Use instead asyncEngineCount. */
+    int          multiProcessorCount;        /**< Number of multiprocessors on device */
+    int          kernelExecTimeoutEnabled;   /**< Specified whether there is a run time limit on kernels */
+    int          integrated;                 /**< Device is integrated as opposed to discrete */
+    int          canMapHostMemory;           /**< Device can map host memory with cudaHostAlloc/cudaHostGetDevicePointer */
+    int          computeMode;                /**< Compute mode (See ::cudaComputeMode) */
+    int          maxTexture1D;               /**< Maximum 1D texture size */
+    int          maxTexture1DMipmap;         /**< Maximum 1D mipmapped texture size */
+    int          maxTexture1DLinear;         /**< Deprecated, do not use. Use cudaDeviceGetTexture1DLinearMaxWidth() or cuDeviceGetTexture1DLinearMaxWidth() instead. */
+    int          maxTexture2D[2];            /**< Maximum 2D texture dimensions */
+    int          maxTexture2DMipmap[2];      /**< Maximum 2D mipmapped texture dimensions */
+    int          maxTexture2DLinear[3];      /**< Maximum dimensions (width, height, pitch) for 2D textures bound to pitched memory */
+    int          maxTexture2DGather[2];      /**< Maximum 2D texture dimensions if texture gather operations have to be performed */
+    int          maxTexture3D[3];            /**< Maximum 3D texture dimensions */
+    int          maxTexture3DAlt[3];         /**< Maximum alternate 3D texture dimensions */
+    int          maxTextureCubemap;          /**< Maximum Cubemap texture dimensions */
+    int          maxTexture1DLayered[2];     /**< Maximum 1D layered texture dimensions */
+    int          maxTexture2DLayered[3];     /**< Maximum 2D layered texture dimensions */
+    int          maxTextureCubemapLayered[2];/**< Maximum Cubemap layered texture dimensions */
+    int          maxSurface1D;               /**< Maximum 1D surface size */
+    int          maxSurface2D[2];            /**< Maximum 2D surface dimensions */
+    int          maxSurface3D[3];            /**< Maximum 3D surface dimensions */
+    int          maxSurface1DLayered[2];     /**< Maximum 1D layered surface dimensions */
+    int          maxSurface2DLayered[3];     /**< Maximum 2D layered surface dimensions */
+    int          maxSurfaceCubemap;          /**< Maximum Cubemap surface dimensions */
+    int          maxSurfaceCubemapLayered[2];/**< Maximum Cubemap layered surface dimensions */
+    size_t       surfaceAlignment;           /**< Alignment requirements for surfaces */
+    int          concurrentKernels;          /**< Device can possibly execute multiple kernels concurrently */
+    int          ECCEnabled;                 /**< Device has ECC support enabled */
+    int          pciBusID;                   /**< PCI bus ID of the device */
+    int          pciDeviceID;                /**< PCI device ID of the device */
+    int          pciDomainID;                /**< PCI domain ID of the device */
+    int          tccDriver;                  /**< 1 if device is a Tesla device using TCC driver, 0 otherwise */
+    int          asyncEngineCount;           /**< Number of asynchronous engines */
+    int          unifiedAddressing;          /**< Device shares a unified address space with the host */
+    int          memoryClockRate;            /**< Peak memory clock frequency in kilohertz */
+    int          memoryBusWidth;             /**< Global memory bus width in bits */
+    int          l2CacheSize;                /**< Size of L2 cache in bytes */
+    int          persistingL2CacheMaxSize;   /**< Device's maximum l2 persisting lines capacity setting in bytes */
+    int          maxThreadsPerMultiProcessor;/**< Maximum resident threads per multiprocessor */
+    int          streamPrioritiesSupported;  /**< Device supports stream priorities */
+    int          globalL1CacheSupported;     /**< Device supports caching globals in L1 */
+    int          localL1CacheSupported;      /**< Device supports caching locals in L1 */
+    size_t       sharedMemPerMultiprocessor; /**< Shared memory available per multiprocessor in bytes */
+    int          regsPerMultiprocessor;      /**< 32-bit registers available per multiprocessor */
+    int          managedMemory;              /**< Device supports allocating managed memory on this system */
+    int          isMultiGpuBoard;            /**< Device is on a multi-GPU board */
+    int          multiGpuBoardGroupID;       /**< Unique identifier for a group of devices on the same multi-GPU board */
+    int          hostNativeAtomicSupported;  /**< Link between the device and the host supports native atomic operations */
+    int          singleToDoublePrecisionPerfRatio; /**< Ratio of single precision performance (in floating-point operations per second) to double precision performance */
+    int          pageableMemoryAccess;       /**< Device supports coherently accessing pageable memory without calling cudaHostRegister on it */
+    int          concurrentManagedAccess;    /**< Device can coherently access managed memory concurrently with the CPU */
+    int          computePreemptionSupported; /**< Device supports Compute Preemption */
+    int          canUseHostPointerForRegisteredMem; /**< Device can access host registered memory at the same virtual address as the CPU */
+    int          cooperativeLaunch;          /**< Device supports launching cooperative kernels via ::cudaLaunchCooperativeKernel */
+    int          cooperativeMultiDeviceLaunch; /**< Deprecated, cudaLaunchCooperativeKernelMultiDevice is deprecated. */
+    size_t       sharedMemPerBlockOptin;     /**< Per device maximum shared memory per block usable by special opt in */
+    int          pageableMemoryAccessUsesHostPageTables; /**< Device accesses pageable memory via the host's page tables */
+    int          directManagedMemAccessFromHost; /**< Host can directly access managed memory on the device without migration. */
+    int          maxBlocksPerMultiProcessor; /**< Maximum number of resident blocks per multiprocessor */
+    int          accessPolicyMaxWindowSize;  /**< The maximum value of ::cudaAccessPolicyWindow::num_bytes. */
+    size_t       reservedSharedMemPerBlock;  /**< Shared memory reserved by CUDA driver per block in bytes */
+  } cudaDeviceProp_t;
+
 typedef struct cudart_handle {
   void *handle;
   uint16_t verbose;
@@ -38,23 +130,17 @@ typedef struct cudart_handle {
   cudartReturn_t (*cudaGetDeviceCount)(int *);
   cudartReturn_t (*cudaDeviceGetAttribute)(int* value, cudartDeviceAttr_t attr, int device);
   cudartReturn_t (*cudaDriverGetVersion) (int *driverVersion);
+  cudartReturn_t (*cudaGetDeviceProperties) (cudaDeviceProp_t* prop, int device);
 } cudart_handle_t;
 
 typedef struct cudart_init_resp {
   char *err;  // If err is non-null handle is invalid
   cudart_handle_t ch;
+  int num_devices;
 } cudart_init_resp_t;
 
-typedef struct cudart_compute_capability {
-  char *err;
-  int major;
-  int minor;
-} cudart_compute_capability_t;
-
-
 void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp);
-void cudart_check_vram(cudart_handle_t ch, mem_info_t *resp);
-void cudart_compute_capability(cudart_handle_t th, cudart_compute_capability_t *cc);
+void cudart_check_vram(cudart_handle_t ch, int device_id, mem_info_t *resp);
 void cudart_release(cudart_handle_t ch);
 
 #endif  // __GPU_INFO_CUDART_H__

+ 0 - 221
gpu/gpu_info_nvml.c

@@ -1,221 +0,0 @@
-#ifndef __APPLE__  // TODO - maybe consider nvidia support on intel macs?
-
-#include <string.h>
-
-#include "gpu_info_nvml.h"
-
-void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) {
-  nvmlReturn_t ret;
-  resp->err = NULL;
-  const int buflen = 256;
-  char buf[buflen + 1];
-  int i;
-
-  struct lookup {
-    char *s;
-    void **p;
-  } l[] = {
-      {"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2},
-      {"nvmlShutdown", (void *)&resp->ch.nvmlShutdown},
-      {"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.nvmlDeviceGetHandleByIndex},
-      {"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo},
-      {"nvmlDeviceGetCount_v2", (void *)&resp->ch.nvmlDeviceGetCount_v2},
-      {"nvmlDeviceGetCudaComputeCapability", (void *)&resp->ch.nvmlDeviceGetCudaComputeCapability},
-      {"nvmlSystemGetDriverVersion", (void *)&resp->ch.nvmlSystemGetDriverVersion},
-      {"nvmlDeviceGetName", (void *)&resp->ch.nvmlDeviceGetName},
-      {"nvmlDeviceGetSerial", (void *)&resp->ch.nvmlDeviceGetSerial},
-      {"nvmlDeviceGetVbiosVersion", (void *)&resp->ch.nvmlDeviceGetVbiosVersion},
-      {"nvmlDeviceGetBoardPartNumber", (void *)&resp->ch.nvmlDeviceGetBoardPartNumber},
-      {"nvmlDeviceGetBrand", (void *)&resp->ch.nvmlDeviceGetBrand},
-      {NULL, NULL},
-  };
-
-  resp->ch.handle = LOAD_LIBRARY(nvml_lib_path, RTLD_LAZY);
-  if (!resp->ch.handle) {
-    char *msg = LOAD_ERR();
-    LOG(resp->ch.verbose, "library %s load err: %s\n", nvml_lib_path, msg);
-    snprintf(buf, buflen,
-             "Unable to load %s library to query for Nvidia GPUs: %s",
-             nvml_lib_path, msg);
-    free(msg);
-    resp->err = strdup(buf);
-    return;
-  }
-
-  // TODO once we've squashed the remaining corner cases remove this log
-  LOG(resp->ch.verbose, "wiring nvidia management library functions in %s\n", nvml_lib_path);
-  
-  for (i = 0; l[i].s != NULL; i++) {
-    // TODO once we've squashed the remaining corner cases remove this log
-    LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
-
-    *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
-    if (!l[i].p) {
-      resp->ch.handle = NULL;
-      char *msg = LOAD_ERR();
-      LOG(resp->ch.verbose, "dlerr: %s\n", msg);
-      UNLOAD_LIBRARY(resp->ch.handle);
-      snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
-               msg);
-      free(msg);
-      resp->err = strdup(buf);
-      return;
-    }
-  }
-
-  ret = (*resp->ch.nvmlInit_v2)();
-  if (ret != NVML_SUCCESS) {
-    LOG(resp->ch.verbose, "nvmlInit_v2 err: %d\n", ret);
-    UNLOAD_LIBRARY(resp->ch.handle);
-    resp->ch.handle = NULL;
-    snprintf(buf, buflen, "nvml vram init failure: %d", ret);
-    resp->err = strdup(buf);
-    return;
-  }
-
-  // Report driver version if we're in verbose mode, ignore errors
-  ret = (*resp->ch.nvmlSystemGetDriverVersion)(buf, buflen);
-  if (ret != NVML_SUCCESS) {
-    LOG(resp->ch.verbose, "nvmlSystemGetDriverVersion failed: %d\n", ret);
-  } else {
-    LOG(resp->ch.verbose, "CUDA driver version: %s\n", buf);
-  }
-}
-
-void nvml_check_vram(nvml_handle_t h, mem_info_t *resp) {
-  resp->err = NULL;
-  nvmlDevice_t device;
-  nvmlMemory_t memInfo = {0};
-  nvmlReturn_t ret;
-  const int buflen = 256;
-  char buf[buflen + 1];
-  int i;
-
-  if (h.handle == NULL) {
-    resp->err = strdup("nvml handle isn't initialized");
-    return;
-  }
-
-  ret = (*h.nvmlDeviceGetCount_v2)(&resp->count);
-  if (ret != NVML_SUCCESS) {
-    snprintf(buf, buflen, "unable to get device count: %d", ret);
-    resp->err = strdup(buf);
-    return;
-  }
-
-  resp->total = 0;
-  resp->free = 0;
-  for (i = 0; i < resp->count; i++) {
-    ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
-    if (ret != NVML_SUCCESS) {
-      snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
-      resp->err = strdup(buf);
-      return;
-    }
-
-    ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo);
-    if (ret != NVML_SUCCESS) {
-      snprintf(buf, buflen, "device memory info lookup failure %d: %d", i, ret);
-      resp->err = strdup(buf);
-      return;
-    }
-    if (h.verbose) {
-      nvmlBrandType_t brand = 0;
-      // When in verbose mode, report more information about
-      // the card we discover, but don't fail on error
-      ret = (*h.nvmlDeviceGetName)(device, buf, buflen);
-      if (ret != NVML_SUCCESS) {
-        LOG(h.verbose, "nvmlDeviceGetName failed: %d\n", ret);
-      } else {
-        LOG(h.verbose, "[%d] CUDA device name: %s\n", i, buf);
-      }
-      ret = (*h.nvmlDeviceGetBoardPartNumber)(device, buf, buflen);
-      if (ret != NVML_SUCCESS) {
-        LOG(h.verbose, "nvmlDeviceGetBoardPartNumber failed: %d\n", ret);
-      } else {
-        LOG(h.verbose, "[%d] CUDA part number: %s\n", i, buf);
-      }
-      ret = (*h.nvmlDeviceGetSerial)(device, buf, buflen);
-      if (ret != NVML_SUCCESS) {
-        LOG(h.verbose, "nvmlDeviceGetSerial failed: %d\n", ret);
-      } else {
-        LOG(h.verbose, "[%d] CUDA S/N: %s\n", i, buf);
-      }
-      ret = (*h.nvmlDeviceGetVbiosVersion)(device, buf, buflen);
-      if (ret != NVML_SUCCESS) {
-        LOG(h.verbose, "nvmlDeviceGetVbiosVersion failed: %d\n", ret);
-      } else {
-        LOG(h.verbose, "[%d] CUDA vbios version: %s\n", i, buf);
-      }
-      ret = (*h.nvmlDeviceGetBrand)(device, &brand);
-      if (ret != NVML_SUCCESS) {
-        LOG(h.verbose, "nvmlDeviceGetBrand failed: %d\n", ret);
-      } else {
-        LOG(h.verbose, "[%d] CUDA brand: %d\n", i, brand);
-      }
-    }
-
-    LOG(h.verbose, "[%d] CUDA totalMem %ld\n", i, memInfo.total);
-    LOG(h.verbose, "[%d] CUDA freeMem %ld\n", i, memInfo.free);
-
-    resp->total += memInfo.total;
-    resp->free += memInfo.free;
-  }
-}
-
-void nvml_compute_capability(nvml_handle_t h, nvml_compute_capability_t *resp) {
-  resp->err = NULL;
-  resp->major = 0;
-  resp->minor = 0;
-  nvmlDevice_t device;
-  int major = 0;
-  int minor = 0;
-  nvmlReturn_t ret;
-  const int buflen = 256;
-  char buf[buflen + 1];
-  int i;
-
-  if (h.handle == NULL) {
-    resp->err = strdup("nvml handle not initialized");
-    return;
-  }
-
-  unsigned int devices;
-  ret = (*h.nvmlDeviceGetCount_v2)(&devices);
-  if (ret != NVML_SUCCESS) {
-    snprintf(buf, buflen, "unable to get device count: %d", ret);
-    resp->err = strdup(buf);
-    return;
-  }
-
-  for (i = 0; i < devices; i++) {
-    ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
-    if (ret != NVML_SUCCESS) {
-      snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
-      resp->err = strdup(buf);
-      return;
-    }
-
-    ret = (*h.nvmlDeviceGetCudaComputeCapability)(device, &major, &minor);
-    if (ret != NVML_SUCCESS) {
-      snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
-      resp->err = strdup(buf);
-      return;
-    }
-    // Report the lowest major.minor we detect as that limits our compatibility
-    if (resp->major == 0 || resp->major > major ) {
-      resp->major = major;
-      resp->minor = minor;
-    } else if ( resp->major == major && resp->minor > minor ) {
-      resp->minor = minor;
-    }
-  }
-}
-
-void nvml_release(nvml_handle_t h) {
-  LOG(h.verbose, "releasing nvml library\n");
-  UNLOAD_LIBRARY(h.handle);
-  h.handle = NULL;
-}
-
-#endif  // __APPLE__

+ 0 - 57
gpu/gpu_info_nvml.h

@@ -1,57 +0,0 @@
-#ifndef __APPLE__
-#ifndef __GPU_INFO_NVML_H__
-#define __GPU_INFO_NVML_H__
-#include "gpu_info.h"
-
-// Just enough typedef's to dlopen/dlsym for memory information
-typedef enum nvmlReturn_enum {
-  NVML_SUCCESS = 0,
-  // Other values omitted for now...
-} nvmlReturn_t;
-typedef void *nvmlDevice_t;  // Opaque is sufficient
-typedef struct nvmlMemory_st {
-  unsigned long long total;
-  unsigned long long free;
-  unsigned long long used;
-} nvmlMemory_t;
-
-typedef enum nvmlBrandType_enum
-{
-    NVML_BRAND_UNKNOWN          = 0,
-} nvmlBrandType_t;
-
-typedef struct nvml_handle {
-  void *handle;
-  uint16_t verbose;
-  nvmlReturn_t (*nvmlInit_v2)(void);
-  nvmlReturn_t (*nvmlShutdown)(void);
-  nvmlReturn_t (*nvmlDeviceGetHandleByIndex)(unsigned int, nvmlDevice_t *);
-  nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
-  nvmlReturn_t (*nvmlDeviceGetCount_v2)(unsigned int *);
-  nvmlReturn_t (*nvmlDeviceGetCudaComputeCapability)(nvmlDevice_t, int* major, int* minor);
-  nvmlReturn_t (*nvmlSystemGetDriverVersion) (char* version, unsigned int  length);
-  nvmlReturn_t (*nvmlDeviceGetName) (nvmlDevice_t device, char* name, unsigned int  length);
-  nvmlReturn_t (*nvmlDeviceGetSerial) (nvmlDevice_t device, char* serial, unsigned int  length);
-  nvmlReturn_t (*nvmlDeviceGetVbiosVersion) (nvmlDevice_t device, char* version, unsigned int  length);
-  nvmlReturn_t (*nvmlDeviceGetBoardPartNumber) (nvmlDevice_t device, char* partNumber, unsigned int  length);
-  nvmlReturn_t (*nvmlDeviceGetBrand) (nvmlDevice_t device, nvmlBrandType_t* type);
-} nvml_handle_t;
-
-typedef struct nvml_init_resp {
-  char *err;  // If err is non-null handle is invalid
-  nvml_handle_t ch;
-} nvml_init_resp_t;
-
-typedef struct nvml_compute_capability {
-  char *err;
-  int major;
-  int minor;
-} nvml_compute_capability_t;
-
-void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp);
-void nvml_check_vram(nvml_handle_t ch, mem_info_t *resp);
-void nvml_compute_capability(nvml_handle_t ch, nvml_compute_capability_t *cc);
-void nvml_release(nvml_handle_t ch);
-
-#endif  // __GPU_INFO_NVML_H__
-#endif  // __APPLE__

+ 6 - 13
gpu/gpu_test.go

@@ -9,23 +9,16 @@ import (
 
 func TestBasicGetGPUInfo(t *testing.T) {
 	info := GetGPUInfo()
-	assert.Contains(t, "cuda rocm cpu metal", info.Library)
-
-	switch runtime.GOOS {
-	case "darwin":
-		// TODO - remove this once MacOS returns some size for CPU
-		return
-	case "linux", "windows":
-		assert.Greater(t, info.TotalMemory, uint64(0))
-		assert.Greater(t, info.FreeMemory, uint64(0))
-		assert.Greater(t, info.DeviceCount, uint32(0))
-	default:
-		return
+	assert.Greater(t, len(info), 0)
+	assert.Contains(t, "cuda rocm cpu metal", info[0].Library)
+	if info[0].Library != "cpu" {
+		assert.Greater(t, info[0].TotalMemory, uint64(0))
+		assert.Greater(t, info[0].FreeMemory, uint64(0))
 	}
 }
 
 func TestCPUMemInfo(t *testing.T) {
-	info, err := getCPUMem()
+	info, err := GetCPUMem()
 	assert.NoError(t, err)
 	switch runtime.GOOS {
 	case "darwin":

+ 43 - 6
gpu/types.go

@@ -3,7 +3,6 @@ package gpu
 type memInfo struct {
 	TotalMemory uint64 `json:"total_memory,omitempty"`
 	FreeMemory  uint64 `json:"free_memory,omitempty"`
-	DeviceCount uint32 `json:"device_count,omitempty"`
 }
 
 // Beginning of an `ollama info` command
@@ -17,11 +16,49 @@ type GpuInfo struct {
 	// MinimumMemory represents the minimum memory required to use the GPU
 	MinimumMemory uint64 `json:"-"`
 
-	// TODO add other useful attributes about the card here for discovery information
+	// Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly
+	DependencyPath string `json:"lib_path,omitempty"`
+
+	// GPU information
+	ID    string `json:"gpu_id"`          // string to use for selection of this specific GPU
+	Name  string `json:"name"`            // user friendly name if available
+	Major int    `json:"major,omitempty"` // Major compatibility version (CC or gfx)
+	Minor int    `json:"minor,omitempty"` // Minor compatibility version (CC or gfx)
+	Patch int    `json:"patch,omitempty"` // Patch compatibility only matters on AMD
+
+	// TODO other performance capability info to help in scheduling decisions
 }
 
-type Version struct {
-	Major uint
-	Minor uint
-	Patch uint
+type GpuInfoList []GpuInfo
+
+// Split up the set of gpu info's by Library and variant
+func (l GpuInfoList) ByLibrary() []GpuInfoList {
+	resp := []GpuInfoList{}
+	libs := []string{}
+	for _, info := range l {
+		found := false
+		requested := info.Library
+		if info.Variant != "" {
+			requested += "_" + info.Variant
+		}
+		for i, lib := range libs {
+			if lib == requested {
+				resp[i] = append(resp[i], info)
+				found = true
+				break
+			}
+		}
+		if !found {
+			libs = append(libs, info.Library)
+			resp = append(resp, []GpuInfo{info})
+		}
+	}
+	return resp
 }
+
+// Sort by Free Space
+type ByFreeMemory []GpuInfo
+
+func (a ByFreeMemory) Len() int           { return len(a) }
+func (a ByFreeMemory) Swap(i, j int)      { a[i], a[j] = a[j], a[i] }
+func (a ByFreeMemory) Less(i, j int) bool { return a[i].FreeMemory < a[j].FreeMemory }

+ 1 - 2
integration/basic_test.go

@@ -4,7 +4,6 @@ package integration
 
 import (
 	"context"
-	"net/http"
 	"testing"
 	"time"
 
@@ -24,5 +23,5 @@ func TestOrcaMiniBlueSky(t *testing.T) {
 			"seed":        123,
 		},
 	}
-	GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"rayleigh", "scattering"})
+	GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
 }

+ 225 - 0
integration/concurrency_test.go

@@ -0,0 +1,225 @@
+//go:build integration
+
+package integration
+
+import (
+	"context"
+	"log/slog"
+	"os"
+	"strconv"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/ollama/ollama/api"
+	"github.com/stretchr/testify/require"
+)
+
+func TestMultiModelConcurrency(t *testing.T) {
+	var (
+		req = [2]api.GenerateRequest{
+			{
+				Model:  "orca-mini",
+				Prompt: "why is the ocean blue?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			}, {
+				Model:  "tinydolphin",
+				Prompt: "what is the origin of the us thanksgiving holiday?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			},
+		}
+		resp = [2][]string{
+			[]string{"sunlight"},
+			[]string{"england", "english", "massachusetts", "pilgrims"},
+		}
+	)
+	var wg sync.WaitGroup
+	wg.Add(len(req))
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
+	defer cancel()
+	for i := 0; i < len(req); i++ {
+		go func(i int) {
+			defer wg.Done()
+			GenerateTestHelper(ctx, t, req[i], resp[i])
+		}(i)
+	}
+	wg.Wait()
+}
+
+func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) // GTX 750 2G card takes ~9 minutes
+	defer cancel()
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+
+	req, resp := GenerateRequests()
+	// Get the server running (if applicable) warm the model up with a single initial request
+	DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 5*time.Second)
+
+	var wg sync.WaitGroup
+	wg.Add(len(req))
+	for i := 0; i < len(req); i++ {
+		go func(i int) {
+			defer wg.Done()
+			for j := 0; j < 5; j++ {
+				slog.Info("Starting", "req", i, "iter", j)
+				// On slower GPUs it can take a while to process the 4 concurrent requests
+				// so we allow a much longer initial timeout
+				DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
+			}
+		}(i)
+	}
+	wg.Wait()
+}
+
+// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
+func TestMultiModelStress(t *testing.T) {
+	vram := os.Getenv("OLLAMA_MAX_VRAM")
+	if vram == "" {
+		t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
+	}
+	max, err := strconv.ParseUint(vram, 10, 64)
+	require.NoError(t, err)
+	const MB = uint64(1024 * 1024)
+	type model struct {
+		name string
+		size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM
+	}
+
+	smallModels := []model{
+		{
+			name: "orca-mini",
+			size: 2992 * MB,
+		},
+		{
+			name: "phi",
+			size: 2616 * MB,
+		},
+		{
+			name: "gemma:2b",
+			size: 2364 * MB,
+		},
+		{
+			name: "stable-code:3b",
+			size: 2608 * MB,
+		},
+		{
+			name: "starcoder2:3b",
+			size: 2166 * MB,
+		},
+	}
+	mediumModels := []model{
+		{
+			name: "llama2",
+			size: 5118 * MB,
+		},
+		{
+			name: "mistral",
+			size: 4620 * MB,
+		},
+		{
+			name: "orca-mini:7b",
+			size: 5118 * MB,
+		},
+		{
+			name: "dolphin-mistral",
+			size: 4620 * MB,
+		},
+		{
+			name: "gemma:7b",
+			size: 5000 * MB,
+		},
+		// TODO - uncomment this once #3565 is merged and this is rebased on it
+		// {
+		// 	name: "codellama:7b",
+		// 	size: 5118 * MB,
+		// },
+	}
+
+	// These seem to be too slow to be useful...
+	// largeModels := []model{
+	// 	{
+	// 		name: "llama2:13b",
+	// 		size: 7400 * MB,
+	// 	},
+	// 	{
+	// 		name: "codellama:13b",
+	// 		size: 7400 * MB,
+	// 	},
+	// 	{
+	// 		name: "orca-mini:13b",
+	// 		size: 7400 * MB,
+	// 	},
+	// 	{
+	// 		name: "gemma:7b",
+	// 		size: 5000 * MB,
+	// 	},
+	// 	{
+	// 		name: "starcoder2:15b",
+	// 		size: 9100 * MB,
+	// 	},
+	// }
+
+	var chosenModels []model
+	switch {
+	case max < 10000*MB:
+		slog.Info("selecting small models")
+		chosenModels = smallModels
+	// case max < 30000*MB:
+	default:
+		slog.Info("selecting medium models")
+		chosenModels = mediumModels
+		// default:
+		// 	slog.Info("selecting large models")
+		// 	chosenModels = largModels
+	}
+
+	req, resp := GenerateRequests()
+
+	for i := range req {
+		if i > len(chosenModels) {
+			break
+		}
+		req[i].Model = chosenModels[i].name
+	}
+
+	ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) // TODO baseline -- 10m too short
+	defer cancel()
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+
+	// Make sure all the models are pulled before we get started
+	for _, r := range req {
+		require.NoError(t, PullIfMissing(ctx, client, r.Model))
+	}
+
+	var wg sync.WaitGroup
+	consumed := uint64(256 * MB) // Assume some baseline usage
+	for i := 0; i < len(req); i++ {
+		// Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long
+		if i > 1 && consumed > max {
+			slog.Info("achieved target vram exhaustion", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
+			break
+		}
+		consumed += chosenModels[i].size
+		slog.Info("target vram", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
+
+		wg.Add(1)
+		go func(i int) {
+			defer wg.Done()
+			for j := 0; j < 3; j++ {
+				slog.Info("Starting", "req", i, "iter", j, "model", req[i].Model)
+				DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
+			}
+		}(i)
+	}
+	wg.Wait()
+}

+ 1 - 2
integration/context_test.go

@@ -4,7 +4,6 @@ package integration
 
 import (
 	"context"
-	"net/http"
 	"testing"
 	"time"
 
@@ -25,5 +24,5 @@ func TestContextExhaustion(t *testing.T) {
 			"num_ctx":     128,
 		},
 	}
-	GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"once", "upon", "lived"})
+	GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"})
 }

+ 3 - 3
integration/llm_image_test.go

@@ -5,7 +5,6 @@ package integration
 import (
 	"context"
 	"encoding/base64"
-	"net/http"
 	"testing"
 	"time"
 
@@ -29,10 +28,11 @@ func TestIntegrationMultimodal(t *testing.T) {
 		},
 	}
 
-	resp := "the ollamas"
+	// Note: sometimes it returns "the ollamas" sometimes "the ollams"
+	resp := "the ollam"
 	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
 	defer cancel()
-	GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp})
+	GenerateTestHelper(ctx, t, req, []string{resp})
 }
 
 const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb

+ 1 - 23
integration/llm_test.go

@@ -4,8 +4,6 @@ package integration
 
 import (
 	"context"
-	"net/http"
-	"sync"
 	"testing"
 	"time"
 
@@ -45,25 +43,5 @@ var (
 func TestIntegrationSimpleOrcaMini(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
 	defer cancel()
-	GenerateTestHelper(ctx, t, &http.Client{}, req[0], resp[0])
+	GenerateTestHelper(ctx, t, req[0], resp[0])
 }
-
-// TODO
-// The server always loads a new runner and closes the old one, which forces serial execution
-// At present this test case fails with concurrency problems.  Eventually we should try to
-// get true concurrency working with n_parallel support in the backend
-func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
-	var wg sync.WaitGroup
-	wg.Add(len(req))
-	ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
-	defer cancel()
-	for i := 0; i < len(req); i++ {
-		go func(i int) {
-			defer wg.Done()
-			GenerateTestHelper(ctx, t, &http.Client{}, req[i], resp[i])
-		}(i)
-	}
-	wg.Wait()
-}
-
-// TODO - create a parallel test with 2 different models once we support concurrency

+ 171 - 98
integration/utils_test.go

@@ -5,13 +5,14 @@ package integration
 import (
 	"bytes"
 	"context"
-	"encoding/json"
+	"errors"
 	"fmt"
 	"io"
 	"log/slog"
 	"math/rand"
 	"net"
 	"net/http"
+	"net/url"
 	"os"
 	"path/filepath"
 	"runtime"
@@ -23,9 +24,13 @@ import (
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/app/lifecycle"
-	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
+func Init() {
+	lifecycle.InitLogging()
+}
+
 func FindPort() string {
 	port := 0
 	if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
@@ -41,7 +46,7 @@ func FindPort() string {
 	return strconv.Itoa(port)
 }
 
-func GetTestEndpoint() (string, string) {
+func GetTestEndpoint() (*api.Client, string) {
 	defaultPort := "11434"
 	ollamaHost := os.Getenv("OLLAMA_HOST")
 
@@ -67,16 +72,20 @@ func GetTestEndpoint() (string, string) {
 		port = FindPort()
 	}
 
-	url := fmt.Sprintf("%s:%s", host, port)
-	slog.Info("server connection", "url", url)
-	return scheme, url
+	slog.Info("server connection", "host", host, "port", port)
+
+	return api.NewClient(
+		&url.URL{
+			Scheme: scheme,
+			Host:   net.JoinHostPort(host, port),
+		},
+		http.DefaultClient), fmt.Sprintf("%s:%s", host, port)
 }
 
-// TODO make fanicier, grab logs, etc.
 var serverMutex sync.Mutex
 var serverReady bool
 
-func StartServer(ctx context.Context, ollamaHost string) error {
+func startServer(ctx context.Context, ollamaHost string) error {
 	// Make sure the server has been built
 	CLIName, err := filepath.Abs("../ollama")
 	if err != nil {
@@ -125,67 +134,76 @@ func StartServer(ctx context.Context, ollamaHost string) error {
 	return nil
 }
 
-func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoint, modelName string) error {
+func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error {
 	slog.Info("checking status of model", "model", modelName)
 	showReq := &api.ShowRequest{Name: modelName}
-	requestJSON, err := json.Marshal(showReq)
-	if err != nil {
-		return err
-	}
-
-	req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/show", bytes.NewReader(requestJSON))
-	if err != nil {
-		return err
-	}
 
-	// Make the request with the HTTP client
-	response, err := client.Do(req.WithContext(ctx))
-	if err != nil {
+	showCtx, cancel := context.WithDeadlineCause(
+		ctx,
+		time.Now().Add(5*time.Second),
+		fmt.Errorf("show for existing model %s took too long", modelName),
+	)
+	defer cancel()
+	_, err := client.Show(showCtx, showReq)
+	var statusError api.StatusError
+	switch {
+	case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
+		break
+	case err != nil:
 		return err
-	}
-	defer response.Body.Close()
-	if response.StatusCode == 200 {
+	default:
 		slog.Info("model already present", "model", modelName)
 		return nil
 	}
-	slog.Info("model missing", "status", response.StatusCode)
+	slog.Info("model missing", "model", modelName)
+
+	stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models
+	stallTimer := time.NewTimer(stallDuration)
+	fn := func(resp api.ProgressResponse) error {
+		// fmt.Print(".")
+		if !stallTimer.Reset(stallDuration) {
+			return fmt.Errorf("stall was detected, aborting status reporting")
+		}
+		return nil
+	}
 
+	stream := true
 	pullReq := &api.PullRequest{Name: modelName, Stream: &stream}
-	requestJSON, err = json.Marshal(pullReq)
-	if err != nil {
-		return err
-	}
 
-	req, err = http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/pull", bytes.NewReader(requestJSON))
-	if err != nil {
-		return err
-	}
-	slog.Info("pulling", "model", modelName)
+	var pullError error
 
-	response, err = client.Do(req.WithContext(ctx))
-	if err != nil {
-		return err
-	}
-	defer response.Body.Close()
-	if response.StatusCode != 200 {
-		return fmt.Errorf("failed to pull model") // TODO more details perhaps
+	done := make(chan int)
+	go func() {
+		pullError = client.Pull(ctx, pullReq, fn)
+		done <- 0
+	}()
+
+	select {
+	case <-stallTimer.C:
+		return fmt.Errorf("download stalled")
+	case <-done:
+		return pullError
 	}
-	slog.Info("model pulled", "model", modelName)
-	return nil
 }
 
 var serverProcMutex sync.Mutex
 
-func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) {
-
-	// TODO maybe stuff in an init routine?
-	lifecycle.InitLogging()
-
-	requestJSON, err := json.Marshal(genReq)
-	if err != nil {
-		t.Fatalf("Error serializing request: %v", err)
+// Returns an Client, the testEndpoint, and a cleanup function, fails the test on errors
+// Starts the server if needed
+func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, string, func()) {
+	client, testEndpoint := GetTestEndpoint()
+	if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
+		serverProcMutex.Lock()
+		fp, err := os.CreateTemp("", "ollama-server-*.log")
+		if err != nil {
+			t.Fatalf("failed to generate log file: %s", err)
+		}
+		lifecycle.ServerLogFile = fp.Name()
+		fp.Close()
+		require.NoError(t, startServer(ctx, testEndpoint))
 	}
-	defer func() {
+
+	return client, testEndpoint, func() {
 		if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
 			defer serverProcMutex.Unlock()
 			if t.Failed() {
@@ -203,63 +221,118 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client,
 				os.Stderr.Write(data)
 				slog.Warn("END OF SERVER")
 			}
-			err = os.Remove(lifecycle.ServerLogFile)
+			err := os.Remove(lifecycle.ServerLogFile)
 			if err != nil && !os.IsNotExist(err) {
 				slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
 			}
 		}
-	}()
-	scheme, testEndpoint := GetTestEndpoint()
-
-	if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
-		serverProcMutex.Lock()
-		fp, err := os.CreateTemp("", "ollama-server-*.log")
-		if err != nil {
-			t.Fatalf("failed to generate log file: %s", err)
-		}
-		lifecycle.ServerLogFile = fp.Name()
-		fp.Close()
-		assert.NoError(t, StartServer(ctx, testEndpoint))
 	}
+}
 
-	err = PullIfMissing(ctx, client, scheme, testEndpoint, genReq.Model)
-	if err != nil {
-		t.Fatalf("Error pulling model: %v", err)
-	}
+func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+	require.NoError(t, PullIfMissing(ctx, client, genReq.Model))
+	DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
+}
 
-	// Make the request and get the response
-	req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON))
-	if err != nil {
-		t.Fatalf("Error creating request: %v", err)
+func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) {
+	stallTimer := time.NewTimer(initialTimeout)
+	var buf bytes.Buffer
+	fn := func(response api.GenerateResponse) error {
+		// fmt.Print(".")
+		buf.Write([]byte(response.Response))
+		if !stallTimer.Reset(streamTimeout) {
+			return fmt.Errorf("stall was detected while streaming response, aborting")
+		}
+		return nil
 	}
 
-	// Set the content type for the request
-	req.Header.Set("Content-Type", "application/json")
+	stream := true
+	genReq.Stream = &stream
+	done := make(chan int)
+	var genErr error
+	go func() {
+		genErr = client.Generate(ctx, &genReq, fn)
+		done <- 0
+	}()
 
-	// Make the request with the HTTP client
-	response, err := client.Do(req.WithContext(ctx))
-	if err != nil {
-		t.Fatalf("Error making request: %v", err)
-	}
-	defer response.Body.Close()
-	body, err := io.ReadAll(response.Body)
-	assert.NoError(t, err)
-	assert.Equal(t, response.StatusCode, 200, string(body))
-
-	// Verify the response is valid JSON
-	var payload api.GenerateResponse
-	err = json.Unmarshal(body, &payload)
-	if err != nil {
-		assert.NoError(t, err, body)
+	select {
+	case <-stallTimer.C:
+		if buf.Len() == 0 {
+			t.Errorf("generate never started.  Timed out after :%s", initialTimeout.String())
+		} else {
+			t.Errorf("generate stalled.  Response so far:%s", buf.String())
+		}
+	case <-done:
+		require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
+		// Verify the response contains the expected data
+		response := buf.String()
+		atLeastOne := false
+		for _, resp := range anyResp {
+			if strings.Contains(strings.ToLower(response), resp) {
+				atLeastOne = true
+				break
+			}
+		}
+		require.True(t, atLeastOne, "none of %v found in %s", anyResp, response)
+		slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
+	case <-ctx.Done():
+		t.Error("outer test context done while waiting for generate")
 	}
+}
 
-	// Verify the response contains the expected data
-	atLeastOne := false
-	for _, resp := range anyResp {
-		if strings.Contains(strings.ToLower(payload.Response), resp) {
-			atLeastOne = true
-			break
+// Generate a set of requests
+// By default each request uses orca-mini as the model
+func GenerateRequests() ([]api.GenerateRequest, [][]string) {
+	return []api.GenerateRequest{
+			{
+				Model:  "orca-mini",
+				Prompt: "why is the ocean blue?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			}, {
+				Model:  "orca-mini",
+				Prompt: "why is the color of dirt brown?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			}, {
+				Model:  "orca-mini",
+				Prompt: "what is the origin of the us thanksgiving holiday?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			}, {
+				Model:  "orca-mini",
+				Prompt: "what is the origin of independence day?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			}, {
+				Model:  "orca-mini",
+				Prompt: "what is the composition of air?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			},
+		},
+		[][]string{
+			[]string{"sunlight"},
+			[]string{"soil", "organic", "earth", "black", "tan"},
+			[]string{"england", "english", "massachusetts", "pilgrims"},
+			[]string{"fourth", "july", "declaration", "independence"},
+			[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
 		}
-	}
-	assert.True(t, atLeastOne, "none of %v found in %s", anyResp, payload.Response)
 }

+ 162 - 0
llm/memory.go

@@ -0,0 +1,162 @@
+package llm
+
+import (
+	"fmt"
+	"log/slog"
+	"strings"
+
+	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/format"
+	"github.com/ollama/ollama/gpu"
+)
+
+// This algorithm looks for a complete fit to determine if we need to unload other models
+func PredictServerFit(allGpus gpu.GpuInfoList, ggml *GGML, adapters, projectors []string, opts api.Options) (bool, uint64) {
+	var estimatedVRAM uint64
+	if opts.NumCtx > int(ggml.KV().ContextLength()) {
+		slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
+		opts.NumCtx = int(ggml.KV().ContextLength())
+	}
+
+	if opts.NumCtx < 4 {
+		opts.NumCtx = 4
+	}
+
+	// Split up the GPUs by type and try them
+	for _, gpus := range allGpus.ByLibrary() {
+		var layerCount int
+		layerCount, estimatedVRAM = EstimateGPULayers(gpus, ggml, projectors, opts)
+		if opts.NumGPU < 0 {
+			if layerCount > 0 && layerCount >= int(ggml.KV().BlockCount()+1) {
+				return true, estimatedVRAM
+			}
+		} else {
+			if layerCount > 0 && layerCount >= opts.NumGPU {
+				return true, estimatedVRAM
+			}
+		}
+	}
+	return false, estimatedVRAM
+}
+
+// Given a model and one or more GPU targets, predict how many layers and bytes we can load
+// The GPUs provided must all be the same Library
+func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts api.Options) (int, uint64) {
+	if gpus[0].Library == "cpu" {
+		return 0, 0
+	}
+	var memoryAvailable uint64
+	for _, info := range gpus {
+		memoryAvailable += info.FreeMemory
+	}
+	slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", format.HumanBytes2(memoryAvailable))
+
+	// TODO - this is probably wrong, first GPU vs secondaries will have different overheads
+	memoryMinimum := gpus[0].MinimumMemory
+
+	for _, projector := range projectors {
+		memoryMinimum += projectorMemoryRequirements(projector)
+
+		// multimodal models require at least 2048 context
+		opts.NumCtx = max(opts.NumCtx, 2048)
+	}
+
+	// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
+	var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
+
+	graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
+	if graphPartialOffload == 0 {
+		graphPartialOffload = ggml.KV().GQA() * kv / 6
+	}
+
+	if graphFullOffload == 0 {
+		graphFullOffload = graphPartialOffload
+	}
+
+	graphFullOffload *= uint64(len(gpus))
+	graphPartialOffload *= uint64(len(gpus))
+
+	// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
+	memoryRequiredTotal := memoryMinimum + graphFullOffload
+
+	// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
+	memoryRequiredPartial := memoryMinimum + graphPartialOffload
+
+	if memoryRequiredPartial > memoryAvailable {
+		slog.Debug("insufficient VRAM to load any model layers")
+		return 0, 0
+	}
+
+	var layerCount int
+	layers := ggml.Tensors().Layers()
+	for i := 0; i < int(ggml.KV().BlockCount()); i++ {
+		memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size()
+
+		// KV is proportional to the number of layers
+		memoryLayer += kv / ggml.KV().BlockCount()
+
+		memoryRequiredTotal += memoryLayer
+		if memoryAvailable > memoryRequiredPartial+memoryLayer {
+			memoryRequiredPartial += memoryLayer
+			layerCount++
+		}
+	}
+
+	var memoryLayerOutput uint64
+	for k, v := range layers {
+		if !strings.HasPrefix(k, "blk.") {
+			memoryLayerOutput += v.size()
+		}
+	}
+
+	memoryRequiredTotal += memoryLayerOutput
+
+	if memoryAvailable > memoryRequiredTotal {
+		layerCount = int(ggml.KV().BlockCount()) + 1
+		memoryRequiredPartial = memoryRequiredTotal
+	}
+
+	memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
+
+	slog.Info(
+		"offload to gpu",
+		slog.Group(
+			"layers",
+			// actual number of layers offloaded
+			"real", opts.NumGPU,
+			// estimated number of layers that can be offloaded
+			"estimate", layerCount,
+		),
+		slog.Group(
+			"memory",
+			// memory available for offloading
+			"available", format.HumanBytes2(memoryAvailable),
+			slog.Group(
+				"required",
+				// memory required for full offloading
+				"full", format.HumanBytes2(memoryRequiredTotal),
+				// memory required to offload layers.estimate layers
+				"partial", format.HumanBytes2(memoryRequiredPartial),
+				// memory of KV cache
+				"kv", format.HumanBytes2(kv),
+			),
+			slog.Group(
+				"weights",
+				// memory of the weights
+				"total", format.HumanBytes2(memoryWeights),
+				// memory of repeating layers
+				"repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput),
+				// memory of non-repeating layers
+				"nonrepeating", format.HumanBytes2(memoryLayerOutput),
+			),
+			slog.Group(
+				"graph",
+				// memory of graph when fully offloaded
+				"full", format.HumanBytes2(graphFullOffload),
+				// memory of graph when not fully offloaded
+				"partial", format.HumanBytes2(graphPartialOffload),
+			),
+		),
+	)
+	return layerCount, uint64(memoryRequiredPartial)
+}

+ 18 - 0
llm/payload.go

@@ -9,6 +9,7 @@ import (
 	"log/slog"
 	"os"
 	"path/filepath"
+	"runtime"
 	"strings"
 
 	"golang.org/x/exp/slices"
@@ -138,6 +139,23 @@ func serversForGpu(info gpu.GpuInfo) []string {
 	return servers
 }
 
+// Return the optimal server for this CPU architecture
+func serverForCpu() string {
+	if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
+		return "metal"
+	}
+	variant := gpu.GetCPUVariant()
+	availableServers := availableServers()
+	if variant != "" {
+		for cmp := range availableServers {
+			if cmp == "cpu_"+variant {
+				return cmp
+			}
+		}
+	}
+	return "cpu"
+}
+
 // extract extracts the embedded files to the target directory
 func extractFiles(targetDir string, glob string) error {
 	files, err := fs.Glob(libEmbed, glob)

+ 156 - 143
llm/server.go

@@ -21,21 +21,43 @@ import (
 	"strings"
 	"time"
 
+	"golang.org/x/sync/semaphore"
+
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/gpu"
 )
 
-// LlamaServer is an instance of the llama.cpp server
-type LlamaServer struct {
+type LlamaServer interface {
+	Ping(ctx context.Context) error
+	WaitUntilRunning(ctx context.Context) error
+	Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
+	Embedding(ctx context.Context, prompt string) ([]float64, error)
+	Tokenize(ctx context.Context, content string) ([]int, error)
+	Detokenize(ctx context.Context, tokens []int) (string, error)
+	Close() error
+	EstimatedVRAM() uint64
+}
+
+// llmServer is an instance of the llama.cpp server
+type llmServer struct {
 	port    int
 	cmd     *exec.Cmd
 	done    chan error // Channel to signal when the process exits
 	status  *StatusWriter
 	options api.Options
+
+	// TODO - this should be broken down by GPU
+	estimatedVRAM uint64 // Estimated usage of VRAM by the loaded model
+
+	sem *semaphore.Weighted
 }
 
-func NewLlamaServer(model string, adapters, projectors []string, opts api.Options) (*LlamaServer, error) {
+func LoadModel(model string) (*GGML, error) {
+	if _, err := os.Stat(model); err != nil {
+		return nil, err
+	}
+
 	f, err := os.Open(model)
 	if err != nil {
 		return nil, err
@@ -43,10 +65,13 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 	defer f.Close()
 
 	ggml, _, err := DecodeGGML(f)
-	if err != nil {
-		return nil, err
-	}
+	return ggml, err
+}
 
+// NewLlamaServer will run a server for the given GPUs
+// The gpu list must be a single family.
+func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) {
+	var err error
 	if opts.NumCtx > int(ggml.KV().ContextLength()) {
 		slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
 		opts.NumCtx = int(ggml.KV().ContextLength())
@@ -56,130 +81,50 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 		opts.NumCtx = 4
 	}
 
-	memoryAvailable, _ := gpu.CheckVRAM()
-	info := gpu.GetGPUInfo()
-
-	memoryMinimum := info.MinimumMemory
-	for _, projector := range projectors {
-		memoryMinimum += projectorMemoryRequirements(projector)
-
-		// multimodal models require at least 2048 context
-		opts.NumCtx = max(opts.NumCtx, 2048)
-	}
+	cpuRunner := ""
+	var estimatedVRAM uint64
+	var systemMemory uint64
+	if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 {
 
-	// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
-	var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
+		// TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner
 
-	graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
-	if graphPartialOffload == 0 {
-		graphPartialOffload = ggml.KV().GQA() * kv / 6
-	}
-
-	if graphFullOffload == 0 {
-		graphFullOffload = graphPartialOffload
-	}
-
-	graphFullOffload *= uint64(info.DeviceCount)
-	graphPartialOffload *= uint64(info.DeviceCount)
-
-	// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
-	memoryRequiredTotal := memoryMinimum + graphFullOffload
-
-	// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
-	memoryRequiredPartial := memoryMinimum + graphPartialOffload
-
-	if info.Library != "metal" {
-		if memoryRequiredPartial > memoryAvailable {
-			info.Library = "cpu"
-		}
-	}
-
-	var layerCount int
-	layers := ggml.Tensors().Layers()
-	for i := 0; i < int(ggml.KV().BlockCount()); i++ {
-		memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size()
-
-		// KV is proportional to the number of layers
-		memoryLayer += kv / ggml.KV().BlockCount()
-
-		memoryRequiredTotal += memoryLayer
-		if memoryAvailable > memoryRequiredPartial+memoryLayer {
-			memoryRequiredPartial += memoryLayer
-			layerCount++
+		cpuRunner = serverForCpu()
+	} else {
+		if gpus[0].Library == "metal" {
+			memInfo, err := gpu.GetCPUMem()
+			if err != nil {
+				slog.Error("failed to lookup system memory", "error", err)
+			} else {
+				systemMemory = memInfo.TotalMemory
+				slog.Debug("system memory", "total", format.HumanBytes2(systemMemory))
+			}
 		}
-	}
-
-	var memoryLayerOutput uint64
-	for k, v := range layers {
-		if !strings.HasPrefix(k, "blk.") {
-			memoryLayerOutput += v.size()
+		var layers int
+		layers, estimatedVRAM = EstimateGPULayers(gpus, ggml, projectors, opts)
+
+		if gpus[0].Library == "metal" && estimatedVRAM > systemMemory {
+			// disable partial offloading when model is greater than total system memory as this
+			// can lead to locking up the system
+			opts.NumGPU = 0
+		} else if opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu" {
+			opts.NumGPU = layers
 		}
 	}
 
-	memoryRequiredTotal += memoryLayerOutput
-
-	if info.Library == "metal" && memoryRequiredTotal > info.TotalMemory {
-		// disable partial offloading when model is greater than total system memory
-		opts.NumGPU = 0
-	} else if memoryAvailable > memoryRequiredTotal {
-		layerCount = int(ggml.KV().BlockCount()) + 1
-		memoryRequiredPartial = memoryRequiredTotal
-	}
-
-	if opts.NumGPU < 0 {
-		opts.NumGPU = layerCount
-	}
-
-	memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
-
-	slog.Info(
-		"offload to gpu",
-		slog.Group(
-			"layers",
-			// actual number of layers offloaded
-			"real", opts.NumGPU,
-			// estimated number of layers that can be offloaded
-			"estimate", layerCount,
-		),
-		slog.Group(
-			"memory",
-			// memory available for offloading
-			"available", format.HumanBytes2(memoryAvailable),
-			slog.Group(
-				"required",
-				// memory required for full offloading
-				"full", format.HumanBytes2(memoryRequiredTotal),
-				// memory required to offload layers.estimate layers
-				"partial", format.HumanBytes2(memoryRequiredPartial),
-				// memory of KV cache
-				"kv", format.HumanBytes2(kv),
-			),
-			slog.Group(
-				"weights",
-				// memory of the weights
-				"total", format.HumanBytes2(memoryWeights),
-				// memory of repeating layers
-				"repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput),
-				// memory of non-repeating layers
-				"nonrepeating", format.HumanBytes2(memoryLayerOutput),
-			),
-			slog.Group(
-				"graph",
-				// memory of graph when fully offloaded
-				"full", format.HumanBytes2(graphFullOffload),
-				// memory of graph when not fully offloaded
-				"partial", format.HumanBytes2(graphPartialOffload),
-			),
-		),
-	)
+	// Loop through potential servers
+	finalErr := fmt.Errorf("no suitable llama servers found")
 
 	if len(adapters) > 1 {
 		return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
 	}
 
 	availableServers := availableServers()
-	servers := serversForGpu(info)
-
+	var servers []string
+	if cpuRunner != "" {
+		servers = []string{cpuRunner}
+	} else {
+		servers = serversForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant
+	}
 	demandLib := strings.Trim(os.Getenv("OLLAMA_LLM_LIBRARY"), "\"' ")
 	if demandLib != "" {
 		serverPath := availableServers[demandLib]
@@ -192,7 +137,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 	}
 
 	if len(servers) == 0 {
-		return nil, fmt.Errorf("no servers found for %v", info)
+		return nil, fmt.Errorf("no servers found for %v", gpus)
 	}
 
 	params := []string{
@@ -249,8 +194,18 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 		params = append(params, "--numa")
 	}
 
-	// Loop through potential servers
-	var finalErr error
+	// "--cont-batching", // TODO - doesn't seem to have any noticeable perf change for multiple requests
+	numParallel := 1
+	if onp := os.Getenv("OLLAMA_NUM_PARALLEL"); onp != "" {
+		numParallel, err = strconv.Atoi(onp)
+		if err != nil || numParallel <= 0 {
+			err = fmt.Errorf("invalid OLLAMA_NUM_PARALLEL=%s must be greater than zero - %w", onp, err)
+			slog.Error("misconfiguration", "error", err)
+			return nil, err
+		}
+	}
+	params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
+
 	for i := 0; i < len(servers); i++ {
 		dir := availableServers[servers[i]]
 
@@ -275,30 +230,49 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 		}
 		// append the server directory to LD_LIBRARY_PATH/PATH
 		libraryPaths := []string{dir}
+
 		if libraryPath, ok := os.LookupEnv(pathEnv); ok {
 			// Append our runner directory to the path
 			// This will favor system libraries over our bundled library dependencies
 			libraryPaths = append(filepath.SplitList(libraryPath), libraryPaths...)
 		}
 
+		// Note: we always put the dependency path first
+		// since this was the exact version we verified for AMD GPUs
+		// and we favor what the user had in their path
+		if gpus[0].DependencyPath != "" {
+			// TODO refine for multi-gpu support
+			libraryPaths = append([]string{gpus[0].DependencyPath}, libraryPaths...)
+		}
+
 		server := filepath.Join(dir, "ollama_llama_server")
 		if runtime.GOOS == "windows" {
 			server = server + ".exe"
 		}
 
-		s := &LlamaServer{
-			port:    port,
-			cmd:     exec.Command(server, finalParams...),
-			status:  NewStatusWriter(os.Stderr),
-			options: opts,
+		s := &llmServer{
+			port:          port,
+			cmd:           exec.Command(server, finalParams...),
+			status:        NewStatusWriter(os.Stderr),
+			options:       opts,
+			estimatedVRAM: estimatedVRAM,
+			sem:           semaphore.NewWeighted(int64(numParallel)),
 		}
+
 		libEnv := fmt.Sprintf("%s=%s", pathEnv, strings.Join(libraryPaths, string(filepath.ListSeparator)))
-		slog.Debug(libEnv)
 		s.cmd.Env = append(os.Environ(), libEnv)
 		s.cmd.Stdout = os.Stdout
 		s.cmd.Stderr = s.status
 
+		// TODO - multiple GPU selection logic...
+		key, val := gpu.GpuInfoList(gpus).GetVisibleDevicesEnv()
+		if key != "" {
+			s.cmd.Env = append(s.cmd.Env, key+"="+val)
+		}
+
 		slog.Info("starting llama server", "cmd", s.cmd.String())
+		// Log at debug as the environment is inherited and might contain sensitive information
+		slog.Debug("subprocess", "environment", s.cmd.Env)
 
 		if err = s.cmd.Start(); err != nil {
 			msg := ""
@@ -316,6 +290,13 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 			_ = s.cmd.Wait()
 		}()
 
+		// TODO - make sure this is all wired up correctly
+		// if err = s.WaitUntilRunning(); err != nil {
+		// 	slog.Error("error starting llama server", "server", servers[i], "error", err)
+		// 	s.Close()
+		// 	finalErr = err
+		// 	continue
+		// }
 		return s, nil
 	}
 
@@ -353,6 +334,21 @@ const ( // iota is reset to 0
 	ServerStatusError
 )
 
+func (s ServerStatus) ToString() string {
+	switch s {
+	case ServerStatusReady:
+		return "llm server ready"
+	case ServerStatusNoSlotsAvaialble:
+		return "llm busy - no slots available"
+	case ServerStatusLoadingModel:
+		return "llm server loading model"
+	case ServerStatusNotResponding:
+		return "llm server not responding"
+	default:
+		return "llm server error"
+	}
+}
+
 type ServerStatusResp struct {
 	Status          string `json:"status"`
 	SlotsIdle       int    `json:"slots_idle"`
@@ -360,7 +356,7 @@ type ServerStatusResp struct {
 	Error           string `json:"error"`
 }
 
-func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
+func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
 	// Fail fast if its exited
 	if s.cmd.ProcessState != nil {
 		msg := ""
@@ -407,7 +403,7 @@ func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error)
 	}
 }
 
-func (s *LlamaServer) Ping(ctx context.Context) error {
+func (s *llmServer) Ping(ctx context.Context) error {
 	_, err := s.getServerStatus(ctx)
 	if err != nil {
 		slog.Debug("server unhealthy", "error", err)
@@ -416,7 +412,7 @@ func (s *LlamaServer) Ping(ctx context.Context) error {
 	return nil
 }
 
-func (s *LlamaServer) WaitUntilRunning() error {
+func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
 	start := time.Now()
 	// TODO we need to wire up a better way to detect hangs during model load and startup of the server
 	expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
@@ -427,6 +423,9 @@ func (s *LlamaServer) WaitUntilRunning() error {
 	var lastStatus ServerStatus = -1
 	for {
 		select {
+		case <-ctx.Done():
+			slog.Info("context expired before server started")
+			return fmt.Errorf("timed out waiting for llama runner to start")
 		case err := <-s.done:
 			msg := ""
 			if s.status != nil && s.status.LastErrMsg != "" {
@@ -450,9 +449,9 @@ func (s *LlamaServer) WaitUntilRunning() error {
 				return fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
 			}
 
-			ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
+			c, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
 			defer cancel()
-			status, err := s.getServerStatus(ctx)
+			status, err := s.getServerStatus(c)
 			if err != nil && lastStatus != status {
 				slog.Debug("server not yet available", "error", err)
 				lastStatus = status
@@ -538,7 +537,12 @@ type CompletionResponse struct {
 	EvalDuration       time.Duration
 }
 
-func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
+func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
+	if err := s.sem.Acquire(ctx, 1); err != nil {
+		slog.Error("Failed to acquire semaphore", "error", err)
+		return err
+	}
+	defer s.sem.Release(1)
 	request := map[string]any{
 		"prompt":            req.Prompt,
 		"stream":            true,
@@ -569,7 +573,7 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
 	if err != nil {
 		return err
 	} else if status != ServerStatusReady {
-		return fmt.Errorf("unexpected server status: %d", status)
+		return fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 
 	if req.Format == "json" {
@@ -716,13 +720,18 @@ type EmbeddingResponse struct {
 	Embedding []float64 `json:"embedding"`
 }
 
-func (s *LlamaServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
+func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
+	if err := s.sem.Acquire(ctx, 1); err != nil {
+		slog.Error("Failed to acquire semaphore", "error", err)
+		return nil, err
+	}
+	defer s.sem.Release(1)
 	// Make sure the server is ready
 	status, err := s.getServerStatus(ctx)
 	if err != nil {
 		return nil, err
 	} else if status != ServerStatusReady {
-		return nil, fmt.Errorf("unexpected server status: %d", status)
+		return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 
 	data, err := json.Marshal(TokenizeRequest{Content: prompt})
@@ -768,13 +777,13 @@ type TokenizeResponse struct {
 	Tokens []int `json:"tokens"`
 }
 
-func (s *LlamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
+func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
 	// Make sure the server is ready
 	status, err := s.getServerStatus(ctx)
 	if err != nil {
 		return nil, err
-	} else if status != ServerStatusReady {
-		return nil, fmt.Errorf("unexpected server status: %d", status)
+	} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble {
+		return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 
 	data, err := json.Marshal(TokenizeRequest{Content: content})
@@ -820,13 +829,13 @@ type DetokenizeResponse struct {
 	Content string `json:"content"`
 }
 
-func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
+func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
 	// Make sure the server is ready
 	status, err := s.getServerStatus(ctx)
 	if err != nil {
 		return "", err
-	} else if status != ServerStatusReady {
-		return "", fmt.Errorf("unexpected server status: %d", status)
+	} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble {
+		return "", fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 
 	data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
@@ -864,7 +873,7 @@ func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, err
 	return decoded.Content, nil
 }
 
-func (s *LlamaServer) Close() error {
+func (s *llmServer) Close() error {
 	if s.cmd != nil {
 		slog.Debug("stopping llama server")
 		return s.cmd.Process.Kill()
@@ -873,6 +882,10 @@ func (s *LlamaServer) Close() error {
 	return nil
 }
 
+func (s *llmServer) EstimatedVRAM() uint64 {
+	return s.estimatedVRAM
+}
+
 func parseDurationMs(ms float64) time.Duration {
 	dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
 	if err != nil {

+ 61 - 141
server/routes.go

@@ -15,11 +15,8 @@ import (
 	"os"
 	"os/signal"
 	"path/filepath"
-	"reflect"
-	"runtime"
 	"strconv"
 	"strings"
-	"sync"
 	"syscall"
 	"time"
 
@@ -38,7 +35,8 @@ import (
 var mode string = gin.DebugMode
 
 type Server struct {
-	addr net.Addr
+	addr  net.Addr
+	sched *Scheduler
 }
 
 func init() {
@@ -53,88 +51,8 @@ func init() {
 	gin.SetMode(mode)
 }
 
-var loaded struct {
-	mu sync.Mutex
-
-	llama *llm.LlamaServer
-
-	expireTimer *time.Timer
-
-	model      string
-	adapters   []string
-	projectors []string
-	*api.Options
-}
-
 var defaultSessionDuration = 5 * time.Minute
 
-func unload() {
-	if loaded.llama != nil {
-		loaded.llama.Close()
-	}
-
-	loaded.llama = nil
-	loaded.model = ""
-	loaded.adapters = nil
-	loaded.projectors = nil
-	loaded.Options = nil
-}
-
-// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
-func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
-	ctx, cancel := context.WithTimeout(c, 10*time.Second)
-	defer cancel()
-
-	needLoad := loaded.llama == nil || // is there a model loaded?
-		loaded.model != model.ModelPath || // has the base model changed?
-		!reflect.DeepEqual(loaded.adapters, model.AdapterPaths) || // have the adapters changed?
-		!reflect.DeepEqual(loaded.projectors, model.ProjectorPaths) || // have the adapters changed?
-		!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) || // have the runner options changed?
-		loaded.llama.Ping(ctx) != nil
-
-	if needLoad {
-		if loaded.llama != nil {
-			slog.Info("changing loaded model")
-			unload()
-		}
-
-		llama, err := llm.NewLlamaServer(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
-		if err != nil {
-			// some older models are not compatible with newer versions of llama.cpp
-			// show a generalized compatibility error until there is a better way to
-			// check for model compatibility
-			if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
-				err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
-			}
-
-			return err
-		}
-
-		loaded.model = model.ModelPath
-		loaded.adapters = model.AdapterPaths
-		loaded.projectors = model.ProjectorPaths
-		loaded.llama = llama
-		loaded.Options = &opts
-
-		if err = llama.WaitUntilRunning(); err != nil {
-			slog.Error("error loading llama server", "error", err)
-			unload()
-			return err
-		}
-	}
-
-	if loaded.expireTimer == nil {
-		loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
-			loaded.mu.Lock()
-			defer loaded.mu.Unlock()
-			unload()
-		})
-	}
-
-	loaded.expireTimer.Reset(sessionDuration)
-	return nil
-}
-
 func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
 	opts := api.DefaultOptions()
 	if err := opts.FromMap(model.Options); err != nil {
@@ -154,9 +72,7 @@ func isSupportedImageType(image []byte) bool {
 	return slices.Contains(allowedTypes, contentType)
 }
 
-func GenerateHandler(c *gin.Context) {
-	loaded.mu.Lock()
-	defer loaded.mu.Unlock()
+func (s *Server) GenerateHandler(c *gin.Context) {
 
 	checkpointStart := time.Now()
 	var req api.GenerateRequest
@@ -224,7 +140,11 @@ func GenerateHandler(c *gin.Context) {
 		sessionDuration = req.KeepAlive.Duration
 	}
 
-	if err := load(c, model, opts, sessionDuration); err != nil {
+	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
+	var runner *runnerRef
+	select {
+	case runner = <-rCh:
+	case err = <-eCh:
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 	}
@@ -275,7 +195,7 @@ func GenerateHandler(c *gin.Context) {
 
 		sb.Reset()
 		if req.Context != nil {
-			prev, err := loaded.llama.Detokenize(c.Request.Context(), req.Context)
+			prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context)
 			if err != nil {
 				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 				return
@@ -297,9 +217,6 @@ func GenerateHandler(c *gin.Context) {
 		defer close(ch)
 
 		fn := func(r llm.CompletionResponse) {
-			// Update model expiration
-			loaded.expireTimer.Reset(sessionDuration)
-
 			// Build up the full response
 			if _, err := generated.WriteString(r.Content); err != nil {
 				ch <- gin.H{"error": err.Error()}
@@ -331,7 +248,7 @@ func GenerateHandler(c *gin.Context) {
 					}
 
 					// TODO (jmorganca): encode() should not strip special tokens
-					tokens, err := loaded.llama.Tokenize(c.Request.Context(), p)
+					tokens, err := runner.llama.Tokenize(c.Request.Context(), p)
 					if err != nil {
 						ch <- gin.H{"error": err.Error()}
 						return
@@ -359,7 +276,7 @@ func GenerateHandler(c *gin.Context) {
 			Images:  images,
 			Options: opts,
 		}
-		if err := loaded.llama.Completion(c.Request.Context(), req, fn); err != nil {
+		if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil {
 			ch <- gin.H{"error": err.Error()}
 		}
 	}()
@@ -421,10 +338,7 @@ func getDefaultSessionDuration() time.Duration {
 	return defaultSessionDuration
 }
 
-func EmbeddingsHandler(c *gin.Context) {
-	loaded.mu.Lock()
-	defer loaded.mu.Unlock()
-
+func (s *Server) EmbeddingsHandler(c *gin.Context) {
 	var req api.EmbeddingRequest
 	err := c.ShouldBindJSON(&req)
 	switch {
@@ -469,7 +383,11 @@ func EmbeddingsHandler(c *gin.Context) {
 		sessionDuration = req.KeepAlive.Duration
 	}
 
-	if err := load(c, model, opts, sessionDuration); err != nil {
+	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
+	var runner *runnerRef
+	select {
+	case runner = <-rCh:
+	case err = <-eCh:
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 	}
@@ -480,7 +398,7 @@ func EmbeddingsHandler(c *gin.Context) {
 		return
 	}
 
-	embedding, err := loaded.llama.Embedding(c.Request.Context(), req.Prompt)
+	embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
 	if err != nil {
 		slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
 		c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
@@ -493,7 +411,7 @@ func EmbeddingsHandler(c *gin.Context) {
 	c.JSON(http.StatusOK, resp)
 }
 
-func PullModelHandler(c *gin.Context) {
+func (s *Server) PullModelHandler(c *gin.Context) {
 	var req api.PullRequest
 	err := c.ShouldBindJSON(&req)
 	switch {
@@ -542,7 +460,7 @@ func PullModelHandler(c *gin.Context) {
 	streamResponse(c, ch)
 }
 
-func PushModelHandler(c *gin.Context) {
+func (s *Server) PushModelHandler(c *gin.Context) {
 	var req api.PushRequest
 	err := c.ShouldBindJSON(&req)
 	switch {
@@ -591,7 +509,7 @@ func PushModelHandler(c *gin.Context) {
 	streamResponse(c, ch)
 }
 
-func CreateModelHandler(c *gin.Context) {
+func (s *Server) CreateModelHandler(c *gin.Context) {
 	var req api.CreateRequest
 	err := c.ShouldBindJSON(&req)
 	switch {
@@ -664,7 +582,7 @@ func CreateModelHandler(c *gin.Context) {
 	streamResponse(c, ch)
 }
 
-func DeleteModelHandler(c *gin.Context) {
+func (s *Server) DeleteModelHandler(c *gin.Context) {
 	var req api.DeleteRequest
 	err := c.ShouldBindJSON(&req)
 	switch {
@@ -709,7 +627,7 @@ func DeleteModelHandler(c *gin.Context) {
 	c.JSON(http.StatusOK, nil)
 }
 
-func ShowModelHandler(c *gin.Context) {
+func (s *Server) ShowModelHandler(c *gin.Context) {
 	var req api.ShowRequest
 	err := c.ShouldBindJSON(&req)
 	switch {
@@ -809,7 +727,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
 	return resp, nil
 }
 
-func ListModelsHandler(c *gin.Context) {
+func (s *Server) ListModelsHandler(c *gin.Context) {
 	models := make([]api.ModelResponse, 0)
 	manifestsPath, err := GetManifestPath()
 	if err != nil {
@@ -869,7 +787,7 @@ func ListModelsHandler(c *gin.Context) {
 	c.JSON(http.StatusOK, api.ListResponse{Models: models})
 }
 
-func CopyModelHandler(c *gin.Context) {
+func (s *Server) CopyModelHandler(c *gin.Context) {
 	var req api.CopyRequest
 	err := c.ShouldBindJSON(&req)
 	switch {
@@ -901,7 +819,7 @@ func CopyModelHandler(c *gin.Context) {
 	}
 }
 
-func HeadBlobHandler(c *gin.Context) {
+func (s *Server) HeadBlobHandler(c *gin.Context) {
 	path, err := GetBlobsPath(c.Param("digest"))
 	if err != nil {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -916,7 +834,7 @@ func HeadBlobHandler(c *gin.Context) {
 	c.Status(http.StatusOK)
 }
 
-func CreateBlobHandler(c *gin.Context) {
+func (s *Server) CreateBlobHandler(c *gin.Context) {
 	path, err := GetBlobsPath(c.Param("digest"))
 	if err != nil {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -1063,27 +981,27 @@ func (s *Server) GenerateRoutes() http.Handler {
 		allowedHostsMiddleware(s.addr),
 	)
 
-	r.POST("/api/pull", PullModelHandler)
-	r.POST("/api/generate", GenerateHandler)
-	r.POST("/api/chat", ChatHandler)
-	r.POST("/api/embeddings", EmbeddingsHandler)
-	r.POST("/api/create", CreateModelHandler)
-	r.POST("/api/push", PushModelHandler)
-	r.POST("/api/copy", CopyModelHandler)
-	r.DELETE("/api/delete", DeleteModelHandler)
-	r.POST("/api/show", ShowModelHandler)
-	r.POST("/api/blobs/:digest", CreateBlobHandler)
-	r.HEAD("/api/blobs/:digest", HeadBlobHandler)
+	r.POST("/api/pull", s.PullModelHandler)
+	r.POST("/api/generate", s.GenerateHandler)
+	r.POST("/api/chat", s.ChatHandler)
+	r.POST("/api/embeddings", s.EmbeddingsHandler)
+	r.POST("/api/create", s.CreateModelHandler)
+	r.POST("/api/push", s.PushModelHandler)
+	r.POST("/api/copy", s.CopyModelHandler)
+	r.DELETE("/api/delete", s.DeleteModelHandler)
+	r.POST("/api/show", s.ShowModelHandler)
+	r.POST("/api/blobs/:digest", s.CreateBlobHandler)
+	r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
 
 	// Compatibility endpoints
-	r.POST("/v1/chat/completions", openai.Middleware(), ChatHandler)
+	r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler)
 
 	for _, method := range []string{http.MethodGet, http.MethodHead} {
 		r.Handle(method, "/", func(c *gin.Context) {
 			c.String(http.StatusOK, "Ollama is running")
 		})
 
-		r.Handle(method, "/api/tags", ListModelsHandler)
+		r.Handle(method, "/api/tags", s.ListModelsHandler)
 		r.Handle(method, "/api/version", func(c *gin.Context) {
 			c.JSON(http.StatusOK, gin.H{"version": version.Version})
 		})
@@ -1137,7 +1055,9 @@ func Serve(ln net.Listener) error {
 		}
 	}
 
-	s := &Server{addr: ln.Addr()}
+	ctx, done := context.WithCancel(context.Background())
+	sched := InitScheduler(ctx)
+	s := &Server{addr: ln.Addr(), sched: sched}
 	r := s.GenerateRoutes()
 
 	slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
@@ -1150,7 +1070,8 @@ func Serve(ln net.Listener) error {
 	signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
 	go func() {
 		<-signals
-		unload()
+		done()
+		sched.unloadAllRunners()
 		gpu.Cleanup()
 		os.Exit(0)
 	}()
@@ -1158,12 +1079,12 @@ func Serve(ln net.Listener) error {
 	if err := llm.Init(); err != nil {
 		return fmt.Errorf("unable to initialize llm library %w", err)
 	}
-	if runtime.GOOS == "linux" { // TODO - windows too
-		// check compatibility to log warnings
-		if _, err := gpu.CheckVRAM(); err != nil {
-			slog.Info(err.Error())
-		}
-	}
+
+	s.sched.Run(ctx)
+
+	// At startup we retrieve GPU information so we can get log messages before loading a model
+	// This will log warnings to the log in case we have problems with detected GPUs
+	_ = gpu.GetGPUInfo()
 
 	return srvr.Serve(ln)
 }
@@ -1219,9 +1140,9 @@ func streamResponse(c *gin.Context, ch chan any) {
 }
 
 // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
-func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
+func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) {
 	encode := func(s string) ([]int, error) {
-		return loaded.llama.Tokenize(ctx, s)
+		return runner.llama.Tokenize(ctx, s)
 	}
 
 	prompt, err := ChatPrompt(template, messages, numCtx, encode)
@@ -1232,10 +1153,7 @@ func chatPrompt(ctx context.Context, template string, messages []api.Message, nu
 	return prompt, nil
 }
 
-func ChatHandler(c *gin.Context) {
-	loaded.mu.Lock()
-	defer loaded.mu.Unlock()
-
+func (s *Server) ChatHandler(c *gin.Context) {
 	checkpointStart := time.Now()
 
 	var req api.ChatRequest
@@ -1292,7 +1210,11 @@ func ChatHandler(c *gin.Context) {
 		sessionDuration = req.KeepAlive.Duration
 	}
 
-	if err := load(c, model, opts, sessionDuration); err != nil {
+	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
+	var runner *runnerRef
+	select {
+	case runner = <-rCh:
+	case err = <-eCh:
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 	}
@@ -1309,7 +1231,7 @@ func ChatHandler(c *gin.Context) {
 		}, req.Messages...)
 	}
 
-	prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
+	prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx)
 	if err != nil {
 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
@@ -1352,8 +1274,6 @@ func ChatHandler(c *gin.Context) {
 		defer close(ch)
 
 		fn := func(r llm.CompletionResponse) {
-			// Update model expiration
-			loaded.expireTimer.Reset(sessionDuration)
 
 			resp := api.ChatResponse{
 				Model:     req.Model,
@@ -1376,7 +1296,7 @@ func ChatHandler(c *gin.Context) {
 			ch <- resp
 		}
 
-		if err := loaded.llama.Completion(c.Request.Context(), llm.CompletionRequest{
+		if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
 			Prompt:  prompt,
 			Format:  req.Format,
 			Images:  images,

+ 525 - 0
server/sched.go

@@ -0,0 +1,525 @@
+package server
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"log/slog"
+	"os"
+	"reflect"
+	"sort"
+	"strconv"
+	"strings"
+	"sync"
+	"time"
+
+	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/format"
+	"github.com/ollama/ollama/gpu"
+	"github.com/ollama/ollama/llm"
+	"golang.org/x/exp/slices"
+)
+
+type LlmRequest struct {
+	ctx             context.Context //nolint:containedctx
+	model           *Model
+	ggml            *llm.GGML // TODO - how large is this, and do we need to free it after we've finished loading?
+	opts            api.Options
+	sessionDuration time.Duration
+	successCh       chan *runnerRef
+	errCh           chan error
+}
+
+type Scheduler struct {
+	pendingReqCh  chan *LlmRequest
+	finishedReqCh chan *LlmRequest
+	expiredCh     chan *runnerRef
+	unloadedCh    chan interface{}
+
+	loaded   map[string]*runnerRef
+	loadedMu sync.Mutex
+
+	loadFn      func(req *LlmRequest, gpus gpu.GpuInfoList)
+	newServerFn func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error)
+	getGpuFn    func() gpu.GpuInfoList
+}
+
+// TODO set this to zero after a release or two, to enable multiple models by default
+var loadedMax = 1          // Maximum runners; < 1 maps to as many as will fit in VRAM (unlimited for CPU runners)
+var maxQueuedRequests = 10 // TODO configurable
+
+func InitScheduler(ctx context.Context) *Scheduler {
+	maxRunners := os.Getenv("OLLAMA_MAX_LOADED_MODELS")
+	if maxRunners != "" {
+		m, err := strconv.Atoi(maxRunners)
+		if err != nil {
+			slog.Error("invalid setting", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err)
+		} else {
+			loadedMax = m
+		}
+	}
+
+	sched := &Scheduler{
+		pendingReqCh:  make(chan *LlmRequest, maxQueuedRequests),
+		finishedReqCh: make(chan *LlmRequest, maxQueuedRequests),
+		expiredCh:     make(chan *runnerRef, maxQueuedRequests),
+		unloadedCh:    make(chan interface{}, maxQueuedRequests),
+		loaded:        make(map[string]*runnerRef),
+		newServerFn:   llm.NewLlamaServer,
+		getGpuFn:      gpu.GetGPUInfo,
+	}
+	sched.loadFn = sched.load
+	return sched
+}
+
+// context must be canceled to decrement ref count and release the runner
+func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) {
+	ggml, err := llm.LoadModel(model.ModelPath)
+	req := &LlmRequest{
+		ctx:             c,
+		model:           model,
+		ggml:            ggml,
+		opts:            opts,
+		sessionDuration: sessionDuration,
+		successCh:       make(chan *runnerRef),
+		errCh:           make(chan error, 1),
+	}
+	if err != nil {
+		req.errCh <- err
+		return req.successCh, req.errCh
+	}
+	select {
+	case s.pendingReqCh <- req:
+	default:
+		req.errCh <- fmt.Errorf("server busy, please try again.  maximum pending requests exceeded")
+	}
+	return req.successCh, req.errCh
+}
+
+// Returns immediately, spawns go routines for the scheduler which will shutdown when ctx is done
+func (s *Scheduler) Run(ctx context.Context) {
+	slog.Debug("starting llm scheduler")
+	go func() {
+		s.processPending(ctx)
+	}()
+
+	go func() {
+		s.processCompleted(ctx)
+	}()
+}
+
+func (s *Scheduler) processPending(ctx context.Context) {
+	for {
+		select {
+		case <-ctx.Done():
+			slog.Debug("shutting down scheduler pending loop")
+			return
+		case pending := <-s.pendingReqCh:
+			// Block other requests until we get this pending request running
+			for {
+				var runnerToExpire *runnerRef
+				s.loadedMu.Lock()
+				runner := s.loaded[pending.model.ModelPath]
+				loadedCount := len(s.loaded)
+				s.loadedMu.Unlock()
+				if runner != nil {
+					if runner.needsReload(ctx, pending) {
+						runnerToExpire = runner
+					} else {
+						// Runner is usable, return it
+						pending.useLoadedRunner(runner, s.finishedReqCh)
+						break
+					}
+				} else if loadedCount == 0 {
+					slog.Debug("loading first model", "model", pending.model.ModelPath)
+					gpus := s.getGpuFn()
+					g := pickBestFitGPUs(pending, gpus)
+					if g != nil {
+						gpus = g
+					}
+					s.loadFn(pending, gpus)
+					break
+				} else if loadedMax > 0 && loadedCount >= loadedMax {
+					slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount)
+					runnerToExpire = s.findRunnerToUnload(pending)
+				} else {
+					// More than one loaded model, so we have to see if the new one fits
+					// Get a refreshed GPU list
+					gpus := s.getGpuFn()
+					// Update free memory from currently loaded models
+					s.updateFreeSpace(gpus)
+					gpus = pickBestFitGPUs(pending, gpus)
+					if gpus != nil {
+						slog.Debug("new model fits with existing models, loading")
+						s.loadFn(pending, gpus)
+						break
+					}
+					runnerToExpire = s.findRunnerToUnload(pending)
+				}
+
+				if runnerToExpire == nil {
+					// Shouildn't happen
+					slog.Error("runner to expire was nil!")
+					continue
+				}
+				// Trigger an expiration to unload once it's done
+				runnerToExpire.refMu.Lock()
+				slog.Debug("resetting model to expire immediately to make room", "model", runnerToExpire.model, "refCount", runnerToExpire.refCount)
+				if runnerToExpire.expireTimer != nil {
+					runnerToExpire.expireTimer.Stop()
+					runnerToExpire.expireTimer = nil
+				}
+				runnerToExpire.sessionDuration = 0
+				if runnerToExpire.refCount <= 0 {
+					s.expiredCh <- runnerToExpire
+				}
+				runnerToExpire.refMu.Unlock()
+				// Wait for the unload to happen
+				// Note: at this point we're queueing up all incoming requests, even if they were for
+				// a different model that's loaded and not scheduled to be removed.
+				slog.Debug("waiting for pending requests to complete and unload to occur", "model", runnerToExpire.model)
+				select {
+				case <-ctx.Done():
+					slog.Debug("shutting down scheduler pending loop")
+					return
+				case <-s.unloadedCh:
+					slog.Debug("unload completed", "model", runnerToExpire.model)
+					continue
+				}
+			}
+		case <-s.unloadedCh:
+			// An unload request when there are no pending request can be ignored
+			slog.Debug("ignoring unload event with no pending requests")
+		}
+	}
+}
+
+func (s *Scheduler) processCompleted(ctx context.Context) {
+	// Process completed requests, expired timers, and unloading models
+	for {
+		select {
+		case <-ctx.Done():
+			slog.Debug("shutting down scheduler completed loop")
+			return
+		case finished := <-s.finishedReqCh:
+			s.loadedMu.Lock()
+			runner := s.loaded[finished.model.ModelPath]
+			s.loadedMu.Unlock()
+			if runner == nil {
+				slog.Error("finished requeset signal received after model unloaded", "model", finished.model.ModelPath)
+				continue
+			}
+			runner.refMu.Lock()
+			runner.refCount--
+			if runner.refCount <= 0 {
+				if runner.sessionDuration <= 0 {
+					slog.Debug("runner with zero duration has gone idle, expiring to unload", "model", runner.model)
+					if runner.expireTimer != nil {
+						runner.expireTimer.Stop()
+						runner.expireTimer = nil
+					}
+					s.expiredCh <- runner
+				} else if runner.expireTimer == nil {
+					slog.Debug("runner with non-zero duration has gone idle, adding timer", "model", runner.model, "duration", runner.sessionDuration)
+					runner.expireTimer = time.AfterFunc(runner.sessionDuration, func() {
+						slog.Debug("timer expired, expiring to unload", "model", runner.model)
+						runner.refMu.Lock()
+						defer runner.refMu.Unlock()
+						if runner.expireTimer != nil {
+							runner.expireTimer.Stop()
+						}
+						s.expiredCh <- runner
+					})
+				} else {
+					slog.Debug("runner with non-zero duration has gone idle, resetting timer", "model", runner.model, "duration", runner.sessionDuration)
+					runner.expireTimer.Reset(runner.sessionDuration)
+				}
+			}
+			slog.Debug("after processing request finished event", "model", runner.model, "refCount", runner.refCount)
+			runner.refMu.Unlock()
+		case runner := <-s.expiredCh:
+			slog.Debug("runner expired event received", "model", runner.model)
+			runner.refMu.Lock()
+			if runner.refCount > 0 {
+				// Shouldn't happen, but safeguard to ensure no leaked runners
+				slog.Debug("expired event with positive ref count, retrying", "model", runner.model, "refCount", runner.refCount)
+				go func(runner *runnerRef) {
+					// We can't unload yet, but want to as soon as the current request completes
+					// So queue up another expired event
+					time.Sleep(10 * time.Millisecond)
+					s.expiredCh <- runner
+				}(runner)
+				runner.refMu.Unlock()
+				continue
+			}
+
+			slog.Debug("got lock to unload", "model", runner.model)
+			runner.unload()
+			s.loadedMu.Lock()
+			delete(s.loaded, runner.model)
+			s.loadedMu.Unlock()
+			slog.Debug("runner released", "model", runner.model)
+			runner.refMu.Unlock()
+			slog.Debug("sending an unloaded event", "model", runner.model)
+			s.unloadedCh <- struct{}{}
+		}
+	}
+}
+
+// Complete the pending request and send the runner back to the requester
+// Wires up a finished event after the request context is completed
+// Updates session duration, and resets expiration timer
+func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *LlmRequest) {
+	runner.refMu.Lock()
+	defer runner.refMu.Unlock()
+	runner.refCount++
+	runner.sessionDuration = pending.sessionDuration
+	pending.successCh <- runner
+	go func() {
+		<-pending.ctx.Done()
+		slog.Debug("context for request finished")
+		finished <- pending
+	}()
+}
+
+func (s *Scheduler) load(req *LlmRequest, gpus gpu.GpuInfoList) {
+	llama, err := s.newServerFn(gpus, req.model.ModelPath, req.ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts)
+	if err != nil {
+		// some older models are not compatible with newer versions of llama.cpp
+		// show a generalized compatibility error until there is a better way to
+		// check for model compatibility
+		if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
+			err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
+		}
+		slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err)
+		req.errCh <- err
+		return
+	}
+	runner := &runnerRef{}
+	runner.model = req.model.ModelPath
+	runner.adapters = req.model.AdapterPaths
+	runner.projectors = req.model.ProjectorPaths
+	runner.llama = llama
+	runner.Options = &req.opts
+	runner.sessionDuration = req.sessionDuration
+	runner.gpus = gpus
+	runner.estimatedVRAM = llama.EstimatedVRAM()
+	runner.loading = true
+	runner.refCount = 1
+	runner.refMu.Lock()
+	s.loadedMu.Lock()
+	s.loaded[req.model.ModelPath] = runner
+	slog.Info("loaded runners", "count", len(s.loaded))
+	s.loadedMu.Unlock()
+
+	go func() {
+		defer runner.refMu.Unlock()
+		if err = llama.WaitUntilRunning(req.ctx); err != nil {
+			slog.Error("error loading llama server", "error", err)
+			runner.refCount--
+			req.errCh <- err
+			slog.Debug("triggering expiration for failed load", "model", runner.model)
+			s.expiredCh <- runner
+			return
+		}
+		slog.Debug("finished setting up runner", "model", req.model.ModelPath)
+		runner.loading = false
+		go func() {
+			<-req.ctx.Done()
+			slog.Debug("context for request finished")
+			s.finishedReqCh <- req
+		}()
+		req.successCh <- runner
+	}()
+}
+
+func (s *Scheduler) updateFreeSpace(allGpus gpu.GpuInfoList) {
+	type predKey struct {
+		Library string
+		ID      string
+	}
+	predMap := map[predKey]uint64{} // Sum up the total predicted usage per GPU for all runners
+	s.loadedMu.Lock()
+	for _, r := range s.loaded {
+		r.refMu.Lock()
+		gpuIDs := make([]string, 0, len(r.gpus))
+		if r.llama != nil {
+
+			// TODO this should be broken down by GPU instead of assuming uniform spread
+			estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus))
+			for _, gpu := range r.gpus {
+				gpuIDs = append(gpuIDs, gpu.ID)
+			}
+			for _, gpu := range allGpus {
+				if slices.Contains(gpuIDs, gpu.ID) {
+					predMap[predKey{gpu.Library, gpu.ID}] += estimatedVRAMPerGPU
+				}
+			}
+		} else {
+			slog.Warn("unexpected nil runner reference, memory prediction may be incorrect")
+		}
+		r.refMu.Unlock()
+	}
+	s.loadedMu.Unlock()
+
+	// Now that we've summed up all the GPU usage predictions across all the loaded runners, update the gpu list
+	for i := range allGpus {
+		if p, ok := predMap[predKey{allGpus[i].Library, allGpus[i].ID}]; ok {
+			slog.Debug("gpu reported", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "available", format.HumanBytes2(allGpus[i].FreeMemory))
+			if p > allGpus[i].TotalMemory {
+				// Shouldn't happen
+				slog.Warn("predicted usage exceeds VRAM", "gpu", allGpus[i].ID, "totalMemory", allGpus[i].TotalMemory, "predicted", p)
+				allGpus[i].FreeMemory = 0
+			} else if (allGpus[i].TotalMemory - p) < allGpus[i].FreeMemory { // predicted free is smaller than reported free, use it
+				// TODO maybe we should just always trust our numbers, since cuda's free memory reporting is laggy
+				// and we might unload models we didn't actually need to.  The risk is if some other GPU intensive app is loaded
+				// after we start our first runner, then we'll never acount for that, so picking the smallest free value seems prudent.
+				allGpus[i].FreeMemory = allGpus[i].TotalMemory - p
+			}
+			slog.Info("updated VRAM", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "total", format.HumanBytes2(allGpus[i].TotalMemory), "available", format.HumanBytes2(allGpus[i].FreeMemory))
+		}
+	}
+}
+
+type runnerRef struct {
+	refMu sync.Mutex
+	// refCond   sync.Cond // Signaled on transition from 1 -> 0 refCount
+	refCount uint // prevent unloading if > 0
+	// unloading bool      // set to true when we are trying to unload the runner
+
+	llama         llm.LlamaServer
+	loading       bool            // True only during initial load, then false forever
+	gpus          gpu.GpuInfoList // Recorded at time of provisioning
+	estimatedVRAM uint64
+
+	sessionDuration time.Duration
+	expireTimer     *time.Timer
+
+	model      string
+	adapters   []string
+	projectors []string
+	*api.Options
+}
+
+// The refMu must already be held when calling unload
+func (runner *runnerRef) unload() {
+	if runner.llama != nil {
+		runner.llama.Close()
+	}
+	runner.llama = nil
+	runner.adapters = nil
+	runner.projectors = nil
+	runner.Options = nil
+	runner.gpus = nil
+}
+
+func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
+	slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
+	runner.refMu.Lock()
+	defer runner.refMu.Unlock()
+	// Ignore the NumGPU settings for comparison
+	optsExisting := runner.Options.Runner
+	optsExisting.NumGPU = -1
+	optsNew := req.opts.Runner
+	optsNew.NumGPU = -1
+	timeout := 10 * time.Second
+	if runner.loading {
+		timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems...
+	}
+	ctx, cancel := context.WithTimeout(ctx, timeout) // BUG -
+	defer cancel()
+	if !reflect.DeepEqual(runner.adapters, req.model.AdapterPaths) || // have the adapters changed?
+		!reflect.DeepEqual(runner.projectors, req.model.ProjectorPaths) || // have the projectors changed?
+		!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
+		runner.llama.Ping(ctx) != nil {
+		return true
+	}
+	return false
+}
+
+type ByDuration []*runnerRef
+
+func (a ByDuration) Len() int      { return len(a) }
+func (a ByDuration) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
+func (a ByDuration) Less(i, j int) bool {
+	// uint64 to turn negative time (never unload) to largest
+	return uint64(a[i].sessionDuration) < uint64(a[j].sessionDuration)
+}
+
+// TODO - future consideration to pick runners based on size
+// type BySize []*runnerRef
+// func (a BySize) Len() int           { return len(a) }
+// func (a BySize) Swap(i, j int)      { a[i], a[j] = a[j], a[i] }
+// func (a BySize) Less(i, j int) bool { return a[i].estimatedVRAM < a[j].estimatedVRAM }
+
+// pickBestFitGPUs will try to find the optimal placement of the model in the available GPUs where the model fully fits
+// If the model can not be fit fully within the available GPU(s) nil is returned
+func pickBestFitGPUs(req *LlmRequest, gpus gpu.GpuInfoList) gpu.GpuInfoList {
+	var estimatedVRAM uint64
+	for _, gl := range gpus.ByLibrary() {
+		var ok bool
+		sgl := append(make(gpu.GpuInfoList, 0, len(gl)), gl...)
+
+		// TODO - potentially sort by performance capability, existing models loaded, etc.
+		// Note: at present, this will favor more VRAM over faster GPU speed in mixed setups
+		sort.Sort(sort.Reverse(gpu.ByFreeMemory(sgl)))
+
+		// First attempt to fit the model into a single GPU
+		for _, g := range sgl {
+			if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, req.ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
+				slog.Debug("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
+				return []gpu.GpuInfo{g}
+			}
+		}
+
+		// TODO future refinements
+		// - if multiple Libraries, see if any single GPU in any Library will fit
+		// - try subsets of GPUs instead of just falling back to 1 or all in a family
+
+		// Now try all the GPUs
+		if ok, estimatedVRAM = llm.PredictServerFit(gl, req.ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
+			slog.Debug("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", gl[0].Library, "required", format.HumanBytes2(estimatedVRAM))
+			return gl
+		}
+	}
+	return nil
+}
+
+// findRunnerToUnload finds a runner to unload to make room for a new model
+func (s *Scheduler) findRunnerToUnload(req *LlmRequest) *runnerRef {
+	s.loadedMu.Lock()
+	runnerList := make([]*runnerRef, 0, len(s.loaded))
+	for _, r := range s.loaded {
+		runnerList = append(runnerList, r)
+	}
+	s.loadedMu.Unlock()
+
+	// In the future we can enhance the algorithm to be smarter about picking the optimal runner to unload
+	// e.g., if we have multiple options, will one make room for the request?
+	sort.Sort(ByDuration(runnerList))
+
+	// First try to find a runner that's already idle
+	for _, runner := range runnerList {
+		runner.refMu.Lock()
+		rc := runner.refCount
+		runner.refMu.Unlock()
+		if rc == 0 {
+			slog.Debug("found an idle runner to unload")
+			return runner
+		}
+	}
+	// None appear idle, just wait for the one with the shortest duration
+	slog.Debug("no idle runners, picking the shortest duration", "count", len(runnerList))
+	return runnerList[0]
+}
+
+func (s *Scheduler) unloadAllRunners() {
+	s.loadedMu.Lock()
+	defer s.loadedMu.Unlock()
+	for model, runner := range s.loaded {
+		if runner.llama != nil {
+			slog.Debug("shutting down runner", "model", model)
+			runner.llama.Close()
+		}
+	}
+}

+ 553 - 0
server/sched_test.go

@@ -0,0 +1,553 @@
+package server
+
+import (
+	"bytes"
+	"context"
+	"encoding/binary"
+	"fmt"
+	"log/slog"
+	"os"
+	"testing"
+	"time"
+
+	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/app/lifecycle"
+	"github.com/ollama/ollama/format"
+	"github.com/ollama/ollama/gpu"
+	"github.com/ollama/ollama/llm"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func init() {
+	os.Setenv("OLLAMA_DEBUG", "1")
+	lifecycle.InitLogging()
+}
+
+func TestInitScheduler(t *testing.T) {
+	ctx, done := context.WithCancel(context.Background())
+	defer done()
+	initialMax := loadedMax
+	s := InitScheduler(ctx)
+	require.Equal(t, initialMax, loadedMax)
+	require.NotNil(t, s.loaded)
+
+	os.Setenv("OLLAMA_MAX_LOADED_MODELS", "blue")
+	s = InitScheduler(ctx)
+	require.Equal(t, initialMax, loadedMax)
+	require.NotNil(t, s.loaded)
+
+	os.Setenv("OLLAMA_MAX_LOADED_MODELS", "0")
+	s = InitScheduler(ctx)
+	require.Equal(t, 0, loadedMax)
+	require.NotNil(t, s.loaded)
+}
+
+func TestLoad(t *testing.T) {
+	ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
+	defer done()
+	s := InitScheduler(ctx)
+	req := &LlmRequest{
+		ctx:             ctx,
+		model:           &Model{ModelPath: "foo"},
+		successCh:       make(chan *runnerRef, 1),
+		errCh:           make(chan error, 1),
+		sessionDuration: 2,
+	}
+	// Fail to load model first
+	s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
+		return nil, fmt.Errorf("something failed to load model blah")
+	}
+	gpus := gpu.GpuInfoList{}
+	s.load(req, gpus)
+	require.Len(t, req.successCh, 0)
+	require.Len(t, req.errCh, 1)
+	require.Len(t, s.loaded, 0)
+	err := <-req.errCh
+	require.Contains(t, err.Error(), "this model may be incompatible")
+
+	server := &mockLlm{estimatedVRAM: 10}
+	s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
+		return server, nil
+	}
+	s.load(req, gpus)
+	select {
+	case err := <-req.errCh:
+		require.NoError(t, err)
+	case resp := <-req.successCh:
+		require.Equal(t, uint64(10), resp.estimatedVRAM)
+		require.Equal(t, uint(1), resp.refCount)
+		require.Len(t, s.loaded, 1)
+	}
+
+	req.model.ModelPath = "dummy_model_path"
+	server.waitResp = fmt.Errorf("wait failure")
+	s.load(req, gpus)
+	select {
+	case err := <-req.errCh:
+		require.Contains(t, err.Error(), "wait failure")
+	case resp := <-req.successCh:
+		t.Errorf("unexpected success %v", resp)
+	}
+	runner := s.loaded["dummy_model_path"]
+	require.NotNil(t, runner)
+	require.Equal(t, uint(0), runner.refCount)
+	require.Len(t, s.expiredCh, 1)
+}
+
+type bundle struct {
+	ctx     context.Context //nolint:containedctx
+	ctxDone func()
+	srv     *mockLlm
+	req     *LlmRequest
+}
+
+func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
+	return scenario.srv, nil
+}
+
+func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle {
+	scenario := &bundle{}
+	scenario.ctx, scenario.ctxDone = context.WithCancel(ctx)
+	t.Helper()
+
+	f, err := os.CreateTemp(t.TempDir(), modelName)
+	assert.Nil(t, err)
+	defer f.Close()
+
+	gguf := llm.NewGGUFV3(binary.LittleEndian)
+	err = gguf.Encode(f, llm.KV{
+		"general.architecture":          "llama",
+		"general.name":                  "name",
+		"llama.context_length":          uint32(32),
+		"llama.embedding_length":        uint32(4096),
+		"llama.block_count":             uint32(1),
+		"llama.attention.head_count":    uint32(32),
+		"llama.attention.head_count_kv": uint32(32),
+		"tokenizer.ggml.tokens":         []string{" "},
+		"tokenizer.ggml.scores":         []float32{0},
+		"tokenizer.ggml.token_type":     []int32{0},
+	}, []llm.Tensor{
+		{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
+	})
+	assert.Nil(t, err)
+	fname := f.Name()
+	model := &Model{Name: modelName, ModelPath: fname}
+	ggml, err := llm.LoadModel(model.ModelPath)
+	require.NoError(t, err)
+	scenario.req = &LlmRequest{
+		ctx:             scenario.ctx,
+		model:           model,
+		ggml:            ggml,
+		sessionDuration: 5 * time.Millisecond,
+		successCh:       make(chan *runnerRef, 1),
+		errCh:           make(chan error, 1),
+	}
+	scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM}
+	return scenario
+}
+
+func TestRequests(t *testing.T) {
+	ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
+	defer done()
+
+	// Same model, same request
+	scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
+	scenario1a.req.sessionDuration = 0
+	scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
+	scenario1b.req.model = scenario1a.req.model
+	scenario1b.req.ggml = scenario1a.req.ggml
+	scenario1b.req.sessionDuration = 0
+
+	// simple reload of same model
+	scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
+	scenario2a.req.model = scenario1a.req.model
+	scenario2a.req.ggml = scenario1a.req.ggml
+
+	// Multiple loaded models
+	scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
+	scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
+	scenario3c := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
+
+	s := InitScheduler(ctx)
+	s.getGpuFn = func() gpu.GpuInfoList {
+		g := gpu.GpuInfo{Library: "metal"}
+		g.TotalMemory = 24 * format.GigaByte
+		g.FreeMemory = 12 * format.GigaByte
+		return []gpu.GpuInfo{g}
+	}
+	s.newServerFn = scenario1a.newServer
+	slog.Info("scenario1a")
+	s.pendingReqCh <- scenario1a.req
+	require.Len(t, s.pendingReqCh, 1)
+	s.Run(ctx)
+	select {
+	case resp := <-scenario1a.req.successCh:
+		require.Equal(t, resp.llama, scenario1a.srv)
+		require.Len(t, s.pendingReqCh, 0)
+		require.Len(t, scenario1a.req.errCh, 0)
+	case <-ctx.Done():
+		t.Errorf("timeout")
+	}
+
+	// Same runner as first request due to not needing a reload
+	s.newServerFn = scenario1b.newServer
+	slog.Info("scenario1b")
+	s.pendingReqCh <- scenario1b.req
+	select {
+	case resp := <-scenario1b.req.successCh:
+		require.Equal(t, resp.llama, scenario1a.srv)
+		require.Len(t, s.pendingReqCh, 0)
+		require.Len(t, scenario1b.req.errCh, 0)
+	case <-ctx.Done():
+		t.Errorf("timeout")
+	}
+
+	// Trigger a reload
+	s.newServerFn = scenario2a.newServer
+	scenario2a.req.model.AdapterPaths = []string{"new"}
+	slog.Info("scenario2a")
+	s.pendingReqCh <- scenario2a.req
+	// finish first two requests, so model can reload
+	time.Sleep(1 * time.Millisecond)
+	scenario1a.ctxDone()
+	scenario1b.ctxDone()
+	select {
+	case resp := <-scenario2a.req.successCh:
+		require.Equal(t, resp.llama, scenario2a.srv)
+		require.Len(t, s.pendingReqCh, 0)
+		require.Len(t, scenario2a.req.errCh, 0)
+	case <-ctx.Done():
+		t.Errorf("timeout")
+	}
+
+	loadedMax = 1
+	s.newServerFn = scenario3a.newServer
+	slog.Info("scenario3a")
+	s.pendingReqCh <- scenario3a.req
+	// finish prior request, so new model can load
+	time.Sleep(1 * time.Millisecond)
+	scenario2a.ctxDone()
+	select {
+	case resp := <-scenario3a.req.successCh:
+		require.Equal(t, resp.llama, scenario3a.srv)
+		require.Len(t, s.pendingReqCh, 0)
+		require.Len(t, scenario3a.req.errCh, 0)
+	case <-ctx.Done():
+		t.Errorf("timeout")
+	}
+	require.Len(t, s.loaded, 1)
+
+	loadedMax = 0
+	s.newServerFn = scenario3b.newServer
+	slog.Info("scenario3b")
+	s.pendingReqCh <- scenario3b.req
+	select {
+	case resp := <-scenario3b.req.successCh:
+		require.Equal(t, resp.llama, scenario3b.srv)
+		require.Len(t, s.pendingReqCh, 0)
+		require.Len(t, scenario3b.req.errCh, 0)
+	case <-ctx.Done():
+		t.Errorf("timeout")
+	}
+	require.Len(t, s.loaded, 2)
+
+	// Try to load a model that wont fit
+	s.newServerFn = scenario3c.newServer
+	slog.Info("scenario3c")
+	require.Len(t, s.loaded, 2)
+	scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
+	time.Sleep(2 * time.Millisecond)
+	s.pendingReqCh <- scenario3c.req
+	// finish prior request, so new model can load
+	time.Sleep(6 * time.Millisecond)
+	require.Len(t, s.loaded, 1)
+	scenario3b.ctxDone()
+	select {
+	case resp := <-scenario3c.req.successCh:
+		require.Equal(t, resp.llama, scenario3c.srv)
+		require.Len(t, s.pendingReqCh, 0)
+		require.Len(t, scenario3c.req.errCh, 0)
+	case <-ctx.Done():
+		t.Errorf("timeout")
+	}
+	require.Len(t, s.loaded, 1)
+}
+
+func TestGetRunner(t *testing.T) {
+	ctx, done := context.WithTimeout(context.Background(), 20*time.Millisecond)
+	defer done()
+
+	// Same model, same request
+	scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
+	scenario1a.req.sessionDuration = 0
+	scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
+	scenario1b.req.sessionDuration = 0
+	scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
+	scenario1c.req.sessionDuration = 0
+	maxQueuedRequests = 1
+	s := InitScheduler(ctx)
+	s.getGpuFn = func() gpu.GpuInfoList {
+		g := gpu.GpuInfo{Library: "metal"}
+		g.TotalMemory = 24 * format.GigaByte
+		g.FreeMemory = 12 * format.GigaByte
+		return []gpu.GpuInfo{g}
+	}
+	s.newServerFn = scenario1a.newServer
+	slog.Info("scenario1a")
+	successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
+	require.Len(t, s.pendingReqCh, 1)
+	slog.Info("scenario1b")
+	successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration)
+	require.Len(t, s.pendingReqCh, 1)
+	require.Len(t, successCh1b, 0)
+	require.Len(t, errCh1b, 1)
+	err := <-errCh1b
+	require.Contains(t, err.Error(), "server busy")
+	s.Run(ctx)
+	select {
+	case resp := <-successCh1a:
+		require.Equal(t, resp.llama, scenario1a.srv)
+		require.Len(t, s.pendingReqCh, 0)
+		require.Len(t, errCh1a, 0)
+	case <-ctx.Done():
+		t.Errorf("timeout")
+	}
+	scenario1a.ctxDone()
+	require.Len(t, s.loaded, 1)
+
+	scenario1c.req.model.ModelPath = "bad path"
+	slog.Info("scenario1c")
+	successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration)
+	require.Len(t, s.pendingReqCh, 0)
+	require.Len(t, successCh1c, 0)
+	require.Len(t, errCh1c, 1)
+	err = <-errCh1c
+	require.Contains(t, err.Error(), "bad path")
+	scenario1b.ctxDone()
+
+	time.Sleep(5 * time.Millisecond)
+	require.Len(t, s.loaded, 0)
+}
+
+// TODO - add one scenario that triggers the bogus finished event with positive ref count
+func TestPrematureExpired(t *testing.T) {
+	ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
+	defer done()
+
+	// Same model, same request
+	scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
+	s := InitScheduler(ctx)
+	s.getGpuFn = func() gpu.GpuInfoList {
+		g := gpu.GpuInfo{Library: "metal"}
+		g.TotalMemory = 24 * format.GigaByte
+		g.FreeMemory = 12 * format.GigaByte
+		return []gpu.GpuInfo{g}
+	}
+	s.newServerFn = scenario1a.newServer
+	successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
+	require.Len(t, s.pendingReqCh, 1)
+	s.Run(ctx)
+	select {
+	case resp := <-successCh1a:
+		require.Equal(t, resp.llama, scenario1a.srv)
+		require.Len(t, s.pendingReqCh, 0)
+		require.Len(t, errCh1a, 0)
+		require.Len(t, s.loaded, 1)
+		slog.Info("sending premature expired event now")
+		s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
+	case <-ctx.Done():
+		t.Errorf("timeout")
+	}
+	time.Sleep(scenario1a.req.sessionDuration)
+	scenario1a.ctxDone()
+	time.Sleep(20 * time.Millisecond)
+	require.LessOrEqual(t, len(s.finishedReqCh), 1)
+	time.Sleep(10 * time.Millisecond)
+	require.Len(t, s.finishedReqCh, 0)
+	require.Len(t, s.loaded, 0)
+
+	// also shouldn't happen in real life
+	s.finishedReqCh <- scenario1a.req
+	time.Sleep(5 * time.Millisecond)
+}
+
+func TestUseLoadedRunner(t *testing.T) {
+	ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
+	req := &LlmRequest{
+		ctx:             ctx,
+		successCh:       make(chan *runnerRef, 1),
+		sessionDuration: 2,
+	}
+	finished := make(chan *LlmRequest)
+	llm1 := &mockLlm{}
+	r1 := &runnerRef{llama: llm1, sessionDuration: 1}
+	req.useLoadedRunner(r1, finished)
+	require.Equal(t, uint(1), r1.refCount)
+	require.Equal(t, time.Duration(2), r1.sessionDuration)
+	select {
+	case success := <-req.successCh:
+		require.Equal(t, r1, success)
+	case <-ctx.Done():
+		t.Errorf("timeout")
+	}
+	done()
+	fin := <-finished
+	require.Equal(t, req, fin)
+}
+
+func TestUpdateFreeSpace(t *testing.T) {
+	ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
+	defer done()
+	gpus := gpu.GpuInfoList{
+		{
+			Library: "a",
+			ID:      "1",
+		},
+		{
+			Library: "a",
+			ID:      "2",
+		},
+	}
+	gpus[0].TotalMemory = 1000
+	gpus[0].FreeMemory = 900
+	gpus[1].TotalMemory = 2000
+	gpus[1].FreeMemory = 1900
+	llm1 := &mockLlm{estimatedVRAM: 100}
+	llm2 := &mockLlm{estimatedVRAM: 200}
+	r1 := &runnerRef{llama: llm1, gpus: gpus}
+	r2 := &runnerRef{llama: llm2, gpus: gpus}
+
+	s := InitScheduler(ctx)
+	s.loaded["a"] = r1
+	s.loaded["b"] = r2
+
+	s.updateFreeSpace(gpus)
+	require.Equal(t, uint64(850), gpus[0].FreeMemory)
+	require.Equal(t, uint64(1850), gpus[1].FreeMemory)
+
+}
+
+func TestFindRunnerToUnload(t *testing.T) {
+	ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
+	defer done()
+	req := &LlmRequest{ctx: ctx}
+	r1 := &runnerRef{refCount: 1, sessionDuration: 1}
+	r2 := &runnerRef{sessionDuration: 2}
+
+	s := InitScheduler(ctx)
+	s.loaded["a"] = r1
+	s.loaded["b"] = r2
+
+	resp := s.findRunnerToUnload(req)
+	require.Equal(t, r2, resp)
+	r2.refCount = 1
+	resp = s.findRunnerToUnload(req)
+	require.Equal(t, r1, resp)
+
+}
+
+func TestNeedsReload(t *testing.T) {
+	ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
+	defer done()
+
+	llm := &mockLlm{}
+	runner := &runnerRef{
+		adapters:   []string{"adapter1"},
+		projectors: []string{"projector1"},
+		Options:    &api.Options{},
+		llama:      llm,
+	}
+	req := &LlmRequest{
+		model: &Model{
+			AdapterPaths:   []string{"adapter2"},
+			ProjectorPaths: []string{"projector2"},
+		},
+		opts: api.Options{},
+	}
+	resp := runner.needsReload(ctx, req)
+	require.True(t, resp)
+	req.model.AdapterPaths = runner.adapters
+	resp = runner.needsReload(ctx, req)
+	require.True(t, resp)
+	req.model.ProjectorPaths = runner.projectors
+	runner.loading = true
+	req.opts.NumBatch = 1234
+	resp = runner.needsReload(ctx, req)
+	require.True(t, resp)
+	req.opts.NumBatch = runner.Options.NumBatch
+	llm.pingResp = fmt.Errorf("foo")
+	resp = runner.needsReload(ctx, req)
+	require.True(t, resp)
+	llm.pingResp = nil
+	resp = runner.needsReload(ctx, req)
+	require.False(t, resp)
+	req.opts.NumGPU = 99
+	resp = runner.needsReload(ctx, req)
+	require.False(t, resp)
+}
+
+func TestUnloadAllRunners(t *testing.T) {
+	ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
+	defer done()
+
+	llm1 := &mockLlm{}
+	llm2 := &mockLlm{}
+	s := InitScheduler(ctx)
+	s.unloadAllRunners()
+
+	r1 := &runnerRef{llama: llm1}
+	r2 := &runnerRef{llama: llm2}
+
+	s.loaded["a"] = r1
+	s.loaded["b"] = r2
+	s.unloadAllRunners()
+
+	require.True(t, llm1.closeCalled)
+	require.True(t, llm2.closeCalled)
+}
+
+func TestUnload(t *testing.T) {
+	llm1 := &mockLlm{}
+	r1 := &runnerRef{llama: llm1}
+	r2 := &runnerRef{adapters: []string{"A"}}
+	r1.unload()
+	require.True(t, llm1.closeCalled)
+	r2.unload()
+	require.Nil(t, r2.adapters)
+}
+
+type mockLlm struct {
+	pingResp          error
+	waitResp          error
+	completionResp    error
+	embeddingResp     []float64
+	embeddingRespErr  error
+	tokenizeResp      []int
+	tokenizeRespErr   error
+	detokenizeResp    string
+	detonekizeRespErr error
+	closeResp         error
+	closeCalled       bool
+	estimatedVRAM     uint64
+}
+
+func (s *mockLlm) Ping(ctx context.Context) error             { return s.pingResp }
+func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitResp }
+func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
+	return s.completionResp
+}
+func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
+	return s.embeddingResp, s.embeddingRespErr
+}
+func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
+	return s.tokenizeResp, s.tokenizeRespErr
+}
+func (s *mockLlm) Detokenize(ctx context.Context, tokens []int) (string, error) {
+	return s.detokenizeResp, s.detonekizeRespErr
+}
+func (s *mockLlm) Close() error {
+	s.closeCalled = true
+	return s.closeResp
+}
+func (s *mockLlm) EstimatedVRAM() uint64 { return s.estimatedVRAM }