Przeglądaj źródła

Request and model concurrency

This change adds support for multiple concurrent requests, as well as
loading multiple models by spawning multiple runners. The default
settings are currently set at 1 concurrent request per model and only 1
loaded model at a time, but these can be adjusted by setting
OLLAMA_NUM_PARALLEL and OLLAMA_MAX_LOADED_MODELS.
Daniel Hiltgen 1 rok temu
rodzic
commit
34b9db5afc

+ 7 - 0
api/client.go

@@ -91,6 +91,13 @@ func ClientFromEnvironment() (*Client, error) {
 	}, nil
 	}, nil
 }
 }
 
 
+func NewClient(base *url.URL, http *http.Client) *Client {
+	return &Client{
+		base: base,
+		http: http,
+	}
+}
+
 func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
 func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
 	var reqBody io.Reader
 	var reqBody io.Reader
 	var data []byte
 	var data []byte

+ 1 - 0
format/bytes.go

@@ -15,6 +15,7 @@ const (
 
 
 	KibiByte = Byte * 1024
 	KibiByte = Byte * 1024
 	MebiByte = KibiByte * 1024
 	MebiByte = KibiByte * 1024
+	GibiByte = MebiByte * 1024
 )
 )
 
 
 func HumanBytes(b int64) string {
 func HumanBytes(b int64) string {

+ 56 - 14
gpu/amd_common.go

@@ -7,7 +7,7 @@ import (
 	"log/slog"
 	"log/slog"
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
-	"strconv"
+	"runtime"
 	"strings"
 	"strings"
 )
 )
 
 
@@ -35,22 +35,64 @@ func GetSupportedGFX(libDir string) ([]string, error) {
 	return ret, nil
 	return ret, nil
 }
 }
 
 
-func amdSetVisibleDevices(ids []int, skip map[int]interface{}) {
-	// Set the visible devices if not already set
-	// TODO - does sort order matter?
-	devices := []string{}
-	for i := range ids {
-		if _, skipped := skip[i]; skipped {
+func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
+	ids := []string{}
+	for _, info := range gpuInfo {
+		if info.Library != "rocm" {
+			// TODO shouldn't happen if things are wired correctly...
+			slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
 			continue
 			continue
 		}
 		}
-		devices = append(devices, strconv.Itoa(i))
+		ids = append(ids, info.ID)
 	}
 	}
+	return "HIP_VISIBLE_DEVICES", strings.Join(ids, ",")
+}
 
 
-	val := strings.Join(devices, ",")
-	err := os.Setenv("HIP_VISIBLE_DEVICES", val)
-	if err != nil {
-		slog.Warn(fmt.Sprintf("failed to set env: %s", err))
-	} else {
-		slog.Info("Setting HIP_VISIBLE_DEVICES=" + val)
+func commonAMDValidateLibDir() (string, error) {
+	// We try to favor system paths first, so that we can wire up the subprocess to use
+	// the system version.  Only use our bundled version if the system version doesn't work
+	// This gives users a more recovery options if versions have subtle problems at runtime
+
+	// Prefer explicit HIP env var
+	hipPath := os.Getenv("HIP_PATH")
+	if hipPath != "" {
+		hipLibDir := filepath.Join(hipPath, "bin")
+		if rocmLibUsable(hipLibDir) {
+			slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
+			return hipLibDir, nil
+		}
+	}
+
+	// Scan the LD_LIBRARY_PATH or PATH
+	pathEnv := "LD_LIBRARY_PATH"
+	if runtime.GOOS == "windows" {
+		pathEnv = "PATH"
+	}
+
+	paths := os.Getenv(pathEnv)
+	for _, path := range filepath.SplitList(paths) {
+		d, err := filepath.Abs(path)
+		if err != nil {
+			continue
+		}
+		if rocmLibUsable(d) {
+			return d, nil
+		}
+	}
+
+	// Well known location(s)
+	if rocmLibUsable(RocmStandardLocation) {
+		return RocmStandardLocation, nil
+	}
+
+	// Installer payload location if we're running the installed binary
+	exe, err := os.Executable()
+	if err == nil {
+		rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
+		if rocmLibUsable(rocmTargetDir) {
+			slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
+			return rocmTargetDir, nil
+		}
 	}
 	}
+	return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
 }
 }

+ 2 - 2
gpu/amd_hip_windows.go

@@ -69,7 +69,7 @@ func NewHipLib() (*HipLib, error) {
 func (hl *HipLib) Release() {
 func (hl *HipLib) Release() {
 	err := windows.FreeLibrary(hl.dll)
 	err := windows.FreeLibrary(hl.dll)
 	if err != nil {
 	if err != nil {
-		slog.Warn(fmt.Sprintf("failed to unload amdhip64.dll: %s", err))
+		slog.Warn("failed to unload amdhip64.dll", "error", err)
 	}
 	}
 	hl.dll = 0
 	hl.dll = 0
 }
 }
@@ -98,7 +98,7 @@ func (hl *HipLib) HipGetDeviceCount() int {
 		return 0
 		return 0
 	}
 	}
 	if status != hipSuccess {
 	if status != hipSuccess {
-		slog.Warn(fmt.Sprintf("failed call to hipGetDeviceCount: %d %s", status, err))
+		slog.Warn("failed call to hipGetDeviceCount", "status", status, "error", err)
 	}
 	}
 	return count
 	return count
 }
 }

+ 174 - 285
gpu/amd_linux.go

@@ -11,6 +11,8 @@ import (
 	"slices"
 	"slices"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
+
+	"github.com/ollama/ollama/format"
 )
 )
 
 
 // Discovery logic for AMD/ROCm GPUs
 // Discovery logic for AMD/ROCm GPUs
@@ -24,9 +26,6 @@ const (
 	GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
 	GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
 	GPUUsedMemoryFileGlob  = "mem_banks/*/used_memory"
 	GPUUsedMemoryFileGlob  = "mem_banks/*/used_memory"
 	RocmStandardLocation   = "/opt/rocm/lib"
 	RocmStandardLocation   = "/opt/rocm/lib"
-
-	// TODO find a better way to detect iGPU instead of minimum memory
-	IGPUMemLimit = 1024 * 1024 * 1024 // 512G is what they typically report, so anything less than 1G must be iGPU
 )
 )
 
 
 var (
 var (
@@ -35,14 +34,11 @@ var (
 )
 )
 
 
 // Gather GPU information from the amdgpu driver if any supported GPUs are detected
 // Gather GPU information from the amdgpu driver if any supported GPUs are detected
-// HIP_VISIBLE_DEVICES will be set if we detect a mix of unsupported and supported devices
-// and the user hasn't already set this variable
-func AMDGetGPUInfo(resp *GpuInfo) {
-	// TODO - DRY this out with windows
+func AMDGetGPUInfo() []GpuInfo {
+	resp := []GpuInfo{}
 	if !AMDDetected() {
 	if !AMDDetected() {
-		return
+		return resp
 	}
 	}
-	skip := map[int]interface{}{}
 
 
 	// Opportunistic logging of driver version to aid in troubleshooting
 	// Opportunistic logging of driver version to aid in troubleshooting
 	ver, err := AMDDriverVersion()
 	ver, err := AMDDriverVersion()
@@ -50,160 +46,117 @@ func AMDGetGPUInfo(resp *GpuInfo) {
 		slog.Info("AMD Driver: " + ver)
 		slog.Info("AMD Driver: " + ver)
 	} else {
 	} else {
 		// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
 		// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
-		slog.Warn(fmt.Sprintf("ollama recommends running the https://www.amd.com/en/support/linux-drivers: %s", err))
+		slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err)
 	}
 	}
 
 
-	// If the user has specified exactly which GPUs to use, look up their memory
-	visibleDevices := os.Getenv("HIP_VISIBLE_DEVICES")
-	if visibleDevices != "" {
-		ids := []int{}
-		for _, idStr := range strings.Split(visibleDevices, ",") {
-			id, err := strconv.Atoi(idStr)
-			if err != nil {
-				slog.Warn(fmt.Sprintf("malformed HIP_VISIBLE_DEVICES=%s %s", visibleDevices, err))
-			} else {
-				ids = append(ids, id)
-			}
-		}
-		amdProcMemLookup(resp, nil, ids)
-		return
+	// Determine if the user has already pre-selected which GPUs to look at, then ignore the others
+	var visibleDevices []string
+	hipVD := os.Getenv("HIP_VISIBLE_DEVICES")   // zero based index only
+	rocrVD := os.Getenv("ROCR_VISIBLE_DEVICES") // zero based index or UUID, but consumer cards seem to not support UUID
+	gpuDO := os.Getenv("GPU_DEVICE_ORDINAL")    // zero based index
+	switch {
+	// TODO is this priorty order right?
+	case hipVD != "":
+		visibleDevices = strings.Split(hipVD, ",")
+	case rocrVD != "":
+		visibleDevices = strings.Split(rocrVD, ",")
+		// TODO - since we don't yet support UUIDs, consider detecting and reporting here
+		// all our test systems show GPU-XX indicating UUID is not supported
+	case gpuDO != "":
+		visibleDevices = strings.Split(gpuDO, ",")
 	}
 	}
 
 
-	// Gather GFX version information from all detected cards
-	gfx := AMDGFXVersions()
-	verStrings := []string{}
-	for i, v := range gfx {
-		verStrings = append(verStrings, v.ToGFXString())
-		if v.Major == 0 {
-			// Silently skip CPUs
-			skip[i] = struct{}{}
+	gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
+	var supported []string
+	libDir := ""
+
+	// The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract
+	// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
+	matches, _ := filepath.Glob(GPUPropertiesFileGlob)
+	cpuCount := 0
+	for _, match := range matches {
+		slog.Debug("evaluating amdgpu node " + match)
+		fp, err := os.Open(match)
+		if err != nil {
+			slog.Debug("failed to open sysfs node", "file", match, "error", err)
 			continue
 			continue
 		}
 		}
-		if v.Major < 9 {
-			// TODO consider this a build-time setting if we can support 8xx family GPUs
-			slog.Warn(fmt.Sprintf("amdgpu [%d] too old %s", i, v.ToGFXString()))
-			skip[i] = struct{}{}
+		defer fp.Close()
+		nodeID, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
+		if err != nil {
+			slog.Debug("failed to parse node ID", "error", err)
+			continue
 		}
 		}
-	}
-	slog.Info(fmt.Sprintf("detected amdgpu versions %v", verStrings))
 
 
-	// Abort if all GPUs are skipped
-	if len(skip) >= len(gfx) {
-		slog.Info("all detected amdgpus are skipped, falling back to CPU")
-		return
-	}
+		scanner := bufio.NewScanner(fp)
+		isCPU := false
+		var major, minor, patch uint64
+		for scanner.Scan() {
+			line := strings.TrimSpace(scanner.Text())
+			// Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
+			if strings.HasPrefix(line, "gfx_target_version") {
+				ver := strings.Fields(line)
 
 
-	// If we got this far, then we have at least 1 GPU that's a ROCm candidate, so make sure we have a lib
-	libDir, err := AMDValidateLibDir()
-	if err != nil {
-		slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
-		return
-	}
+				// Detect CPUs
+				if len(ver) == 2 && ver[1] == "0" {
+					slog.Debug("detected CPU " + match)
+					isCPU = true
+					break
+				}
 
 
-	updateLibPath(libDir)
+				if len(ver) != 2 || len(ver[1]) < 5 {
+					slog.Warn("malformed "+match, "gfx_target_version", line)
+					// If this winds up being a CPU, our offsets may be wrong
+					continue
+				}
+				l := len(ver[1])
+				var err1, err2, err3 error
+				patch, err1 = strconv.ParseUint(ver[1][l-2:l], 10, 32)
+				minor, err2 = strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
+				major, err3 = strconv.ParseUint(ver[1][:l-4], 10, 32)
+				if err1 != nil || err2 != nil || err3 != nil {
+					slog.Debug("malformed int " + line)
+					continue
+				}
+			}
 
 
-	gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
-	if gfxOverride == "" {
-		supported, err := GetSupportedGFX(libDir)
-		if err != nil {
-			slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
-			return
+			// TODO - any other properties we want to extract and record?
+			// vendor_id + device_id -> pci lookup for "Name"
+			// Other metrics that may help us understand relative performance between multiple GPUs
 		}
 		}
-		slog.Debug(fmt.Sprintf("rocm supported GPU types %v", supported))
 
 
-		for i, v := range gfx {
-			if !slices.Contains[[]string, string](supported, v.ToGFXString()) {
-				slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported))
-				// TODO - consider discrete markdown just for ROCM troubleshooting?
-				slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
-				skip[i] = struct{}{}
-			} else {
-				slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString()))
-			}
+		if isCPU {
+			cpuCount++
+			continue
 		}
 		}
-	} else {
-		slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
-	}
 
 
-	if len(skip) >= len(gfx) {
-		slog.Info("all detected amdgpus are skipped, falling back to CPU")
-		return
-	}
+		// CPUs are always first in the list
+		gpuID := nodeID - cpuCount
 
 
-	ids := make([]int, len(gfx))
-	i := 0
-	for k := range gfx {
-		ids[i] = k
-		i++
-	}
-	amdProcMemLookup(resp, skip, ids)
-	if resp.memInfo.DeviceCount == 0 {
-		return
-	}
-	if len(skip) > 0 {
-		amdSetVisibleDevices(ids, skip)
-	}
-}
-
-func updateLibPath(libDir string) {
-	ldPaths := []string{}
-	if val, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
-		ldPaths = strings.Split(val, ":")
-	}
-	for _, d := range ldPaths {
-		if d == libDir {
-			return
-		}
-	}
-	val := strings.Join(append(ldPaths, libDir), ":")
-	slog.Debug("updated lib path", "LD_LIBRARY_PATH", val)
-	os.Setenv("LD_LIBRARY_PATH", val)
-}
-
-// Walk the sysfs nodes for the available GPUs and gather information from them
-// skipping over any devices in the skip map
-func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
-	resp.memInfo.DeviceCount = 0
-	resp.memInfo.TotalMemory = 0
-	resp.memInfo.FreeMemory = 0
-	slog.Debug("discovering VRAM for amdgpu devices")
-	if len(ids) == 0 {
-		entries, err := os.ReadDir(AMDNodesSysfsDir)
-		if err != nil {
-			slog.Warn(fmt.Sprintf("failed to read amdgpu sysfs %s - %s", AMDNodesSysfsDir, err))
-			return
-		}
-		for _, node := range entries {
-			if !node.IsDir() {
-				continue
-			}
-			id, err := strconv.Atoi(node.Name())
-			if err != nil {
-				slog.Warn("malformed amdgpu sysfs node id " + node.Name())
-				continue
-			}
-			ids = append(ids, id)
+		// Shouldn't happen, but just in case...
+		if gpuID < 0 {
+			slog.Error("unexpected amdgpu sysfs data resulted in negative GPU ID, please set OLLAMA_DEBUG=1 and report an issue")
+			return []GpuInfo{}
 		}
 		}
-	}
-	slog.Debug(fmt.Sprintf("amdgpu devices %v", ids))
 
 
-	for _, id := range ids {
-		if _, skipped := skip[id]; skipped {
+		if int(major) < RocmComputeMin {
+			slog.Warn(fmt.Sprintf("amdgpu too old gfx%d%d%d", major, minor, patch), "gpu", gpuID)
 			continue
 			continue
 		}
 		}
+
+		// Look up the memory for the current node
 		totalMemory := uint64(0)
 		totalMemory := uint64(0)
 		usedMemory := uint64(0)
 		usedMemory := uint64(0)
-		// Adjust for sysfs vs HIP ids
-		propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id+1), GPUTotalMemoryFileGlob)
+		propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUTotalMemoryFileGlob)
 		propFiles, err := filepath.Glob(propGlob)
 		propFiles, err := filepath.Glob(propGlob)
 		if err != nil {
 		if err != nil {
-			slog.Warn(fmt.Sprintf("error looking up total GPU memory: %s %s", propGlob, err))
+			slog.Warn("error looking up total GPU memory", "glob", propGlob, "error", err)
 		}
 		}
 		// 1 or more memory banks - sum the values of all of them
 		// 1 or more memory banks - sum the values of all of them
 		for _, propFile := range propFiles {
 		for _, propFile := range propFiles {
 			fp, err := os.Open(propFile)
 			fp, err := os.Open(propFile)
 			if err != nil {
 			if err != nil {
-				slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", propFile, err))
+				slog.Warn("failed to open sysfs node", "file", propFile, "erroir", err)
 				continue
 				continue
 			}
 			}
 			defer fp.Close()
 			defer fp.Close()
@@ -226,49 +179,113 @@ func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
 			}
 			}
 		}
 		}
 		if totalMemory == 0 {
 		if totalMemory == 0 {
-			slog.Warn(fmt.Sprintf("amdgpu [%d] reports zero total memory, skipping", id))
-			skip[id] = struct{}{}
+			slog.Warn("amdgpu reports zero total memory", "gpu", gpuID)
 			continue
 			continue
 		}
 		}
-		if totalMemory < IGPUMemLimit {
-			slog.Info(fmt.Sprintf("amdgpu [%d] appears to be an iGPU with %dM reported total memory, skipping", id, totalMemory/1024/1024))
-			skip[id] = struct{}{}
-			continue
-		}
-		usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUUsedMemoryFileGlob)
+		usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUUsedMemoryFileGlob)
 		usedFiles, err := filepath.Glob(usedGlob)
 		usedFiles, err := filepath.Glob(usedGlob)
 		if err != nil {
 		if err != nil {
-			slog.Warn(fmt.Sprintf("error looking up used GPU memory: %s %s", usedGlob, err))
+			slog.Warn("error looking up used GPU memory", "glob", usedGlob, "error", err)
 			continue
 			continue
 		}
 		}
 		for _, usedFile := range usedFiles {
 		for _, usedFile := range usedFiles {
 			fp, err := os.Open(usedFile)
 			fp, err := os.Open(usedFile)
 			if err != nil {
 			if err != nil {
-				slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", usedFile, err))
+				slog.Warn("failed to open sysfs node", "file", usedFile, "error", err)
 				continue
 				continue
 			}
 			}
 			defer fp.Close()
 			defer fp.Close()
 			data, err := io.ReadAll(fp)
 			data, err := io.ReadAll(fp)
 			if err != nil {
 			if err != nil {
-				slog.Warn(fmt.Sprintf("failed to read sysfs node file %s: %s", usedFile, err))
+				slog.Warn("failed to read sysfs node", "file", usedFile, "error", err)
 				continue
 				continue
 			}
 			}
 			used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
 			used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
 			if err != nil {
 			if err != nil {
-				slog.Warn(fmt.Sprintf("malformed used memory %s: %s", string(data), err))
+				slog.Warn("malformed used memory", "data", string(data), "error", err)
 				continue
 				continue
 			}
 			}
 			usedMemory += used
 			usedMemory += used
 		}
 		}
-		slog.Info(fmt.Sprintf("[%d] amdgpu totalMemory %dM", id, totalMemory/1024/1024))
-		slog.Info(fmt.Sprintf("[%d] amdgpu freeMemory  %dM", id, (totalMemory-usedMemory)/1024/1024))
-		resp.memInfo.DeviceCount++
-		resp.memInfo.TotalMemory += totalMemory
-		resp.memInfo.FreeMemory += (totalMemory - usedMemory)
+
+		// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
+		if totalMemory < IGPUMemLimit {
+			slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
+			continue
+		}
+
+		slog.Info("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
+		slog.Info("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
+		gpuInfo := GpuInfo{
+			Library: "rocm",
+			memInfo: memInfo{
+				TotalMemory: totalMemory,
+				FreeMemory:  (totalMemory - usedMemory),
+			},
+			ID: fmt.Sprintf("%d", gpuID),
+			// Name: not exposed in sysfs directly, would require pci device id lookup
+			Major:         int(major),
+			Minor:         int(minor),
+			Patch:         int(patch),
+			MinimumMemory: rocmMinimumMemory,
+		}
+
+		// If the user wants to filter to a subset of devices, filter out if we aren't a match
+		if len(visibleDevices) > 0 {
+			include := false
+			for _, visible := range visibleDevices {
+				if visible == gpuInfo.ID {
+					include = true
+					break
+				}
+			}
+			if !include {
+				slog.Info("filtering out device per user request", "id", gpuInfo.ID, "visible_devices", visibleDevices)
+				continue
+			}
+		}
+
+		// Final validation is gfx compatibility - load the library if we haven't already loaded it
+		// even if the user overrides, we still need to validate the library
+		if libDir == "" {
+			libDir, err = AMDValidateLibDir()
+			if err != nil {
+				slog.Warn("unable to verify rocm library, will use cpu", "error", err)
+				return []GpuInfo{}
+			}
+		}
+		gpuInfo.DependencyPath = libDir
+
+		if gfxOverride == "" {
+			// Only load supported list once
+			if len(supported) == 0 {
+				supported, err = GetSupportedGFX(libDir)
+				if err != nil {
+					slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
+					return []GpuInfo{}
+				}
+				slog.Debug("rocm supported GPUs", "types", supported)
+			}
+			gfx := fmt.Sprintf("gfx%d%d%d", gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch)
+			if !slices.Contains[[]string, string](supported, gfx) {
+				slog.Warn("amdgpu is not supported", "gpu", gpuInfo.ID, "gpu_type", gfx, "library", libDir, "supported_types", supported)
+				// TODO - consider discrete markdown just for ROCM troubleshooting?
+				slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
+				continue
+			} else {
+				slog.Info("amdgpu is supported", "gpu", gpuInfo.ID, "gpu_type", gfx)
+			}
+		} else {
+			slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
+		}
+
+		// The GPU has passed all the verification steps and is supported
+		resp = append(resp, gpuInfo)
 	}
 	}
-	if resp.memInfo.DeviceCount > 0 {
-		resp.Library = "rocm"
+	if len(resp) == 0 {
+		slog.Info("no compatible amdgpu devices detected")
 	}
 	}
+	return resp
 }
 }
 
 
 // Quick check for AMD driver so we can skip amdgpu discovery if not present
 // Quick check for AMD driver so we can skip amdgpu discovery if not present
@@ -280,87 +297,24 @@ func AMDDetected() bool {
 		slog.Debug("amdgpu driver not detected " + sysfsDir)
 		slog.Debug("amdgpu driver not detected " + sysfsDir)
 		return false
 		return false
 	} else if err != nil {
 	} else if err != nil {
-		slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err))
+		slog.Debug("error looking up amd driver", "path", sysfsDir, "error", err)
 		return false
 		return false
 	}
 	}
 	return true
 	return true
 }
 }
 
 
-func setupLink(source, target string) error {
-	if err := os.RemoveAll(target); err != nil {
-		return fmt.Errorf("failed to remove old rocm directory %s %w", target, err)
-	}
-	if err := os.Symlink(source, target); err != nil {
-		return fmt.Errorf("failed to create link %s => %s %w", source, target, err)
-	}
-	slog.Debug(fmt.Sprintf("host rocm linked %s => %s", source, target))
-	return nil
-}
-
-// Ensure the AMD rocm lib dir is wired up
 // Prefer to use host installed ROCm, as long as it meets our minimum requirements
 // Prefer to use host installed ROCm, as long as it meets our minimum requirements
 // failing that, tell the user how to download it on their own
 // failing that, tell the user how to download it on their own
 func AMDValidateLibDir() (string, error) {
 func AMDValidateLibDir() (string, error) {
-	// We rely on the rpath compiled into our library to find rocm
-	// so we establish a symlink to wherever we find it on the system
-	// to <payloads>/rocm
-	payloadsDir, err := PayloadsDir()
-	if err != nil {
-		return "", err
-	}
-
-	// If we already have a rocm dependency wired, nothing more to do
-	rocmTargetDir := filepath.Clean(filepath.Join(payloadsDir, "..", "rocm"))
-	if rocmLibUsable(rocmTargetDir) {
-		return rocmTargetDir, nil
-	}
-
-	// next to the running binary
-	exe, err := os.Executable()
+	libDir, err := commonAMDValidateLibDir()
 	if err == nil {
 	if err == nil {
-		peerDir := filepath.Dir(exe)
-		if rocmLibUsable(peerDir) {
-			slog.Debug("detected ROCM next to ollama executable " + peerDir)
-			return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
-		}
-		peerDir = filepath.Join(filepath.Dir(exe), "rocm")
-		if rocmLibUsable(peerDir) {
-			slog.Debug("detected ROCM next to ollama executable " + peerDir)
-			return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
-		}
+		return libDir, nil
 	}
 	}
 
 
 	// Well known ollama installer path
 	// Well known ollama installer path
 	installedRocmDir := "/usr/share/ollama/lib/rocm"
 	installedRocmDir := "/usr/share/ollama/lib/rocm"
 	if rocmLibUsable(installedRocmDir) {
 	if rocmLibUsable(installedRocmDir) {
-		return rocmTargetDir, setupLink(installedRocmDir, rocmTargetDir)
-	}
-
-	// Prefer explicit HIP env var
-	hipPath := os.Getenv("HIP_PATH")
-	if hipPath != "" {
-		hipLibDir := filepath.Join(hipPath, "lib")
-		if rocmLibUsable(hipLibDir) {
-			slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
-			return rocmTargetDir, setupLink(hipLibDir, rocmTargetDir)
-		}
-	}
-
-	// Scan the library path for potential matches
-	ldPaths := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
-	for _, ldPath := range ldPaths {
-		d, err := filepath.Abs(ldPath)
-		if err != nil {
-			continue
-		}
-		if rocmLibUsable(d) {
-			return rocmTargetDir, setupLink(d, rocmTargetDir)
-		}
-	}
-
-	// Well known location(s)
-	if rocmLibUsable("/opt/rocm/lib") {
-		return rocmTargetDir, setupLink("/opt/rocm/lib", rocmTargetDir)
+		return installedRocmDir, nil
 	}
 	}
 
 
 	// If we still haven't found a usable rocm, the user will have to install it on their own
 	// If we still haven't found a usable rocm, the user will have to install it on their own
@@ -384,68 +338,3 @@ func AMDDriverVersion() (string, error) {
 	}
 	}
 	return strings.TrimSpace(string(verString)), nil
 	return strings.TrimSpace(string(verString)), nil
 }
 }
-
-func AMDGFXVersions() map[int]Version {
-	// The amdgpu driver always exposes the host CPU as node 0, but we have to skip that and subtract one
-	// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
-	res := map[int]Version{}
-	matches, _ := filepath.Glob(GPUPropertiesFileGlob)
-	for _, match := range matches {
-		fp, err := os.Open(match)
-		if err != nil {
-			slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err))
-			continue
-		}
-		defer fp.Close()
-		i, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
-		if err != nil {
-			slog.Debug(fmt.Sprintf("failed to parse node ID %s", err))
-			continue
-		}
-
-		if i == 0 {
-			// Skipping the CPU
-			continue
-		}
-		// Align with HIP IDs (zero is first GPU, not CPU)
-		i -= 1
-
-		scanner := bufio.NewScanner(fp)
-		for scanner.Scan() {
-			line := strings.TrimSpace(scanner.Text())
-			if strings.HasPrefix(line, "gfx_target_version") {
-				ver := strings.Fields(line)
-				if len(ver) != 2 || len(ver[1]) < 5 {
-					if ver[1] != "0" {
-						slog.Debug("malformed " + line)
-					}
-					res[i] = Version{
-						Major: 0,
-						Minor: 0,
-						Patch: 0,
-					}
-					continue
-				}
-				l := len(ver[1])
-				patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32)
-				minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
-				major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32)
-				if err1 != nil || err2 != nil || err3 != nil {
-					slog.Debug("malformed int " + line)
-					continue
-				}
-
-				res[i] = Version{
-					Major: uint(major),
-					Minor: uint(minor),
-					Patch: uint(patch),
-				}
-			}
-		}
-	}
-	return res
-}
-
-func (v Version) ToGFXString() string {
-	return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch)
-}

+ 80 - 74
gpu/amd_windows.go

@@ -7,7 +7,10 @@ import (
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
 	"slices"
 	"slices"
+	"strconv"
 	"strings"
 	"strings"
+
+	"github.com/ollama/ollama/format"
 )
 )
 
 
 const (
 const (
@@ -22,36 +25,32 @@ var (
 	ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
 	ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
 )
 )
 
 
-func AMDGetGPUInfo(resp *GpuInfo) {
+func AMDGetGPUInfo() []GpuInfo {
+	resp := []GpuInfo{}
 	hl, err := NewHipLib()
 	hl, err := NewHipLib()
 	if err != nil {
 	if err != nil {
 		slog.Debug(err.Error())
 		slog.Debug(err.Error())
-		return
+		return nil
 	}
 	}
 	defer hl.Release()
 	defer hl.Release()
-	skip := map[int]interface{}{}
-	ids := []int{}
-	resp.memInfo.DeviceCount = 0
-	resp.memInfo.TotalMemory = 0
-	resp.memInfo.FreeMemory = 0
 
 
 	ver, err := hl.AMDDriverVersion()
 	ver, err := hl.AMDDriverVersion()
 	if err == nil {
 	if err == nil {
 		slog.Info("AMD Driver: " + ver)
 		slog.Info("AMD Driver: " + ver)
 	} else {
 	} else {
 		// For now this is benign, but we may eventually need to fail compatibility checks
 		// For now this is benign, but we may eventually need to fail compatibility checks
-		slog.Debug(fmt.Sprintf("error looking up amd driver version: %s", err))
+		slog.Debug("error looking up amd driver version", "error", err)
 	}
 	}
 
 
-	// Note: the HIP library automatically handles HIP_VISIBLE_DEVICES
+	// Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
 	count := hl.HipGetDeviceCount()
 	count := hl.HipGetDeviceCount()
 	if count == 0 {
 	if count == 0 {
-		return
+		return nil
 	}
 	}
 	libDir, err := AMDValidateLibDir()
 	libDir, err := AMDValidateLibDir()
 	if err != nil {
 	if err != nil {
-		slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
-		return
+		slog.Warn("unable to verify rocm library, will use cpu", "error", err)
+		return nil
 	}
 	}
 
 
 	var supported []string
 	var supported []string
@@ -59,95 +58,117 @@ func AMDGetGPUInfo(resp *GpuInfo) {
 	if gfxOverride == "" {
 	if gfxOverride == "" {
 		supported, err = GetSupportedGFX(libDir)
 		supported, err = GetSupportedGFX(libDir)
 		if err != nil {
 		if err != nil {
-			slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
-			return
+			slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
+			return nil
 		}
 		}
 	} else {
 	} else {
 		slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
 		slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
 	}
 	}
 
 
-	slog.Info(fmt.Sprintf("detected %d hip devices", count))
+	slog.Info("detected hip devices", "count", count)
+	// TODO how to determine the underlying device ID when visible devices is causing this to subset?
 	for i := 0; i < count; i++ {
 	for i := 0; i < count; i++ {
-		ids = append(ids, i)
 		err = hl.HipSetDevice(i)
 		err = hl.HipSetDevice(i)
 		if err != nil {
 		if err != nil {
-			slog.Warn(fmt.Sprintf("[%d] %s", i, err))
-			skip[i] = struct{}{}
+			slog.Warn("set device", "id", i, "error", err)
 			continue
 			continue
 		}
 		}
 
 
 		props, err := hl.HipGetDeviceProperties(i)
 		props, err := hl.HipGetDeviceProperties(i)
 		if err != nil {
 		if err != nil {
-			slog.Warn(fmt.Sprintf("[%d] %s", i, err))
-			skip[i] = struct{}{}
+			slog.Warn("get properties", "id", i, "error", err)
 			continue
 			continue
 		}
 		}
 		n := bytes.IndexByte(props.Name[:], 0)
 		n := bytes.IndexByte(props.Name[:], 0)
 		name := string(props.Name[:n])
 		name := string(props.Name[:n])
-		slog.Info(fmt.Sprintf("[%d] Name: %s", i, name))
+		// TODO is UUID actually populated on windows?
+		// Can luid be used on windows for setting visible devices (and is it actually set?)
 		n = bytes.IndexByte(props.GcnArchName[:], 0)
 		n = bytes.IndexByte(props.GcnArchName[:], 0)
 		gfx := string(props.GcnArchName[:n])
 		gfx := string(props.GcnArchName[:n])
-		slog.Info(fmt.Sprintf("[%d] GcnArchName: %s", i, gfx))
+		slog.Info("hip device", "id", i, "name", name, "gfx", gfx)
+		var major, minor, patch string
+		switch len(gfx) {
+		case 6:
+			major, minor, patch = gfx[3:4], gfx[4:5], gfx[5:]
+		case 7:
+			major, minor, patch = gfx[3:5], gfx[5:6], gfx[6:]
+		}
 		//slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY!  Always 0
 		//slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY!  Always 0
 		// TODO  Why isn't props.iGPU accurate!?
 		// TODO  Why isn't props.iGPU accurate!?
 		if strings.EqualFold(name, iGPUName) {
 		if strings.EqualFold(name, iGPUName) {
-			slog.Info(fmt.Sprintf("iGPU detected [%d] skipping", i))
-			skip[i] = struct{}{}
+			slog.Info("iGPU detected skipping", "id", i)
 			continue
 			continue
 		}
 		}
 		if gfxOverride == "" {
 		if gfxOverride == "" {
 			if !slices.Contains[[]string, string](supported, gfx) {
 			if !slices.Contains[[]string, string](supported, gfx) {
-				slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, gfx, libDir, supported))
+				slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
 				// TODO - consider discrete markdown just for ROCM troubleshooting?
 				// TODO - consider discrete markdown just for ROCM troubleshooting?
 				slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
 				slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
-				skip[i] = struct{}{}
 				continue
 				continue
 			} else {
 			} else {
-				slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, gfx))
+				slog.Info("amdgpu is supported", "gpu", i, "gpu_type", gfx)
 			}
 			}
 		}
 		}
 
 
-		totalMemory, freeMemory, err := hl.HipMemGetInfo()
+		freeMemory, totalMemory, err := hl.HipMemGetInfo()
 		if err != nil {
 		if err != nil {
-			slog.Warn(fmt.Sprintf("[%d] %s", i, err))
+			slog.Warn("get mem info", "id", i, "error", err)
 			continue
 			continue
 		}
 		}
 
 
-		// TODO according to docs, freeMem may lie on windows!
-		slog.Info(fmt.Sprintf("[%d] Total Mem: %d", i, totalMemory))
-		slog.Info(fmt.Sprintf("[%d] Free Mem:  %d", i, freeMemory))
-		resp.memInfo.DeviceCount++
-		resp.memInfo.TotalMemory += totalMemory
-		resp.memInfo.FreeMemory += freeMemory
-	}
-	if resp.memInfo.DeviceCount > 0 {
-		resp.Library = "rocm"
-	}
-	// Abort if all GPUs are skipped
-	if len(skip) >= count {
-		slog.Info("all detected amdgpus are skipped, falling back to CPU")
-		return
-	}
-	if len(skip) > 0 {
-		amdSetVisibleDevices(ids, skip)
+		// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
+		if totalMemory < IGPUMemLimit {
+			slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", i, "total", format.HumanBytes2(totalMemory))
+			continue
+		}
+
+		// TODO revisit this once ROCm v6 is available on windows.
+		// v5.7 only reports VRAM used by this process, so it's completely wrong and unusable
+		slog.Info("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory))
+		slog.Info("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory))
+		gpuInfo := GpuInfo{
+			Library: "rocm",
+			memInfo: memInfo{
+				TotalMemory: totalMemory,
+				FreeMemory:  freeMemory,
+			},
+			ID:             fmt.Sprintf("%d", i), // TODO this is probably wrong if we specify visible devices
+			DependencyPath: libDir,
+			MinimumMemory:  rocmMinimumMemory,
+		}
+		if major != "" {
+			gpuInfo.Major, err = strconv.Atoi(major)
+			if err != nil {
+				slog.Info("failed to parse version", "version", gfx, "error", err)
+			}
+		}
+		if minor != "" {
+			gpuInfo.Minor, err = strconv.Atoi(minor)
+			if err != nil {
+				slog.Info("failed to parse version", "version", gfx, "error", err)
+			}
+		}
+		if patch != "" {
+			gpuInfo.Patch, err = strconv.Atoi(patch)
+			if err != nil {
+				slog.Info("failed to parse version", "version", gfx, "error", err)
+			}
+		}
+		if gpuInfo.Major < RocmComputeMin {
+			slog.Warn(fmt.Sprintf("amdgpu [%s] too old gfx%d%d%d", gpuInfo.ID, gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch))
+			continue
+		}
+
+		resp = append(resp, gpuInfo)
 	}
 	}
-	UpdatePath(libDir)
+
+	return resp
 }
 }
 
 
 func AMDValidateLibDir() (string, error) {
 func AMDValidateLibDir() (string, error) {
-	// On windows non-admins typically can't create links
-	// so instead of trying to rely on rpath and a link in
-	// $LibDir/rocm, we instead rely on setting PATH to point
-	// to the location of the ROCm library
-
-	// Installer payload location if we're running the installed binary
-	exe, err := os.Executable()
+	libDir, err := commonAMDValidateLibDir()
 	if err == nil {
 	if err == nil {
-		rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
-		if rocmLibUsable(rocmTargetDir) {
-			slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
-			return rocmTargetDir, nil
-		}
+		return libDir, nil
 	}
 	}
 
 
 	// Installer payload (if we're running from some other location)
 	// Installer payload (if we're running from some other location)
@@ -159,21 +180,6 @@ func AMDValidateLibDir() (string, error) {
 		return rocmTargetDir, nil
 		return rocmTargetDir, nil
 	}
 	}
 
 
-	// Prefer explicit HIP env var
-	hipPath := os.Getenv("HIP_PATH")
-	if hipPath != "" {
-		hipLibDir := filepath.Join(hipPath, "bin")
-		if rocmLibUsable(hipLibDir) {
-			slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
-			return hipLibDir, nil
-		}
-	}
-
-	// Well known location(s)
-	if rocmLibUsable(RocmStandardLocation) {
-		return RocmStandardLocation, nil
-	}
-
 	// Should not happen on windows since we include it in the installer, but stand-alone binary might hit this
 	// Should not happen on windows since we include it in the installer, but stand-alone binary might hit this
 	slog.Warn("amdgpu detected, but no compatible rocm library found.  Please install ROCm")
 	slog.Warn("amdgpu detected, but no compatible rocm library found.  Please install ROCm")
 	return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
 	return "", fmt.Errorf("no suitable rocm found, falling back to CPU")

+ 2 - 2
gpu/assets.go

@@ -80,7 +80,7 @@ func cleanupTmpDirs() {
 		}
 		}
 		err = os.RemoveAll(d)
 		err = os.RemoveAll(d)
 		if err != nil {
 		if err != nil {
-			slog.Debug(fmt.Sprintf("unable to cleanup stale tmpdir %s: %s", d, err))
+			slog.Debug("unable to cleanup stale tmpdir", "path", d, "error", err)
 		}
 		}
 	}
 	}
 }
 }
@@ -120,7 +120,7 @@ func UpdatePath(dir string) {
 			}
 			}
 		}
 		}
 		newPath := strings.Join(append([]string{dir}, pathComponents...), ";")
 		newPath := strings.Join(append([]string{dir}, pathComponents...), ";")
-		slog.Info(fmt.Sprintf("Updating PATH to %s", newPath))
+		slog.Info("updating", "PATH", newPath)
 		os.Setenv("PATH", newPath)
 		os.Setenv("PATH", newPath)
 	}
 	}
 	// linux and darwin rely on rpath
 	// linux and darwin rely on rpath

+ 22 - 0
gpu/cuda_common.go

@@ -0,0 +1,22 @@
+//go:build linux || windows
+
+package gpu
+
+import (
+	"log/slog"
+	"strings"
+)
+
+func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
+	ids := []string{}
+	for _, info := range gpuInfo {
+		if info.Library != "cuda" {
+			// TODO shouldn't happen if things are wired correctly...
+			slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library)
+			continue
+		}
+		ids = append(ids, info.ID)
+	}
+	return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
+
+}

+ 82 - 145
gpu/gpu.go

@@ -16,7 +16,6 @@ import (
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
 	"runtime"
 	"runtime"
-	"strconv"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 	"unsafe"
 	"unsafe"
@@ -25,8 +24,8 @@ import (
 )
 )
 
 
 type handles struct {
 type handles struct {
-	nvml   *C.nvml_handle_t
-	cudart *C.cudart_handle_t
+	deviceCount int
+	cudart      *C.cudart_handle_t
 }
 }
 
 
 const (
 const (
@@ -39,26 +38,10 @@ var gpuMutex sync.Mutex
 // With our current CUDA compile flags, older than 5.0 will not work properly
 // With our current CUDA compile flags, older than 5.0 will not work properly
 var CudaComputeMin = [2]C.int{5, 0}
 var CudaComputeMin = [2]C.int{5, 0}
 
 
-// Possible locations for the nvidia-ml library
-var NvmlLinuxGlobs = []string{
-	"/usr/local/cuda/lib64/libnvidia-ml.so*",
-	"/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so*",
-	"/usr/lib/x86_64-linux-gnu/libnvidia-ml.so*",
-	"/usr/lib/wsl/lib/libnvidia-ml.so*",
-	"/usr/lib/wsl/drivers/*/libnvidia-ml.so*",
-	"/opt/cuda/lib64/libnvidia-ml.so*",
-	"/usr/lib*/libnvidia-ml.so*",
-	"/usr/lib/aarch64-linux-gnu/nvidia/current/libnvidia-ml.so*",
-	"/usr/lib/aarch64-linux-gnu/libnvidia-ml.so*",
-	"/usr/local/lib*/libnvidia-ml.so*",
-
-	// TODO: are these stubs ever valid?
-	"/opt/cuda/targets/x86_64-linux/lib/stubs/libnvidia-ml.so*",
-}
+var RocmComputeMin = 9
 
 
-var NvmlWindowsGlobs = []string{
-	"c:\\Windows\\System32\\nvml.dll",
-}
+// TODO find a better way to detect iGPU instead of minimum memory
+const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU
 
 
 var CudartLinuxGlobs = []string{
 var CudartLinuxGlobs = []string{
 	"/usr/local/cuda/lib64/libcudart.so*",
 	"/usr/local/cuda/lib64/libcudart.so*",
@@ -88,26 +71,18 @@ func initGPUHandles() *handles {
 
 
 	// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
 	// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
 
 
-	gpuHandles := &handles{nil, nil}
-	var nvmlMgmtName string
-	var nvmlMgmtPatterns []string
+	gpuHandles := &handles{}
 	var cudartMgmtName string
 	var cudartMgmtName string
 	var cudartMgmtPatterns []string
 	var cudartMgmtPatterns []string
 
 
 	tmpDir, _ := PayloadsDir()
 	tmpDir, _ := PayloadsDir()
 	switch runtime.GOOS {
 	switch runtime.GOOS {
 	case "windows":
 	case "windows":
-		nvmlMgmtName = "nvml.dll"
-		nvmlMgmtPatterns = make([]string, len(NvmlWindowsGlobs))
-		copy(nvmlMgmtPatterns, NvmlWindowsGlobs)
 		cudartMgmtName = "cudart64_*.dll"
 		cudartMgmtName = "cudart64_*.dll"
 		localAppData := os.Getenv("LOCALAPPDATA")
 		localAppData := os.Getenv("LOCALAPPDATA")
 		cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)}
 		cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)}
 		cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...)
 		cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...)
 	case "linux":
 	case "linux":
-		nvmlMgmtName = "libnvidia-ml.so"
-		nvmlMgmtPatterns = make([]string, len(NvmlLinuxGlobs))
-		copy(nvmlMgmtPatterns, NvmlLinuxGlobs)
 		cudartMgmtName = "libcudart.so*"
 		cudartMgmtName = "libcudart.so*"
 		if tmpDir != "" {
 		if tmpDir != "" {
 			// TODO - add "payloads" for subprocess
 			// TODO - add "payloads" for subprocess
@@ -118,31 +93,21 @@ func initGPUHandles() *handles {
 		return gpuHandles
 		return gpuHandles
 	}
 	}
 
 
-	slog.Info("Detecting GPU type")
+	slog.Info("Detecting GPUs")
 	cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns)
 	cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns)
 	if len(cudartLibPaths) > 0 {
 	if len(cudartLibPaths) > 0 {
-		cudart := LoadCUDARTMgmt(cudartLibPaths)
+		deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths)
 		if cudart != nil {
 		if cudart != nil {
-			slog.Info("Nvidia GPU detected via cudart")
+			slog.Info("detected GPUs", "library", libPath, "count", deviceCount)
 			gpuHandles.cudart = cudart
 			gpuHandles.cudart = cudart
-			return gpuHandles
-		}
-	}
-
-	// TODO once we build confidence, remove this and the gpu_info_nvml.[ch] files
-	nvmlLibPaths := FindGPULibs(nvmlMgmtName, nvmlMgmtPatterns)
-	if len(nvmlLibPaths) > 0 {
-		nvml := LoadNVMLMgmt(nvmlLibPaths)
-		if nvml != nil {
-			slog.Info("Nvidia GPU detected via nvidia-ml")
-			gpuHandles.nvml = nvml
+			gpuHandles.deviceCount = deviceCount
 			return gpuHandles
 			return gpuHandles
 		}
 		}
 	}
 	}
 	return gpuHandles
 	return gpuHandles
 }
 }
 
 
-func GetGPUInfo() GpuInfo {
+func GetGPUInfo() GpuInfoList {
 	// TODO - consider exploring lspci (and equivalent on windows) to check for
 	// TODO - consider exploring lspci (and equivalent on windows) to check for
 	// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
 	// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
 	gpuMutex.Lock()
 	gpuMutex.Lock()
@@ -150,9 +115,6 @@ func GetGPUInfo() GpuInfo {
 
 
 	gpuHandles := initGPUHandles()
 	gpuHandles := initGPUHandles()
 	defer func() {
 	defer func() {
-		if gpuHandles.nvml != nil {
-			C.nvml_release(*gpuHandles.nvml)
-		}
 		if gpuHandles.cudart != nil {
 		if gpuHandles.cudart != nil {
 			C.cudart_release(*gpuHandles.cudart)
 			C.cudart_release(*gpuHandles.cudart)
 		}
 		}
@@ -165,72 +127,63 @@ func GetGPUInfo() GpuInfo {
 	}
 	}
 
 
 	var memInfo C.mem_info_t
 	var memInfo C.mem_info_t
-	resp := GpuInfo{}
-	if gpuHandles.nvml != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
-		C.nvml_check_vram(*gpuHandles.nvml, &memInfo)
-		if memInfo.err != nil {
-			slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU memory: %s", C.GoString(memInfo.err)))
-			C.free(unsafe.Pointer(memInfo.err))
-		} else if memInfo.count > 0 {
-			// Verify minimum compute capability
-			var cc C.nvml_compute_capability_t
-			C.nvml_compute_capability(*gpuHandles.nvml, &cc)
-			if cc.err != nil {
-				slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU compute capability: %s", C.GoString(cc.err)))
-				C.free(unsafe.Pointer(cc.err))
-			} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
-				slog.Info(fmt.Sprintf("[nvidia-ml] NVML CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
-				resp.Library = "cuda"
-				resp.MinimumMemory = cudaMinimumMemory
-			} else {
-				slog.Info(fmt.Sprintf("[nvidia-ml] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
-			}
+	resp := []GpuInfo{}
+
+	// NVIDIA first
+	for i := 0; i < gpuHandles.deviceCount; i++ {
+		// TODO once we support CPU compilation variants of GPU libraries refine this...
+		if cpuVariant == "" && runtime.GOARCH == "amd64" {
+			continue
 		}
 		}
-	} else if gpuHandles.cudart != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
-		C.cudart_check_vram(*gpuHandles.cudart, &memInfo)
+		gpuInfo := GpuInfo{
+			Library: "cuda",
+		}
+		C.cudart_check_vram(*gpuHandles.cudart, C.int(i), &memInfo)
 		if memInfo.err != nil {
 		if memInfo.err != nil {
-			slog.Info(fmt.Sprintf("[cudart] error looking up CUDART GPU memory: %s", C.GoString(memInfo.err)))
+			slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
 			C.free(unsafe.Pointer(memInfo.err))
 			C.free(unsafe.Pointer(memInfo.err))
-		} else if memInfo.count > 0 {
-			// Verify minimum compute capability
-			var cc C.cudart_compute_capability_t
-			C.cudart_compute_capability(*gpuHandles.cudart, &cc)
-			if cc.err != nil {
-				slog.Info(fmt.Sprintf("[cudart] error looking up CUDA compute capability: %s", C.GoString(cc.err)))
-				C.free(unsafe.Pointer(cc.err))
-			} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
-				slog.Info(fmt.Sprintf("[cudart] CUDART CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
-				resp.Library = "cuda"
-				resp.MinimumMemory = cudaMinimumMemory
-			} else {
-				slog.Info(fmt.Sprintf("[cudart] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
-			}
+			continue
 		}
 		}
-	} else {
-		AMDGetGPUInfo(&resp)
-		if resp.Library != "" {
-			resp.MinimumMemory = rocmMinimumMemory
-			return resp
+		if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) {
+			slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
+			continue
 		}
 		}
+		gpuInfo.TotalMemory = uint64(memInfo.total)
+		gpuInfo.FreeMemory = uint64(memInfo.free)
+		gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
+		gpuInfo.Major = int(memInfo.major)
+		gpuInfo.Minor = int(memInfo.minor)
+		gpuInfo.MinimumMemory = cudaMinimumMemory
+
+		// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
+		resp = append(resp, gpuInfo)
 	}
 	}
-	if resp.Library == "" {
+
+	// Then AMD
+	resp = append(resp, AMDGetGPUInfo()...)
+
+	if len(resp) == 0 {
 		C.cpu_check_ram(&memInfo)
 		C.cpu_check_ram(&memInfo)
-		resp.Library = "cpu"
-		resp.Variant = cpuVariant
-	}
-	if memInfo.err != nil {
-		slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err)))
-		C.free(unsafe.Pointer(memInfo.err))
-		return resp
+		if memInfo.err != nil {
+			slog.Info("error looking up CPU memory", "error", C.GoString(memInfo.err))
+			C.free(unsafe.Pointer(memInfo.err))
+			return resp
+		}
+		gpuInfo := GpuInfo{
+			Library: "cpu",
+			Variant: cpuVariant,
+		}
+		gpuInfo.TotalMemory = uint64(memInfo.total)
+		gpuInfo.FreeMemory = uint64(memInfo.free)
+		gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
+
+		resp = append(resp, gpuInfo)
 	}
 	}
 
 
-	resp.DeviceCount = uint32(memInfo.count)
-	resp.FreeMemory = uint64(memInfo.free)
-	resp.TotalMemory = uint64(memInfo.total)
 	return resp
 	return resp
 }
 }
 
 
-func getCPUMem() (memInfo, error) {
+func GetCPUMem() (memInfo, error) {
 	var ret memInfo
 	var ret memInfo
 	var info C.mem_info_t
 	var info C.mem_info_t
 	C.cpu_check_ram(&info)
 	C.cpu_check_ram(&info)
@@ -243,29 +196,11 @@ func getCPUMem() (memInfo, error) {
 	return ret, nil
 	return ret, nil
 }
 }
 
 
-func CheckVRAM() (uint64, error) {
-	userLimit := os.Getenv("OLLAMA_MAX_VRAM")
-	if userLimit != "" {
-		avail, err := strconv.ParseInt(userLimit, 10, 64)
-		if err != nil {
-			return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
-		}
-		slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
-		return uint64(avail), nil
-	}
-	gpuInfo := GetGPUInfo()
-	if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
-		return gpuInfo.FreeMemory, nil
-	}
-
-	return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
-}
-
 func FindGPULibs(baseLibName string, patterns []string) []string {
 func FindGPULibs(baseLibName string, patterns []string) []string {
 	// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
 	// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
 	var ldPaths []string
 	var ldPaths []string
 	gpuLibPaths := []string{}
 	gpuLibPaths := []string{}
-	slog.Info(fmt.Sprintf("Searching for GPU management library %s", baseLibName))
+	slog.Debug("Searching for GPU library", "name", baseLibName)
 
 
 	switch runtime.GOOS {
 	switch runtime.GOOS {
 	case "windows":
 	case "windows":
@@ -283,7 +218,7 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
 		}
 		}
 		patterns = append(patterns, filepath.Join(d, baseLibName+"*"))
 		patterns = append(patterns, filepath.Join(d, baseLibName+"*"))
 	}
 	}
-	slog.Debug(fmt.Sprintf("gpu management search paths: %v", patterns))
+	slog.Debug("gpu library search", "globs", patterns)
 	for _, pattern := range patterns {
 	for _, pattern := range patterns {
 		// Ignore glob discovery errors
 		// Ignore glob discovery errors
 		matches, _ := filepath.Glob(pattern)
 		matches, _ := filepath.Glob(pattern)
@@ -311,28 +246,11 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
 			}
 			}
 		}
 		}
 	}
 	}
-	slog.Info(fmt.Sprintf("Discovered GPU libraries: %v", gpuLibPaths))
+	slog.Debug("discovered GPU libraries", "paths", gpuLibPaths)
 	return gpuLibPaths
 	return gpuLibPaths
 }
 }
 
 
-func LoadNVMLMgmt(nvmlLibPaths []string) *C.nvml_handle_t {
-	var resp C.nvml_init_resp_t
-	resp.ch.verbose = getVerboseState()
-	for _, libPath := range nvmlLibPaths {
-		lib := C.CString(libPath)
-		defer C.free(unsafe.Pointer(lib))
-		C.nvml_init(lib, &resp)
-		if resp.err != nil {
-			slog.Info(fmt.Sprintf("Unable to load NVML management library %s: %s", libPath, C.GoString(resp.err)))
-			C.free(unsafe.Pointer(resp.err))
-		} else {
-			return &resp.ch
-		}
-	}
-	return nil
-}
-
-func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t {
+func LoadCUDARTMgmt(cudartLibPaths []string) (int, *C.cudart_handle_t, string) {
 	var resp C.cudart_init_resp_t
 	var resp C.cudart_init_resp_t
 	resp.ch.verbose = getVerboseState()
 	resp.ch.verbose = getVerboseState()
 	for _, libPath := range cudartLibPaths {
 	for _, libPath := range cudartLibPaths {
@@ -340,13 +258,13 @@ func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t {
 		defer C.free(unsafe.Pointer(lib))
 		defer C.free(unsafe.Pointer(lib))
 		C.cudart_init(lib, &resp)
 		C.cudart_init(lib, &resp)
 		if resp.err != nil {
 		if resp.err != nil {
-			slog.Info(fmt.Sprintf("Unable to load cudart CUDA management library %s: %s", libPath, C.GoString(resp.err)))
+			slog.Debug("Unable to load cudart", "library", libPath, "error", C.GoString(resp.err))
 			C.free(unsafe.Pointer(resp.err))
 			C.free(unsafe.Pointer(resp.err))
 		} else {
 		} else {
-			return &resp.ch
+			return int(resp.num_devices), &resp.ch, libPath
 		}
 		}
 	}
 	}
-	return nil
+	return 0, nil, ""
 }
 }
 
 
 func getVerboseState() C.uint16_t {
 func getVerboseState() C.uint16_t {
@@ -355,3 +273,22 @@ func getVerboseState() C.uint16_t {
 	}
 	}
 	return C.uint16_t(0)
 	return C.uint16_t(0)
 }
 }
+
+// Given the list of GPUs this instantiation is targeted for,
+// figure out the visible devices environment variable
+//
+// If different libraries are detected, the first one is what we use
+func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
+	if len(l) == 0 {
+		return "", ""
+	}
+	switch l[0].Library {
+	case "cuda":
+		return cudaGetVisibleDevicesEnv(l)
+	case "rocm":
+		return rocmGetVisibleDevicesEnv(l)
+	default:
+		slog.Debug("no filter required for library " + l[0].Library)
+		return "", ""
+	}
+}

+ 23 - 34
gpu/gpu_darwin.go

@@ -9,52 +9,41 @@ package gpu
 */
 */
 import "C"
 import "C"
 import (
 import (
-	"fmt"
-	"log/slog"
-	"os"
 	"runtime"
 	"runtime"
-	"strconv"
 )
 )
 
 
-// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
-func CheckVRAM() (uint64, error) {
-	userLimit := os.Getenv("OLLAMA_MAX_VRAM")
-	if userLimit != "" {
-		avail, err := strconv.ParseInt(userLimit, 10, 64)
-		if err != nil {
-			return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
-		}
-		slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
-		return uint64(avail), nil
-	}
-
-	if runtime.GOARCH == "amd64" {
-		// gpu not supported, this may not be metal
-		return 0, nil
-	}
-
-	return uint64(C.getRecommendedMaxVRAM()), nil
-}
-
-func GetGPUInfo() GpuInfo {
-	mem, _ := getCPUMem()
+func GetGPUInfo() GpuInfoList {
+	mem, _ := GetCPUMem()
 	if runtime.GOARCH == "amd64" {
 	if runtime.GOARCH == "amd64" {
-		return GpuInfo{
-			Library: "cpu",
-			Variant: GetCPUVariant(),
-			memInfo: mem,
+		return []GpuInfo{
+			{
+				Library: "cpu",
+				Variant: GetCPUVariant(),
+				memInfo: mem,
+			},
 		}
 		}
 	}
 	}
-	return GpuInfo{
+	info := GpuInfo{
 		Library: "metal",
 		Library: "metal",
-		memInfo: mem,
+		ID:      "0",
 	}
 	}
+	info.TotalMemory = uint64(C.getRecommendedMaxVRAM())
+
+	// TODO is there a way to gather actual allocated video memory? (currentAllocatedSize doesn't work)
+	info.FreeMemory = info.TotalMemory
+
+	info.MinimumMemory = 0
+	return []GpuInfo{info}
 }
 }
 
 
-func getCPUMem() (memInfo, error) {
+func GetCPUMem() (memInfo, error) {
 	return memInfo{
 	return memInfo{
 		TotalMemory: uint64(C.getPhysicalMemory()),
 		TotalMemory: uint64(C.getPhysicalMemory()),
 		FreeMemory:  0,
 		FreeMemory:  0,
-		DeviceCount: 1,
 	}, nil
 	}, nil
 }
 }
+
+func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
+	// No-op on darwin
+	return "", ""
+}

+ 8 - 4
gpu/gpu_info.h

@@ -38,12 +38,17 @@
 extern "C" {
 extern "C" {
 #endif
 #endif
 
 
+#define GPU_ID_LEN 64
+
 typedef struct mem_info {
 typedef struct mem_info {
+  char *err;  // If non-nill, caller responsible for freeing
+  char gpu_id[GPU_ID_LEN];
   uint64_t total;
   uint64_t total;
   uint64_t free;
   uint64_t free;
-  unsigned int count;
-  int igpu_index; // If >= 0, we detected an integrated GPU to ignore
-  char *err;  // If non-nill, caller responsible for freeing
+
+  // Compute Capability
+  int major; 
+  int minor;
 } mem_info_t;
 } mem_info_t;
 
 
 void cpu_check_ram(mem_info_t *resp);
 void cpu_check_ram(mem_info_t *resp);
@@ -52,7 +57,6 @@ void cpu_check_ram(mem_info_t *resp);
 }
 }
 #endif
 #endif
 
 
-#include "gpu_info_nvml.h"
 #include "gpu_info_cudart.h"
 #include "gpu_info_cudart.h"
 
 
 #endif  // __GPU_INFO_H__
 #endif  // __GPU_INFO_H__

+ 6 - 2
gpu/gpu_info_cpu.c

@@ -8,9 +8,11 @@ void cpu_check_ram(mem_info_t *resp) {
   MEMORYSTATUSEX info;
   MEMORYSTATUSEX info;
   info.dwLength = sizeof(info);
   info.dwLength = sizeof(info);
   if (GlobalMemoryStatusEx(&info) != 0) {
   if (GlobalMemoryStatusEx(&info) != 0) {
-    resp->count = 1;
     resp->total = info.ullTotalPhys;
     resp->total = info.ullTotalPhys;
     resp->free = info.ullAvailPhys;
     resp->free = info.ullAvailPhys;
+    resp->major = 0;
+    resp->minor = 0;
+    snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
   } else {
   } else {
     resp->err = LOAD_ERR();
     resp->err = LOAD_ERR();
   }
   }
@@ -27,9 +29,11 @@ void cpu_check_ram(mem_info_t *resp) {
   if (sysinfo(&info) != 0) {
   if (sysinfo(&info) != 0) {
     resp->err = strdup(strerror(errno));
     resp->err = strdup(strerror(errno));
   } else {
   } else {
-    resp->count = 1;
     resp->total = info.totalram * info.mem_unit;
     resp->total = info.totalram * info.mem_unit;
     resp->free = info.freeram * info.mem_unit;
     resp->free = info.freeram * info.mem_unit;
+    resp->major = 0;
+    resp->minor = 0;
+    snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
   }
   }
   return;
   return;
 }
 }

+ 63 - 82
gpu/gpu_info_cudart.c

@@ -6,6 +6,7 @@
 void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
 void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
   cudartReturn_t ret;
   cudartReturn_t ret;
   resp->err = NULL;
   resp->err = NULL;
+  resp->num_devices = 0;
   const int buflen = 256;
   const int buflen = 256;
   char buf[buflen + 1];
   char buf[buflen + 1];
   int i;
   int i;
@@ -21,6 +22,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
       {"cudaGetDeviceCount", (void *)&resp->ch.cudaGetDeviceCount},
       {"cudaGetDeviceCount", (void *)&resp->ch.cudaGetDeviceCount},
       {"cudaDeviceGetAttribute", (void *)&resp->ch.cudaDeviceGetAttribute},
       {"cudaDeviceGetAttribute", (void *)&resp->ch.cudaDeviceGetAttribute},
       {"cudaDriverGetVersion", (void *)&resp->ch.cudaDriverGetVersion},
       {"cudaDriverGetVersion", (void *)&resp->ch.cudaDriverGetVersion},
+      {"cudaGetDeviceProperties", (void *)&resp->ch.cudaGetDeviceProperties},
       {NULL, NULL},
       {NULL, NULL},
   };
   };
 
 
@@ -36,13 +38,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
     return;
     return;
   }
   }
 
 
-  // TODO once we've squashed the remaining corner cases remove this log
-  LOG(resp->ch.verbose, "wiring cudart library functions in %s\n", cudart_lib_path);
-  
   for (i = 0; l[i].s != NULL; i++) {
   for (i = 0; l[i].s != NULL; i++) {
-    // TODO once we've squashed the remaining corner cases remove this log
-    LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
-
     *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
     *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
     if (!l[i].p) {
     if (!l[i].p) {
       char *msg = LOAD_ERR();
       char *msg = LOAD_ERR();
@@ -63,7 +59,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
     UNLOAD_LIBRARY(resp->ch.handle);
     UNLOAD_LIBRARY(resp->ch.handle);
     resp->ch.handle = NULL;
     resp->ch.handle = NULL;
     if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
     if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
-      resp->err = strdup("your nvidia driver is too old or missing, please upgrade to run ollama");
+      resp->err = strdup("your nvidia driver is too old or missing.  If you have a CUDA GPU please upgrade to run ollama");
       return;
       return;
     }
     }
     snprintf(buf, buflen, "cudart init failure: %d", ret);
     snprintf(buf, buflen, "cudart init failure: %d", ret);
@@ -85,110 +81,95 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
     driverVersion.minor = (version - (driverVersion.major * 1000)) / 10;
     driverVersion.minor = (version - (driverVersion.major * 1000)) / 10;
     LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor);
     LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor);
   }
   }
+
+  ret = (*resp->ch.cudaGetDeviceCount)(&resp->num_devices);
+  if (ret != CUDART_SUCCESS) {
+    LOG(resp->ch.verbose, "cudaGetDeviceCount err: %d\n", ret);
+    UNLOAD_LIBRARY(resp->ch.handle);
+    resp->ch.handle = NULL;
+    snprintf(buf, buflen, "unable to get device count: %d", ret);
+    resp->err = strdup(buf);
+    return;
+  }
 }
 }
 
 
 
 
-void cudart_check_vram(cudart_handle_t h, mem_info_t *resp) {
+void cudart_check_vram(cudart_handle_t h, int i, mem_info_t *resp) {
   resp->err = NULL;
   resp->err = NULL;
   cudartMemory_t memInfo = {0,0,0};
   cudartMemory_t memInfo = {0,0,0};
   cudartReturn_t ret;
   cudartReturn_t ret;
   const int buflen = 256;
   const int buflen = 256;
   char buf[buflen + 1];
   char buf[buflen + 1];
-  int i;
 
 
   if (h.handle == NULL) {
   if (h.handle == NULL) {
     resp->err = strdup("cudart handle isn't initialized");
     resp->err = strdup("cudart handle isn't initialized");
     return;
     return;
   }
   }
 
 
-  // cudaGetDeviceCount takes int type, resp-> count is uint
-  int deviceCount;
-  ret = (*h.cudaGetDeviceCount)(&deviceCount);
+  ret = (*h.cudaSetDevice)(i);
   if (ret != CUDART_SUCCESS) {
   if (ret != CUDART_SUCCESS) {
-    snprintf(buf, buflen, "unable to get device count: %d", ret);
+    snprintf(buf, buflen, "cudart device failed to initialize");
     resp->err = strdup(buf);
     resp->err = strdup(buf);
     return;
     return;
-  } else {
-    resp->count = (unsigned int)deviceCount;
   }
   }
 
 
-  resp->total = 0;
-  resp->free = 0;
-  for (i = 0; i < resp-> count; i++) {  
-    ret = (*h.cudaSetDevice)(i);
-    if (ret != CUDART_SUCCESS) {
-      snprintf(buf, buflen, "cudart device failed to initialize");
-      resp->err = strdup(buf);
-      return;
+  cudaDeviceProp_t props;
+  ret = (*h.cudaGetDeviceProperties)(&props, i);
+  if (ret != CUDART_SUCCESS) {
+    LOG(h.verbose, "[%d] device properties lookup failure: %d\n", i, ret);
+    snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
+    resp->major = 0;
+    resp->minor = 0;
+  } else {
+    int allNull = 1;
+    for (int j = 0; j < 16; j++) {
+      if (props.uuid.bytes[j] != 0) {
+        allNull = 0;
+        break;
+      }
     }
     }
-    ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total);
-    if (ret != CUDART_SUCCESS) {
-      snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret);
-      resp->err = strdup(buf);
-      return;
+    if (allNull != 0) {
+      snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
+    } else {
+      // GPU-d110a105-ac29-1d54-7b49-9c90440f215b
+      snprintf(&resp->gpu_id[0], GPU_ID_LEN,
+          "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
+          props.uuid.bytes[0],
+          props.uuid.bytes[1],
+          props.uuid.bytes[2],
+          props.uuid.bytes[3],
+          props.uuid.bytes[4],
+          props.uuid.bytes[5],
+          props.uuid.bytes[6],
+          props.uuid.bytes[7],
+          props.uuid.bytes[8],
+          props.uuid.bytes[9],
+          props.uuid.bytes[10],
+          props.uuid.bytes[11],
+          props.uuid.bytes[12],
+          props.uuid.bytes[13],
+          props.uuid.bytes[14],
+          props.uuid.bytes[15]
+        );
     }
     }
+    resp->major = props.major;
+    resp->minor = props.minor;
 
 
-    LOG(h.verbose, "[%d] CUDA totalMem %lu\n", i, memInfo.total);
-    LOG(h.verbose, "[%d] CUDA freeMem %lu\n", i, memInfo.free);
-
-    resp->total += memInfo.total;
-    resp->free += memInfo.free;
-  }
-}
-
-void cudart_compute_capability(cudart_handle_t h, cudart_compute_capability_t *resp) {
-  resp->err = NULL;
-  resp->major = 0;
-  resp->minor = 0;
-  int major = 0;
-  int minor = 0;
-  cudartReturn_t ret;
-  const int buflen = 256;
-  char buf[buflen + 1];
-  int i;
-
-  if (h.handle == NULL) {
-    resp->err = strdup("cudart handle not initialized");
-    return;
+    // TODO add other useful properties from props
   }
   }
-
-  int devices;
-  ret = (*h.cudaGetDeviceCount)(&devices);
+  ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total);
   if (ret != CUDART_SUCCESS) {
   if (ret != CUDART_SUCCESS) {
-    snprintf(buf, buflen, "unable to get cudart device count: %d", ret);
+    snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret);
     resp->err = strdup(buf);
     resp->err = strdup(buf);
     return;
     return;
   }
   }
 
 
-  for (i = 0; i < devices; i++) {
-    ret = (*h.cudaSetDevice)(i);
-    if (ret != CUDART_SUCCESS) {
-      snprintf(buf, buflen, "cudart device failed to initialize");
-      resp->err = strdup(buf);
-      return;
-    }
+  resp->total = memInfo.total;
+  resp->free = memInfo.free;
 
 
-    ret = (*h.cudaDeviceGetAttribute)(&major, cudartDevAttrComputeCapabilityMajor, i);
-    if (ret != CUDART_SUCCESS) {
-      snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
-      resp->err = strdup(buf);
-      return;
-    }
-    ret = (*h.cudaDeviceGetAttribute)(&minor, cudartDevAttrComputeCapabilityMinor, i);
-    if (ret != CUDART_SUCCESS) {
-      snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
-      resp->err = strdup(buf);
-      return;
-    }
-      
-    // Report the lowest major.minor we detect as that limits our compatibility
-    if (resp->major == 0 || resp->major > major ) {
-      resp->major = major;
-      resp->minor = minor;
-    } else if ( resp->major == major && resp->minor > minor ) {
-      resp->minor = minor;
-    }
-  }
+  LOG(h.verbose, "[%s] CUDA totalMem %lu\n", resp->gpu_id, resp->total);
+  LOG(h.verbose, "[%s] CUDA freeMem %lu\n", resp->gpu_id, resp->free);
+  LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
 }
 }
 
 
 void cudart_release(cudart_handle_t h) {
 void cudart_release(cudart_handle_t h) {

+ 96 - 10
gpu/gpu_info_cudart.h

@@ -6,7 +6,8 @@
 // Just enough typedef's to dlopen/dlsym for memory information
 // Just enough typedef's to dlopen/dlsym for memory information
 typedef enum cudartReturn_enum {
 typedef enum cudartReturn_enum {
   CUDART_SUCCESS = 0,
   CUDART_SUCCESS = 0,
-  CUDART_UNSUPPORTED = 1,
+  CUDA_ERROR_INVALID_VALUE = 1,
+  CUDA_ERROR_MEMORY_ALLOCATION = 2,
   CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
   CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
   // Other values omitted for now...
   // Other values omitted for now...
 } cudartReturn_t;
 } cudartReturn_t;
@@ -14,6 +15,11 @@ typedef enum cudartReturn_enum {
 typedef enum cudartDeviceAttr_enum {
 typedef enum cudartDeviceAttr_enum {
   cudartDevAttrComputeCapabilityMajor = 75,
   cudartDevAttrComputeCapabilityMajor = 75,
   cudartDevAttrComputeCapabilityMinor = 76,
   cudartDevAttrComputeCapabilityMinor = 76,
+
+  // TODO - not yet wired up but may be useful for Jetson or other
+  // integrated GPU scenarios with shared memory
+  cudaDevAttrIntegrated = 18
+
 } cudartDeviceAttr_t;
 } cudartDeviceAttr_t;
 
 
 typedef void *cudartDevice_t;  // Opaque is sufficient
 typedef void *cudartDevice_t;  // Opaque is sufficient
@@ -28,6 +34,92 @@ typedef struct cudartDriverVersion {
   int minor;
   int minor;
 } cudartDriverVersion_t;
 } cudartDriverVersion_t;
 
 
+typedef struct cudaUUID {
+    unsigned char bytes[16];
+} cudaUUID_t;
+typedef struct cudaDeviceProp {
+    char         name[256];                  /**< ASCII string identifying device */
+    cudaUUID_t   uuid;                       /**< 16-byte unique identifier */
+    char         luid[8];                    /**< 8-byte locally unique identifier. Value is undefined on TCC and non-Windows platforms */
+    unsigned int luidDeviceNodeMask;         /**< LUID device node mask. Value is undefined on TCC and non-Windows platforms */
+    size_t       totalGlobalMem;             /**< Global memory available on device in bytes */
+    size_t       sharedMemPerBlock;          /**< Shared memory available per block in bytes */
+    int          regsPerBlock;               /**< 32-bit registers available per block */
+    int          warpSize;                   /**< Warp size in threads */
+    size_t       memPitch;                   /**< Maximum pitch in bytes allowed by memory copies */
+    int          maxThreadsPerBlock;         /**< Maximum number of threads per block */
+    int          maxThreadsDim[3];           /**< Maximum size of each dimension of a block */
+    int          maxGridSize[3];             /**< Maximum size of each dimension of a grid */
+    int          clockRate;                  /**< Clock frequency in kilohertz */
+    size_t       totalConstMem;              /**< Constant memory available on device in bytes */
+    int          major;                      /**< Major compute capability */
+    int          minor;                      /**< Minor compute capability */
+    size_t       textureAlignment;           /**< Alignment requirement for textures */
+    size_t       texturePitchAlignment;      /**< Pitch alignment requirement for texture references bound to pitched memory */
+    int          deviceOverlap;              /**< Device can concurrently copy memory and execute a kernel. Deprecated. Use instead asyncEngineCount. */
+    int          multiProcessorCount;        /**< Number of multiprocessors on device */
+    int          kernelExecTimeoutEnabled;   /**< Specified whether there is a run time limit on kernels */
+    int          integrated;                 /**< Device is integrated as opposed to discrete */
+    int          canMapHostMemory;           /**< Device can map host memory with cudaHostAlloc/cudaHostGetDevicePointer */
+    int          computeMode;                /**< Compute mode (See ::cudaComputeMode) */
+    int          maxTexture1D;               /**< Maximum 1D texture size */
+    int          maxTexture1DMipmap;         /**< Maximum 1D mipmapped texture size */
+    int          maxTexture1DLinear;         /**< Deprecated, do not use. Use cudaDeviceGetTexture1DLinearMaxWidth() or cuDeviceGetTexture1DLinearMaxWidth() instead. */
+    int          maxTexture2D[2];            /**< Maximum 2D texture dimensions */
+    int          maxTexture2DMipmap[2];      /**< Maximum 2D mipmapped texture dimensions */
+    int          maxTexture2DLinear[3];      /**< Maximum dimensions (width, height, pitch) for 2D textures bound to pitched memory */
+    int          maxTexture2DGather[2];      /**< Maximum 2D texture dimensions if texture gather operations have to be performed */
+    int          maxTexture3D[3];            /**< Maximum 3D texture dimensions */
+    int          maxTexture3DAlt[3];         /**< Maximum alternate 3D texture dimensions */
+    int          maxTextureCubemap;          /**< Maximum Cubemap texture dimensions */
+    int          maxTexture1DLayered[2];     /**< Maximum 1D layered texture dimensions */
+    int          maxTexture2DLayered[3];     /**< Maximum 2D layered texture dimensions */
+    int          maxTextureCubemapLayered[2];/**< Maximum Cubemap layered texture dimensions */
+    int          maxSurface1D;               /**< Maximum 1D surface size */
+    int          maxSurface2D[2];            /**< Maximum 2D surface dimensions */
+    int          maxSurface3D[3];            /**< Maximum 3D surface dimensions */
+    int          maxSurface1DLayered[2];     /**< Maximum 1D layered surface dimensions */
+    int          maxSurface2DLayered[3];     /**< Maximum 2D layered surface dimensions */
+    int          maxSurfaceCubemap;          /**< Maximum Cubemap surface dimensions */
+    int          maxSurfaceCubemapLayered[2];/**< Maximum Cubemap layered surface dimensions */
+    size_t       surfaceAlignment;           /**< Alignment requirements for surfaces */
+    int          concurrentKernels;          /**< Device can possibly execute multiple kernels concurrently */
+    int          ECCEnabled;                 /**< Device has ECC support enabled */
+    int          pciBusID;                   /**< PCI bus ID of the device */
+    int          pciDeviceID;                /**< PCI device ID of the device */
+    int          pciDomainID;                /**< PCI domain ID of the device */
+    int          tccDriver;                  /**< 1 if device is a Tesla device using TCC driver, 0 otherwise */
+    int          asyncEngineCount;           /**< Number of asynchronous engines */
+    int          unifiedAddressing;          /**< Device shares a unified address space with the host */
+    int          memoryClockRate;            /**< Peak memory clock frequency in kilohertz */
+    int          memoryBusWidth;             /**< Global memory bus width in bits */
+    int          l2CacheSize;                /**< Size of L2 cache in bytes */
+    int          persistingL2CacheMaxSize;   /**< Device's maximum l2 persisting lines capacity setting in bytes */
+    int          maxThreadsPerMultiProcessor;/**< Maximum resident threads per multiprocessor */
+    int          streamPrioritiesSupported;  /**< Device supports stream priorities */
+    int          globalL1CacheSupported;     /**< Device supports caching globals in L1 */
+    int          localL1CacheSupported;      /**< Device supports caching locals in L1 */
+    size_t       sharedMemPerMultiprocessor; /**< Shared memory available per multiprocessor in bytes */
+    int          regsPerMultiprocessor;      /**< 32-bit registers available per multiprocessor */
+    int          managedMemory;              /**< Device supports allocating managed memory on this system */
+    int          isMultiGpuBoard;            /**< Device is on a multi-GPU board */
+    int          multiGpuBoardGroupID;       /**< Unique identifier for a group of devices on the same multi-GPU board */
+    int          hostNativeAtomicSupported;  /**< Link between the device and the host supports native atomic operations */
+    int          singleToDoublePrecisionPerfRatio; /**< Ratio of single precision performance (in floating-point operations per second) to double precision performance */
+    int          pageableMemoryAccess;       /**< Device supports coherently accessing pageable memory without calling cudaHostRegister on it */
+    int          concurrentManagedAccess;    /**< Device can coherently access managed memory concurrently with the CPU */
+    int          computePreemptionSupported; /**< Device supports Compute Preemption */
+    int          canUseHostPointerForRegisteredMem; /**< Device can access host registered memory at the same virtual address as the CPU */
+    int          cooperativeLaunch;          /**< Device supports launching cooperative kernels via ::cudaLaunchCooperativeKernel */
+    int          cooperativeMultiDeviceLaunch; /**< Deprecated, cudaLaunchCooperativeKernelMultiDevice is deprecated. */
+    size_t       sharedMemPerBlockOptin;     /**< Per device maximum shared memory per block usable by special opt in */
+    int          pageableMemoryAccessUsesHostPageTables; /**< Device accesses pageable memory via the host's page tables */
+    int          directManagedMemAccessFromHost; /**< Host can directly access managed memory on the device without migration. */
+    int          maxBlocksPerMultiProcessor; /**< Maximum number of resident blocks per multiprocessor */
+    int          accessPolicyMaxWindowSize;  /**< The maximum value of ::cudaAccessPolicyWindow::num_bytes. */
+    size_t       reservedSharedMemPerBlock;  /**< Shared memory reserved by CUDA driver per block in bytes */
+  } cudaDeviceProp_t;
+
 typedef struct cudart_handle {
 typedef struct cudart_handle {
   void *handle;
   void *handle;
   uint16_t verbose;
   uint16_t verbose;
@@ -38,23 +130,17 @@ typedef struct cudart_handle {
   cudartReturn_t (*cudaGetDeviceCount)(int *);
   cudartReturn_t (*cudaGetDeviceCount)(int *);
   cudartReturn_t (*cudaDeviceGetAttribute)(int* value, cudartDeviceAttr_t attr, int device);
   cudartReturn_t (*cudaDeviceGetAttribute)(int* value, cudartDeviceAttr_t attr, int device);
   cudartReturn_t (*cudaDriverGetVersion) (int *driverVersion);
   cudartReturn_t (*cudaDriverGetVersion) (int *driverVersion);
+  cudartReturn_t (*cudaGetDeviceProperties) (cudaDeviceProp_t* prop, int device);
 } cudart_handle_t;
 } cudart_handle_t;
 
 
 typedef struct cudart_init_resp {
 typedef struct cudart_init_resp {
   char *err;  // If err is non-null handle is invalid
   char *err;  // If err is non-null handle is invalid
   cudart_handle_t ch;
   cudart_handle_t ch;
+  int num_devices;
 } cudart_init_resp_t;
 } cudart_init_resp_t;
 
 
-typedef struct cudart_compute_capability {
-  char *err;
-  int major;
-  int minor;
-} cudart_compute_capability_t;
-
-
 void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp);
 void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp);
-void cudart_check_vram(cudart_handle_t ch, mem_info_t *resp);
-void cudart_compute_capability(cudart_handle_t th, cudart_compute_capability_t *cc);
+void cudart_check_vram(cudart_handle_t ch, int device_id, mem_info_t *resp);
 void cudart_release(cudart_handle_t ch);
 void cudart_release(cudart_handle_t ch);
 
 
 #endif  // __GPU_INFO_CUDART_H__
 #endif  // __GPU_INFO_CUDART_H__

+ 0 - 221
gpu/gpu_info_nvml.c

@@ -1,221 +0,0 @@
-#ifndef __APPLE__  // TODO - maybe consider nvidia support on intel macs?
-
-#include <string.h>
-
-#include "gpu_info_nvml.h"
-
-void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) {
-  nvmlReturn_t ret;
-  resp->err = NULL;
-  const int buflen = 256;
-  char buf[buflen + 1];
-  int i;
-
-  struct lookup {
-    char *s;
-    void **p;
-  } l[] = {
-      {"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2},
-      {"nvmlShutdown", (void *)&resp->ch.nvmlShutdown},
-      {"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.nvmlDeviceGetHandleByIndex},
-      {"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo},
-      {"nvmlDeviceGetCount_v2", (void *)&resp->ch.nvmlDeviceGetCount_v2},
-      {"nvmlDeviceGetCudaComputeCapability", (void *)&resp->ch.nvmlDeviceGetCudaComputeCapability},
-      {"nvmlSystemGetDriverVersion", (void *)&resp->ch.nvmlSystemGetDriverVersion},
-      {"nvmlDeviceGetName", (void *)&resp->ch.nvmlDeviceGetName},
-      {"nvmlDeviceGetSerial", (void *)&resp->ch.nvmlDeviceGetSerial},
-      {"nvmlDeviceGetVbiosVersion", (void *)&resp->ch.nvmlDeviceGetVbiosVersion},
-      {"nvmlDeviceGetBoardPartNumber", (void *)&resp->ch.nvmlDeviceGetBoardPartNumber},
-      {"nvmlDeviceGetBrand", (void *)&resp->ch.nvmlDeviceGetBrand},
-      {NULL, NULL},
-  };
-
-  resp->ch.handle = LOAD_LIBRARY(nvml_lib_path, RTLD_LAZY);
-  if (!resp->ch.handle) {
-    char *msg = LOAD_ERR();
-    LOG(resp->ch.verbose, "library %s load err: %s\n", nvml_lib_path, msg);
-    snprintf(buf, buflen,
-             "Unable to load %s library to query for Nvidia GPUs: %s",
-             nvml_lib_path, msg);
-    free(msg);
-    resp->err = strdup(buf);
-    return;
-  }
-
-  // TODO once we've squashed the remaining corner cases remove this log
-  LOG(resp->ch.verbose, "wiring nvidia management library functions in %s\n", nvml_lib_path);
-  
-  for (i = 0; l[i].s != NULL; i++) {
-    // TODO once we've squashed the remaining corner cases remove this log
-    LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
-
-    *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
-    if (!l[i].p) {
-      resp->ch.handle = NULL;
-      char *msg = LOAD_ERR();
-      LOG(resp->ch.verbose, "dlerr: %s\n", msg);
-      UNLOAD_LIBRARY(resp->ch.handle);
-      snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
-               msg);
-      free(msg);
-      resp->err = strdup(buf);
-      return;
-    }
-  }
-
-  ret = (*resp->ch.nvmlInit_v2)();
-  if (ret != NVML_SUCCESS) {
-    LOG(resp->ch.verbose, "nvmlInit_v2 err: %d\n", ret);
-    UNLOAD_LIBRARY(resp->ch.handle);
-    resp->ch.handle = NULL;
-    snprintf(buf, buflen, "nvml vram init failure: %d", ret);
-    resp->err = strdup(buf);
-    return;
-  }
-
-  // Report driver version if we're in verbose mode, ignore errors
-  ret = (*resp->ch.nvmlSystemGetDriverVersion)(buf, buflen);
-  if (ret != NVML_SUCCESS) {
-    LOG(resp->ch.verbose, "nvmlSystemGetDriverVersion failed: %d\n", ret);
-  } else {
-    LOG(resp->ch.verbose, "CUDA driver version: %s\n", buf);
-  }
-}
-
-void nvml_check_vram(nvml_handle_t h, mem_info_t *resp) {
-  resp->err = NULL;
-  nvmlDevice_t device;
-  nvmlMemory_t memInfo = {0};
-  nvmlReturn_t ret;
-  const int buflen = 256;
-  char buf[buflen + 1];
-  int i;
-
-  if (h.handle == NULL) {
-    resp->err = strdup("nvml handle isn't initialized");
-    return;
-  }
-
-  ret = (*h.nvmlDeviceGetCount_v2)(&resp->count);
-  if (ret != NVML_SUCCESS) {
-    snprintf(buf, buflen, "unable to get device count: %d", ret);
-    resp->err = strdup(buf);
-    return;
-  }
-
-  resp->total = 0;
-  resp->free = 0;
-  for (i = 0; i < resp->count; i++) {
-    ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
-    if (ret != NVML_SUCCESS) {
-      snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
-      resp->err = strdup(buf);
-      return;
-    }
-
-    ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo);
-    if (ret != NVML_SUCCESS) {
-      snprintf(buf, buflen, "device memory info lookup failure %d: %d", i, ret);
-      resp->err = strdup(buf);
-      return;
-    }
-    if (h.verbose) {
-      nvmlBrandType_t brand = 0;
-      // When in verbose mode, report more information about
-      // the card we discover, but don't fail on error
-      ret = (*h.nvmlDeviceGetName)(device, buf, buflen);
-      if (ret != NVML_SUCCESS) {
-        LOG(h.verbose, "nvmlDeviceGetName failed: %d\n", ret);
-      } else {
-        LOG(h.verbose, "[%d] CUDA device name: %s\n", i, buf);
-      }
-      ret = (*h.nvmlDeviceGetBoardPartNumber)(device, buf, buflen);
-      if (ret != NVML_SUCCESS) {
-        LOG(h.verbose, "nvmlDeviceGetBoardPartNumber failed: %d\n", ret);
-      } else {
-        LOG(h.verbose, "[%d] CUDA part number: %s\n", i, buf);
-      }
-      ret = (*h.nvmlDeviceGetSerial)(device, buf, buflen);
-      if (ret != NVML_SUCCESS) {
-        LOG(h.verbose, "nvmlDeviceGetSerial failed: %d\n", ret);
-      } else {
-        LOG(h.verbose, "[%d] CUDA S/N: %s\n", i, buf);
-      }
-      ret = (*h.nvmlDeviceGetVbiosVersion)(device, buf, buflen);
-      if (ret != NVML_SUCCESS) {
-        LOG(h.verbose, "nvmlDeviceGetVbiosVersion failed: %d\n", ret);
-      } else {
-        LOG(h.verbose, "[%d] CUDA vbios version: %s\n", i, buf);
-      }
-      ret = (*h.nvmlDeviceGetBrand)(device, &brand);
-      if (ret != NVML_SUCCESS) {
-        LOG(h.verbose, "nvmlDeviceGetBrand failed: %d\n", ret);
-      } else {
-        LOG(h.verbose, "[%d] CUDA brand: %d\n", i, brand);
-      }
-    }
-
-    LOG(h.verbose, "[%d] CUDA totalMem %ld\n", i, memInfo.total);
-    LOG(h.verbose, "[%d] CUDA freeMem %ld\n", i, memInfo.free);
-
-    resp->total += memInfo.total;
-    resp->free += memInfo.free;
-  }
-}
-
-void nvml_compute_capability(nvml_handle_t h, nvml_compute_capability_t *resp) {
-  resp->err = NULL;
-  resp->major = 0;
-  resp->minor = 0;
-  nvmlDevice_t device;
-  int major = 0;
-  int minor = 0;
-  nvmlReturn_t ret;
-  const int buflen = 256;
-  char buf[buflen + 1];
-  int i;
-
-  if (h.handle == NULL) {
-    resp->err = strdup("nvml handle not initialized");
-    return;
-  }
-
-  unsigned int devices;
-  ret = (*h.nvmlDeviceGetCount_v2)(&devices);
-  if (ret != NVML_SUCCESS) {
-    snprintf(buf, buflen, "unable to get device count: %d", ret);
-    resp->err = strdup(buf);
-    return;
-  }
-
-  for (i = 0; i < devices; i++) {
-    ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
-    if (ret != NVML_SUCCESS) {
-      snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
-      resp->err = strdup(buf);
-      return;
-    }
-
-    ret = (*h.nvmlDeviceGetCudaComputeCapability)(device, &major, &minor);
-    if (ret != NVML_SUCCESS) {
-      snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
-      resp->err = strdup(buf);
-      return;
-    }
-    // Report the lowest major.minor we detect as that limits our compatibility
-    if (resp->major == 0 || resp->major > major ) {
-      resp->major = major;
-      resp->minor = minor;
-    } else if ( resp->major == major && resp->minor > minor ) {
-      resp->minor = minor;
-    }
-  }
-}
-
-void nvml_release(nvml_handle_t h) {
-  LOG(h.verbose, "releasing nvml library\n");
-  UNLOAD_LIBRARY(h.handle);
-  h.handle = NULL;
-}
-
-#endif  // __APPLE__

+ 0 - 57
gpu/gpu_info_nvml.h

@@ -1,57 +0,0 @@
-#ifndef __APPLE__
-#ifndef __GPU_INFO_NVML_H__
-#define __GPU_INFO_NVML_H__
-#include "gpu_info.h"
-
-// Just enough typedef's to dlopen/dlsym for memory information
-typedef enum nvmlReturn_enum {
-  NVML_SUCCESS = 0,
-  // Other values omitted for now...
-} nvmlReturn_t;
-typedef void *nvmlDevice_t;  // Opaque is sufficient
-typedef struct nvmlMemory_st {
-  unsigned long long total;
-  unsigned long long free;
-  unsigned long long used;
-} nvmlMemory_t;
-
-typedef enum nvmlBrandType_enum
-{
-    NVML_BRAND_UNKNOWN          = 0,
-} nvmlBrandType_t;
-
-typedef struct nvml_handle {
-  void *handle;
-  uint16_t verbose;
-  nvmlReturn_t (*nvmlInit_v2)(void);
-  nvmlReturn_t (*nvmlShutdown)(void);
-  nvmlReturn_t (*nvmlDeviceGetHandleByIndex)(unsigned int, nvmlDevice_t *);
-  nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
-  nvmlReturn_t (*nvmlDeviceGetCount_v2)(unsigned int *);
-  nvmlReturn_t (*nvmlDeviceGetCudaComputeCapability)(nvmlDevice_t, int* major, int* minor);
-  nvmlReturn_t (*nvmlSystemGetDriverVersion) (char* version, unsigned int  length);
-  nvmlReturn_t (*nvmlDeviceGetName) (nvmlDevice_t device, char* name, unsigned int  length);
-  nvmlReturn_t (*nvmlDeviceGetSerial) (nvmlDevice_t device, char* serial, unsigned int  length);
-  nvmlReturn_t (*nvmlDeviceGetVbiosVersion) (nvmlDevice_t device, char* version, unsigned int  length);
-  nvmlReturn_t (*nvmlDeviceGetBoardPartNumber) (nvmlDevice_t device, char* partNumber, unsigned int  length);
-  nvmlReturn_t (*nvmlDeviceGetBrand) (nvmlDevice_t device, nvmlBrandType_t* type);
-} nvml_handle_t;
-
-typedef struct nvml_init_resp {
-  char *err;  // If err is non-null handle is invalid
-  nvml_handle_t ch;
-} nvml_init_resp_t;
-
-typedef struct nvml_compute_capability {
-  char *err;
-  int major;
-  int minor;
-} nvml_compute_capability_t;
-
-void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp);
-void nvml_check_vram(nvml_handle_t ch, mem_info_t *resp);
-void nvml_compute_capability(nvml_handle_t ch, nvml_compute_capability_t *cc);
-void nvml_release(nvml_handle_t ch);
-
-#endif  // __GPU_INFO_NVML_H__
-#endif  // __APPLE__

+ 6 - 13
gpu/gpu_test.go

@@ -9,23 +9,16 @@ import (
 
 
 func TestBasicGetGPUInfo(t *testing.T) {
 func TestBasicGetGPUInfo(t *testing.T) {
 	info := GetGPUInfo()
 	info := GetGPUInfo()
-	assert.Contains(t, "cuda rocm cpu metal", info.Library)
-
-	switch runtime.GOOS {
-	case "darwin":
-		// TODO - remove this once MacOS returns some size for CPU
-		return
-	case "linux", "windows":
-		assert.Greater(t, info.TotalMemory, uint64(0))
-		assert.Greater(t, info.FreeMemory, uint64(0))
-		assert.Greater(t, info.DeviceCount, uint32(0))
-	default:
-		return
+	assert.Greater(t, len(info), 0)
+	assert.Contains(t, "cuda rocm cpu metal", info[0].Library)
+	if info[0].Library != "cpu" {
+		assert.Greater(t, info[0].TotalMemory, uint64(0))
+		assert.Greater(t, info[0].FreeMemory, uint64(0))
 	}
 	}
 }
 }
 
 
 func TestCPUMemInfo(t *testing.T) {
 func TestCPUMemInfo(t *testing.T) {
-	info, err := getCPUMem()
+	info, err := GetCPUMem()
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	switch runtime.GOOS {
 	switch runtime.GOOS {
 	case "darwin":
 	case "darwin":

+ 43 - 6
gpu/types.go

@@ -3,7 +3,6 @@ package gpu
 type memInfo struct {
 type memInfo struct {
 	TotalMemory uint64 `json:"total_memory,omitempty"`
 	TotalMemory uint64 `json:"total_memory,omitempty"`
 	FreeMemory  uint64 `json:"free_memory,omitempty"`
 	FreeMemory  uint64 `json:"free_memory,omitempty"`
-	DeviceCount uint32 `json:"device_count,omitempty"`
 }
 }
 
 
 // Beginning of an `ollama info` command
 // Beginning of an `ollama info` command
@@ -17,11 +16,49 @@ type GpuInfo struct {
 	// MinimumMemory represents the minimum memory required to use the GPU
 	// MinimumMemory represents the minimum memory required to use the GPU
 	MinimumMemory uint64 `json:"-"`
 	MinimumMemory uint64 `json:"-"`
 
 
-	// TODO add other useful attributes about the card here for discovery information
+	// Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly
+	DependencyPath string `json:"lib_path,omitempty"`
+
+	// GPU information
+	ID    string `json:"gpu_id"`          // string to use for selection of this specific GPU
+	Name  string `json:"name"`            // user friendly name if available
+	Major int    `json:"major,omitempty"` // Major compatibility version (CC or gfx)
+	Minor int    `json:"minor,omitempty"` // Minor compatibility version (CC or gfx)
+	Patch int    `json:"patch,omitempty"` // Patch compatibility only matters on AMD
+
+	// TODO other performance capability info to help in scheduling decisions
 }
 }
 
 
-type Version struct {
-	Major uint
-	Minor uint
-	Patch uint
+type GpuInfoList []GpuInfo
+
+// Split up the set of gpu info's by Library and variant
+func (l GpuInfoList) ByLibrary() []GpuInfoList {
+	resp := []GpuInfoList{}
+	libs := []string{}
+	for _, info := range l {
+		found := false
+		requested := info.Library
+		if info.Variant != "" {
+			requested += "_" + info.Variant
+		}
+		for i, lib := range libs {
+			if lib == requested {
+				resp[i] = append(resp[i], info)
+				found = true
+				break
+			}
+		}
+		if !found {
+			libs = append(libs, info.Library)
+			resp = append(resp, []GpuInfo{info})
+		}
+	}
+	return resp
 }
 }
+
+// Sort by Free Space
+type ByFreeMemory []GpuInfo
+
+func (a ByFreeMemory) Len() int           { return len(a) }
+func (a ByFreeMemory) Swap(i, j int)      { a[i], a[j] = a[j], a[i] }
+func (a ByFreeMemory) Less(i, j int) bool { return a[i].FreeMemory < a[j].FreeMemory }

+ 1 - 2
integration/basic_test.go

@@ -4,7 +4,6 @@ package integration
 
 
 import (
 import (
 	"context"
 	"context"
-	"net/http"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
@@ -24,5 +23,5 @@ func TestOrcaMiniBlueSky(t *testing.T) {
 			"seed":        123,
 			"seed":        123,
 		},
 		},
 	}
 	}
-	GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"rayleigh", "scattering"})
+	GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
 }
 }

+ 225 - 0
integration/concurrency_test.go

@@ -0,0 +1,225 @@
+//go:build integration
+
+package integration
+
+import (
+	"context"
+	"log/slog"
+	"os"
+	"strconv"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/ollama/ollama/api"
+	"github.com/stretchr/testify/require"
+)
+
+func TestMultiModelConcurrency(t *testing.T) {
+	var (
+		req = [2]api.GenerateRequest{
+			{
+				Model:  "orca-mini",
+				Prompt: "why is the ocean blue?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			}, {
+				Model:  "tinydolphin",
+				Prompt: "what is the origin of the us thanksgiving holiday?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			},
+		}
+		resp = [2][]string{
+			[]string{"sunlight"},
+			[]string{"england", "english", "massachusetts", "pilgrims"},
+		}
+	)
+	var wg sync.WaitGroup
+	wg.Add(len(req))
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
+	defer cancel()
+	for i := 0; i < len(req); i++ {
+		go func(i int) {
+			defer wg.Done()
+			GenerateTestHelper(ctx, t, req[i], resp[i])
+		}(i)
+	}
+	wg.Wait()
+}
+
+func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) // GTX 750 2G card takes ~9 minutes
+	defer cancel()
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+
+	req, resp := GenerateRequests()
+	// Get the server running (if applicable) warm the model up with a single initial request
+	DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 5*time.Second)
+
+	var wg sync.WaitGroup
+	wg.Add(len(req))
+	for i := 0; i < len(req); i++ {
+		go func(i int) {
+			defer wg.Done()
+			for j := 0; j < 5; j++ {
+				slog.Info("Starting", "req", i, "iter", j)
+				// On slower GPUs it can take a while to process the 4 concurrent requests
+				// so we allow a much longer initial timeout
+				DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
+			}
+		}(i)
+	}
+	wg.Wait()
+}
+
+// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
+func TestMultiModelStress(t *testing.T) {
+	vram := os.Getenv("OLLAMA_MAX_VRAM")
+	if vram == "" {
+		t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
+	}
+	max, err := strconv.ParseUint(vram, 10, 64)
+	require.NoError(t, err)
+	const MB = uint64(1024 * 1024)
+	type model struct {
+		name string
+		size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM
+	}
+
+	smallModels := []model{
+		{
+			name: "orca-mini",
+			size: 2992 * MB,
+		},
+		{
+			name: "phi",
+			size: 2616 * MB,
+		},
+		{
+			name: "gemma:2b",
+			size: 2364 * MB,
+		},
+		{
+			name: "stable-code:3b",
+			size: 2608 * MB,
+		},
+		{
+			name: "starcoder2:3b",
+			size: 2166 * MB,
+		},
+	}
+	mediumModels := []model{
+		{
+			name: "llama2",
+			size: 5118 * MB,
+		},
+		{
+			name: "mistral",
+			size: 4620 * MB,
+		},
+		{
+			name: "orca-mini:7b",
+			size: 5118 * MB,
+		},
+		{
+			name: "dolphin-mistral",
+			size: 4620 * MB,
+		},
+		{
+			name: "gemma:7b",
+			size: 5000 * MB,
+		},
+		// TODO - uncomment this once #3565 is merged and this is rebased on it
+		// {
+		// 	name: "codellama:7b",
+		// 	size: 5118 * MB,
+		// },
+	}
+
+	// These seem to be too slow to be useful...
+	// largeModels := []model{
+	// 	{
+	// 		name: "llama2:13b",
+	// 		size: 7400 * MB,
+	// 	},
+	// 	{
+	// 		name: "codellama:13b",
+	// 		size: 7400 * MB,
+	// 	},
+	// 	{
+	// 		name: "orca-mini:13b",
+	// 		size: 7400 * MB,
+	// 	},
+	// 	{
+	// 		name: "gemma:7b",
+	// 		size: 5000 * MB,
+	// 	},
+	// 	{
+	// 		name: "starcoder2:15b",
+	// 		size: 9100 * MB,
+	// 	},
+	// }
+
+	var chosenModels []model
+	switch {
+	case max < 10000*MB:
+		slog.Info("selecting small models")
+		chosenModels = smallModels
+	// case max < 30000*MB:
+	default:
+		slog.Info("selecting medium models")
+		chosenModels = mediumModels
+		// default:
+		// 	slog.Info("selecting large models")
+		// 	chosenModels = largModels
+	}
+
+	req, resp := GenerateRequests()
+
+	for i := range req {
+		if i > len(chosenModels) {
+			break
+		}
+		req[i].Model = chosenModels[i].name
+	}
+
+	ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) // TODO baseline -- 10m too short
+	defer cancel()
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+
+	// Make sure all the models are pulled before we get started
+	for _, r := range req {
+		require.NoError(t, PullIfMissing(ctx, client, r.Model))
+	}
+
+	var wg sync.WaitGroup
+	consumed := uint64(256 * MB) // Assume some baseline usage
+	for i := 0; i < len(req); i++ {
+		// Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long
+		if i > 1 && consumed > max {
+			slog.Info("achieved target vram exhaustion", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
+			break
+		}
+		consumed += chosenModels[i].size
+		slog.Info("target vram", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
+
+		wg.Add(1)
+		go func(i int) {
+			defer wg.Done()
+			for j := 0; j < 3; j++ {
+				slog.Info("Starting", "req", i, "iter", j, "model", req[i].Model)
+				DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
+			}
+		}(i)
+	}
+	wg.Wait()
+}

+ 1 - 2
integration/context_test.go

@@ -4,7 +4,6 @@ package integration
 
 
 import (
 import (
 	"context"
 	"context"
-	"net/http"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
@@ -25,5 +24,5 @@ func TestContextExhaustion(t *testing.T) {
 			"num_ctx":     128,
 			"num_ctx":     128,
 		},
 		},
 	}
 	}
-	GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"once", "upon", "lived"})
+	GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"})
 }
 }

+ 3 - 3
integration/llm_image_test.go

@@ -5,7 +5,6 @@ package integration
 import (
 import (
 	"context"
 	"context"
 	"encoding/base64"
 	"encoding/base64"
-	"net/http"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
@@ -29,10 +28,11 @@ func TestIntegrationMultimodal(t *testing.T) {
 		},
 		},
 	}
 	}
 
 
-	resp := "the ollamas"
+	// Note: sometimes it returns "the ollamas" sometimes "the ollams"
+	resp := "the ollam"
 	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
 	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
 	defer cancel()
 	defer cancel()
-	GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp})
+	GenerateTestHelper(ctx, t, req, []string{resp})
 }
 }
 
 
 const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
 const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb

+ 1 - 23
integration/llm_test.go

@@ -4,8 +4,6 @@ package integration
 
 
 import (
 import (
 	"context"
 	"context"
-	"net/http"
-	"sync"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
@@ -45,25 +43,5 @@ var (
 func TestIntegrationSimpleOrcaMini(t *testing.T) {
 func TestIntegrationSimpleOrcaMini(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
 	defer cancel()
 	defer cancel()
-	GenerateTestHelper(ctx, t, &http.Client{}, req[0], resp[0])
+	GenerateTestHelper(ctx, t, req[0], resp[0])
 }
 }
-
-// TODO
-// The server always loads a new runner and closes the old one, which forces serial execution
-// At present this test case fails with concurrency problems.  Eventually we should try to
-// get true concurrency working with n_parallel support in the backend
-func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
-	var wg sync.WaitGroup
-	wg.Add(len(req))
-	ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
-	defer cancel()
-	for i := 0; i < len(req); i++ {
-		go func(i int) {
-			defer wg.Done()
-			GenerateTestHelper(ctx, t, &http.Client{}, req[i], resp[i])
-		}(i)
-	}
-	wg.Wait()
-}
-
-// TODO - create a parallel test with 2 different models once we support concurrency

+ 171 - 98
integration/utils_test.go

@@ -5,13 +5,14 @@ package integration
 import (
 import (
 	"bytes"
 	"bytes"
 	"context"
 	"context"
-	"encoding/json"
+	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"log/slog"
 	"log/slog"
 	"math/rand"
 	"math/rand"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
+	"net/url"
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
 	"runtime"
 	"runtime"
@@ -23,9 +24,13 @@ import (
 
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/app/lifecycle"
 	"github.com/ollama/ollama/app/lifecycle"
-	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 )
 
 
+func Init() {
+	lifecycle.InitLogging()
+}
+
 func FindPort() string {
 func FindPort() string {
 	port := 0
 	port := 0
 	if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
 	if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
@@ -41,7 +46,7 @@ func FindPort() string {
 	return strconv.Itoa(port)
 	return strconv.Itoa(port)
 }
 }
 
 
-func GetTestEndpoint() (string, string) {
+func GetTestEndpoint() (*api.Client, string) {
 	defaultPort := "11434"
 	defaultPort := "11434"
 	ollamaHost := os.Getenv("OLLAMA_HOST")
 	ollamaHost := os.Getenv("OLLAMA_HOST")
 
 
@@ -67,16 +72,20 @@ func GetTestEndpoint() (string, string) {
 		port = FindPort()
 		port = FindPort()
 	}
 	}
 
 
-	url := fmt.Sprintf("%s:%s", host, port)
-	slog.Info("server connection", "url", url)
-	return scheme, url
+	slog.Info("server connection", "host", host, "port", port)
+
+	return api.NewClient(
+		&url.URL{
+			Scheme: scheme,
+			Host:   net.JoinHostPort(host, port),
+		},
+		http.DefaultClient), fmt.Sprintf("%s:%s", host, port)
 }
 }
 
 
-// TODO make fanicier, grab logs, etc.
 var serverMutex sync.Mutex
 var serverMutex sync.Mutex
 var serverReady bool
 var serverReady bool
 
 
-func StartServer(ctx context.Context, ollamaHost string) error {
+func startServer(ctx context.Context, ollamaHost string) error {
 	// Make sure the server has been built
 	// Make sure the server has been built
 	CLIName, err := filepath.Abs("../ollama")
 	CLIName, err := filepath.Abs("../ollama")
 	if err != nil {
 	if err != nil {
@@ -125,67 +134,76 @@ func StartServer(ctx context.Context, ollamaHost string) error {
 	return nil
 	return nil
 }
 }
 
 
-func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoint, modelName string) error {
+func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error {
 	slog.Info("checking status of model", "model", modelName)
 	slog.Info("checking status of model", "model", modelName)
 	showReq := &api.ShowRequest{Name: modelName}
 	showReq := &api.ShowRequest{Name: modelName}
-	requestJSON, err := json.Marshal(showReq)
-	if err != nil {
-		return err
-	}
-
-	req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/show", bytes.NewReader(requestJSON))
-	if err != nil {
-		return err
-	}
 
 
-	// Make the request with the HTTP client
-	response, err := client.Do(req.WithContext(ctx))
-	if err != nil {
+	showCtx, cancel := context.WithDeadlineCause(
+		ctx,
+		time.Now().Add(5*time.Second),
+		fmt.Errorf("show for existing model %s took too long", modelName),
+	)
+	defer cancel()
+	_, err := client.Show(showCtx, showReq)
+	var statusError api.StatusError
+	switch {
+	case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
+		break
+	case err != nil:
 		return err
 		return err
-	}
-	defer response.Body.Close()
-	if response.StatusCode == 200 {
+	default:
 		slog.Info("model already present", "model", modelName)
 		slog.Info("model already present", "model", modelName)
 		return nil
 		return nil
 	}
 	}
-	slog.Info("model missing", "status", response.StatusCode)
+	slog.Info("model missing", "model", modelName)
+
+	stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models
+	stallTimer := time.NewTimer(stallDuration)
+	fn := func(resp api.ProgressResponse) error {
+		// fmt.Print(".")
+		if !stallTimer.Reset(stallDuration) {
+			return fmt.Errorf("stall was detected, aborting status reporting")
+		}
+		return nil
+	}
 
 
+	stream := true
 	pullReq := &api.PullRequest{Name: modelName, Stream: &stream}
 	pullReq := &api.PullRequest{Name: modelName, Stream: &stream}
-	requestJSON, err = json.Marshal(pullReq)
-	if err != nil {
-		return err
-	}
 
 
-	req, err = http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/pull", bytes.NewReader(requestJSON))
-	if err != nil {
-		return err
-	}
-	slog.Info("pulling", "model", modelName)
+	var pullError error
 
 
-	response, err = client.Do(req.WithContext(ctx))
-	if err != nil {
-		return err
-	}
-	defer response.Body.Close()
-	if response.StatusCode != 200 {
-		return fmt.Errorf("failed to pull model") // TODO more details perhaps
+	done := make(chan int)
+	go func() {
+		pullError = client.Pull(ctx, pullReq, fn)
+		done <- 0
+	}()
+
+	select {
+	case <-stallTimer.C:
+		return fmt.Errorf("download stalled")
+	case <-done:
+		return pullError
 	}
 	}
-	slog.Info("model pulled", "model", modelName)
-	return nil
 }
 }
 
 
 var serverProcMutex sync.Mutex
 var serverProcMutex sync.Mutex
 
 
-func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) {
-
-	// TODO maybe stuff in an init routine?
-	lifecycle.InitLogging()
-
-	requestJSON, err := json.Marshal(genReq)
-	if err != nil {
-		t.Fatalf("Error serializing request: %v", err)
+// Returns an Client, the testEndpoint, and a cleanup function, fails the test on errors
+// Starts the server if needed
+func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, string, func()) {
+	client, testEndpoint := GetTestEndpoint()
+	if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
+		serverProcMutex.Lock()
+		fp, err := os.CreateTemp("", "ollama-server-*.log")
+		if err != nil {
+			t.Fatalf("failed to generate log file: %s", err)
+		}
+		lifecycle.ServerLogFile = fp.Name()
+		fp.Close()
+		require.NoError(t, startServer(ctx, testEndpoint))
 	}
 	}
-	defer func() {
+
+	return client, testEndpoint, func() {
 		if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
 		if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
 			defer serverProcMutex.Unlock()
 			defer serverProcMutex.Unlock()
 			if t.Failed() {
 			if t.Failed() {
@@ -203,63 +221,118 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client,
 				os.Stderr.Write(data)
 				os.Stderr.Write(data)
 				slog.Warn("END OF SERVER")
 				slog.Warn("END OF SERVER")
 			}
 			}
-			err = os.Remove(lifecycle.ServerLogFile)
+			err := os.Remove(lifecycle.ServerLogFile)
 			if err != nil && !os.IsNotExist(err) {
 			if err != nil && !os.IsNotExist(err) {
 				slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
 				slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
 			}
 			}
 		}
 		}
-	}()
-	scheme, testEndpoint := GetTestEndpoint()
-
-	if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
-		serverProcMutex.Lock()
-		fp, err := os.CreateTemp("", "ollama-server-*.log")
-		if err != nil {
-			t.Fatalf("failed to generate log file: %s", err)
-		}
-		lifecycle.ServerLogFile = fp.Name()
-		fp.Close()
-		assert.NoError(t, StartServer(ctx, testEndpoint))
 	}
 	}
+}
 
 
-	err = PullIfMissing(ctx, client, scheme, testEndpoint, genReq.Model)
-	if err != nil {
-		t.Fatalf("Error pulling model: %v", err)
-	}
+func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+	require.NoError(t, PullIfMissing(ctx, client, genReq.Model))
+	DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
+}
 
 
-	// Make the request and get the response
-	req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON))
-	if err != nil {
-		t.Fatalf("Error creating request: %v", err)
+func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) {
+	stallTimer := time.NewTimer(initialTimeout)
+	var buf bytes.Buffer
+	fn := func(response api.GenerateResponse) error {
+		// fmt.Print(".")
+		buf.Write([]byte(response.Response))
+		if !stallTimer.Reset(streamTimeout) {
+			return fmt.Errorf("stall was detected while streaming response, aborting")
+		}
+		return nil
 	}
 	}
 
 
-	// Set the content type for the request
-	req.Header.Set("Content-Type", "application/json")
+	stream := true
+	genReq.Stream = &stream
+	done := make(chan int)
+	var genErr error
+	go func() {
+		genErr = client.Generate(ctx, &genReq, fn)
+		done <- 0
+	}()
 
 
-	// Make the request with the HTTP client
-	response, err := client.Do(req.WithContext(ctx))
-	if err != nil {
-		t.Fatalf("Error making request: %v", err)
-	}
-	defer response.Body.Close()
-	body, err := io.ReadAll(response.Body)
-	assert.NoError(t, err)
-	assert.Equal(t, response.StatusCode, 200, string(body))
-
-	// Verify the response is valid JSON
-	var payload api.GenerateResponse
-	err = json.Unmarshal(body, &payload)
-	if err != nil {
-		assert.NoError(t, err, body)
+	select {
+	case <-stallTimer.C:
+		if buf.Len() == 0 {
+			t.Errorf("generate never started.  Timed out after :%s", initialTimeout.String())
+		} else {
+			t.Errorf("generate stalled.  Response so far:%s", buf.String())
+		}
+	case <-done:
+		require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
+		// Verify the response contains the expected data
+		response := buf.String()
+		atLeastOne := false
+		for _, resp := range anyResp {
+			if strings.Contains(strings.ToLower(response), resp) {
+				atLeastOne = true
+				break
+			}
+		}
+		require.True(t, atLeastOne, "none of %v found in %s", anyResp, response)
+		slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
+	case <-ctx.Done():
+		t.Error("outer test context done while waiting for generate")
 	}
 	}
+}
 
 
-	// Verify the response contains the expected data
-	atLeastOne := false
-	for _, resp := range anyResp {
-		if strings.Contains(strings.ToLower(payload.Response), resp) {
-			atLeastOne = true
-			break
+// Generate a set of requests
+// By default each request uses orca-mini as the model
+func GenerateRequests() ([]api.GenerateRequest, [][]string) {
+	return []api.GenerateRequest{
+			{
+				Model:  "orca-mini",
+				Prompt: "why is the ocean blue?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			}, {
+				Model:  "orca-mini",
+				Prompt: "why is the color of dirt brown?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			}, {
+				Model:  "orca-mini",
+				Prompt: "what is the origin of the us thanksgiving holiday?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			}, {
+				Model:  "orca-mini",
+				Prompt: "what is the origin of independence day?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			}, {
+				Model:  "orca-mini",
+				Prompt: "what is the composition of air?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			},
+		},
+		[][]string{
+			[]string{"sunlight"},
+			[]string{"soil", "organic", "earth", "black", "tan"},
+			[]string{"england", "english", "massachusetts", "pilgrims"},
+			[]string{"fourth", "july", "declaration", "independence"},
+			[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
 		}
 		}
-	}
-	assert.True(t, atLeastOne, "none of %v found in %s", anyResp, payload.Response)
 }
 }

+ 162 - 0
llm/memory.go

@@ -0,0 +1,162 @@
+package llm
+
+import (
+	"fmt"
+	"log/slog"
+	"strings"
+
+	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/format"
+	"github.com/ollama/ollama/gpu"
+)
+
+// This algorithm looks for a complete fit to determine if we need to unload other models
+func PredictServerFit(allGpus gpu.GpuInfoList, ggml *GGML, adapters, projectors []string, opts api.Options) (bool, uint64) {
+	var estimatedVRAM uint64
+	if opts.NumCtx > int(ggml.KV().ContextLength()) {
+		slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
+		opts.NumCtx = int(ggml.KV().ContextLength())
+	}
+
+	if opts.NumCtx < 4 {
+		opts.NumCtx = 4
+	}
+
+	// Split up the GPUs by type and try them
+	for _, gpus := range allGpus.ByLibrary() {
+		var layerCount int
+		layerCount, estimatedVRAM = EstimateGPULayers(gpus, ggml, projectors, opts)
+		if opts.NumGPU < 0 {
+			if layerCount > 0 && layerCount >= int(ggml.KV().BlockCount()+1) {
+				return true, estimatedVRAM
+			}
+		} else {
+			if layerCount > 0 && layerCount >= opts.NumGPU {
+				return true, estimatedVRAM
+			}
+		}
+	}
+	return false, estimatedVRAM
+}
+
+// Given a model and one or more GPU targets, predict how many layers and bytes we can load
+// The GPUs provided must all be the same Library
+func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts api.Options) (int, uint64) {
+	if gpus[0].Library == "cpu" {
+		return 0, 0
+	}
+	var memoryAvailable uint64
+	for _, info := range gpus {
+		memoryAvailable += info.FreeMemory
+	}
+	slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", format.HumanBytes2(memoryAvailable))
+
+	// TODO - this is probably wrong, first GPU vs secondaries will have different overheads
+	memoryMinimum := gpus[0].MinimumMemory
+
+	for _, projector := range projectors {
+		memoryMinimum += projectorMemoryRequirements(projector)
+
+		// multimodal models require at least 2048 context
+		opts.NumCtx = max(opts.NumCtx, 2048)
+	}
+
+	// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
+	var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
+
+	graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
+	if graphPartialOffload == 0 {
+		graphPartialOffload = ggml.KV().GQA() * kv / 6
+	}
+
+	if graphFullOffload == 0 {
+		graphFullOffload = graphPartialOffload
+	}
+
+	graphFullOffload *= uint64(len(gpus))
+	graphPartialOffload *= uint64(len(gpus))
+
+	// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
+	memoryRequiredTotal := memoryMinimum + graphFullOffload
+
+	// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
+	memoryRequiredPartial := memoryMinimum + graphPartialOffload
+
+	if memoryRequiredPartial > memoryAvailable {
+		slog.Debug("insufficient VRAM to load any model layers")
+		return 0, 0
+	}
+
+	var layerCount int
+	layers := ggml.Tensors().Layers()
+	for i := 0; i < int(ggml.KV().BlockCount()); i++ {
+		memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size()
+
+		// KV is proportional to the number of layers
+		memoryLayer += kv / ggml.KV().BlockCount()
+
+		memoryRequiredTotal += memoryLayer
+		if memoryAvailable > memoryRequiredPartial+memoryLayer {
+			memoryRequiredPartial += memoryLayer
+			layerCount++
+		}
+	}
+
+	var memoryLayerOutput uint64
+	for k, v := range layers {
+		if !strings.HasPrefix(k, "blk.") {
+			memoryLayerOutput += v.size()
+		}
+	}
+
+	memoryRequiredTotal += memoryLayerOutput
+
+	if memoryAvailable > memoryRequiredTotal {
+		layerCount = int(ggml.KV().BlockCount()) + 1
+		memoryRequiredPartial = memoryRequiredTotal
+	}
+
+	memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
+
+	slog.Info(
+		"offload to gpu",
+		slog.Group(
+			"layers",
+			// actual number of layers offloaded
+			"real", opts.NumGPU,
+			// estimated number of layers that can be offloaded
+			"estimate", layerCount,
+		),
+		slog.Group(
+			"memory",
+			// memory available for offloading
+			"available", format.HumanBytes2(memoryAvailable),
+			slog.Group(
+				"required",
+				// memory required for full offloading
+				"full", format.HumanBytes2(memoryRequiredTotal),
+				// memory required to offload layers.estimate layers
+				"partial", format.HumanBytes2(memoryRequiredPartial),
+				// memory of KV cache
+				"kv", format.HumanBytes2(kv),
+			),
+			slog.Group(
+				"weights",
+				// memory of the weights
+				"total", format.HumanBytes2(memoryWeights),
+				// memory of repeating layers
+				"repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput),
+				// memory of non-repeating layers
+				"nonrepeating", format.HumanBytes2(memoryLayerOutput),
+			),
+			slog.Group(
+				"graph",
+				// memory of graph when fully offloaded
+				"full", format.HumanBytes2(graphFullOffload),
+				// memory of graph when not fully offloaded
+				"partial", format.HumanBytes2(graphPartialOffload),
+			),
+		),
+	)
+	return layerCount, uint64(memoryRequiredPartial)
+}

+ 18 - 0
llm/payload.go

@@ -9,6 +9,7 @@ import (
 	"log/slog"
 	"log/slog"
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
+	"runtime"
 	"strings"
 	"strings"
 
 
 	"golang.org/x/exp/slices"
 	"golang.org/x/exp/slices"
@@ -138,6 +139,23 @@ func serversForGpu(info gpu.GpuInfo) []string {
 	return servers
 	return servers
 }
 }
 
 
+// Return the optimal server for this CPU architecture
+func serverForCpu() string {
+	if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
+		return "metal"
+	}
+	variant := gpu.GetCPUVariant()
+	availableServers := availableServers()
+	if variant != "" {
+		for cmp := range availableServers {
+			if cmp == "cpu_"+variant {
+				return cmp
+			}
+		}
+	}
+	return "cpu"
+}
+
 // extract extracts the embedded files to the target directory
 // extract extracts the embedded files to the target directory
 func extractFiles(targetDir string, glob string) error {
 func extractFiles(targetDir string, glob string) error {
 	files, err := fs.Glob(libEmbed, glob)
 	files, err := fs.Glob(libEmbed, glob)

+ 156 - 143
llm/server.go

@@ -21,21 +21,43 @@ import (
 	"strings"
 	"strings"
 	"time"
 	"time"
 
 
+	"golang.org/x/sync/semaphore"
+
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/gpu"
 	"github.com/ollama/ollama/gpu"
 )
 )
 
 
-// LlamaServer is an instance of the llama.cpp server
-type LlamaServer struct {
+type LlamaServer interface {
+	Ping(ctx context.Context) error
+	WaitUntilRunning(ctx context.Context) error
+	Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
+	Embedding(ctx context.Context, prompt string) ([]float64, error)
+	Tokenize(ctx context.Context, content string) ([]int, error)
+	Detokenize(ctx context.Context, tokens []int) (string, error)
+	Close() error
+	EstimatedVRAM() uint64
+}
+
+// llmServer is an instance of the llama.cpp server
+type llmServer struct {
 	port    int
 	port    int
 	cmd     *exec.Cmd
 	cmd     *exec.Cmd
 	done    chan error // Channel to signal when the process exits
 	done    chan error // Channel to signal when the process exits
 	status  *StatusWriter
 	status  *StatusWriter
 	options api.Options
 	options api.Options
+
+	// TODO - this should be broken down by GPU
+	estimatedVRAM uint64 // Estimated usage of VRAM by the loaded model
+
+	sem *semaphore.Weighted
 }
 }
 
 
-func NewLlamaServer(model string, adapters, projectors []string, opts api.Options) (*LlamaServer, error) {
+func LoadModel(model string) (*GGML, error) {
+	if _, err := os.Stat(model); err != nil {
+		return nil, err
+	}
+
 	f, err := os.Open(model)
 	f, err := os.Open(model)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
@@ -43,10 +65,13 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 	defer f.Close()
 	defer f.Close()
 
 
 	ggml, _, err := DecodeGGML(f)
 	ggml, _, err := DecodeGGML(f)
-	if err != nil {
-		return nil, err
-	}
+	return ggml, err
+}
 
 
+// NewLlamaServer will run a server for the given GPUs
+// The gpu list must be a single family.
+func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) {
+	var err error
 	if opts.NumCtx > int(ggml.KV().ContextLength()) {
 	if opts.NumCtx > int(ggml.KV().ContextLength()) {
 		slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
 		slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
 		opts.NumCtx = int(ggml.KV().ContextLength())
 		opts.NumCtx = int(ggml.KV().ContextLength())
@@ -56,130 +81,50 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 		opts.NumCtx = 4
 		opts.NumCtx = 4
 	}
 	}
 
 
-	memoryAvailable, _ := gpu.CheckVRAM()
-	info := gpu.GetGPUInfo()
-
-	memoryMinimum := info.MinimumMemory
-	for _, projector := range projectors {
-		memoryMinimum += projectorMemoryRequirements(projector)
-
-		// multimodal models require at least 2048 context
-		opts.NumCtx = max(opts.NumCtx, 2048)
-	}
+	cpuRunner := ""
+	var estimatedVRAM uint64
+	var systemMemory uint64
+	if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 {
 
 
-	// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
-	var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
+		// TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner
 
 
-	graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
-	if graphPartialOffload == 0 {
-		graphPartialOffload = ggml.KV().GQA() * kv / 6
-	}
-
-	if graphFullOffload == 0 {
-		graphFullOffload = graphPartialOffload
-	}
-
-	graphFullOffload *= uint64(info.DeviceCount)
-	graphPartialOffload *= uint64(info.DeviceCount)
-
-	// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
-	memoryRequiredTotal := memoryMinimum + graphFullOffload
-
-	// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
-	memoryRequiredPartial := memoryMinimum + graphPartialOffload
-
-	if info.Library != "metal" {
-		if memoryRequiredPartial > memoryAvailable {
-			info.Library = "cpu"
-		}
-	}
-
-	var layerCount int
-	layers := ggml.Tensors().Layers()
-	for i := 0; i < int(ggml.KV().BlockCount()); i++ {
-		memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size()
-
-		// KV is proportional to the number of layers
-		memoryLayer += kv / ggml.KV().BlockCount()
-
-		memoryRequiredTotal += memoryLayer
-		if memoryAvailable > memoryRequiredPartial+memoryLayer {
-			memoryRequiredPartial += memoryLayer
-			layerCount++
+		cpuRunner = serverForCpu()
+	} else {
+		if gpus[0].Library == "metal" {
+			memInfo, err := gpu.GetCPUMem()
+			if err != nil {
+				slog.Error("failed to lookup system memory", "error", err)
+			} else {
+				systemMemory = memInfo.TotalMemory
+				slog.Debug("system memory", "total", format.HumanBytes2(systemMemory))
+			}
 		}
 		}
-	}
-
-	var memoryLayerOutput uint64
-	for k, v := range layers {
-		if !strings.HasPrefix(k, "blk.") {
-			memoryLayerOutput += v.size()
+		var layers int
+		layers, estimatedVRAM = EstimateGPULayers(gpus, ggml, projectors, opts)
+
+		if gpus[0].Library == "metal" && estimatedVRAM > systemMemory {
+			// disable partial offloading when model is greater than total system memory as this
+			// can lead to locking up the system
+			opts.NumGPU = 0
+		} else if opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu" {
+			opts.NumGPU = layers
 		}
 		}
 	}
 	}
 
 
-	memoryRequiredTotal += memoryLayerOutput
-
-	if info.Library == "metal" && memoryRequiredTotal > info.TotalMemory {
-		// disable partial offloading when model is greater than total system memory
-		opts.NumGPU = 0
-	} else if memoryAvailable > memoryRequiredTotal {
-		layerCount = int(ggml.KV().BlockCount()) + 1
-		memoryRequiredPartial = memoryRequiredTotal
-	}
-
-	if opts.NumGPU < 0 {
-		opts.NumGPU = layerCount
-	}
-
-	memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
-
-	slog.Info(
-		"offload to gpu",
-		slog.Group(
-			"layers",
-			// actual number of layers offloaded
-			"real", opts.NumGPU,
-			// estimated number of layers that can be offloaded
-			"estimate", layerCount,
-		),
-		slog.Group(
-			"memory",
-			// memory available for offloading
-			"available", format.HumanBytes2(memoryAvailable),
-			slog.Group(
-				"required",
-				// memory required for full offloading
-				"full", format.HumanBytes2(memoryRequiredTotal),
-				// memory required to offload layers.estimate layers
-				"partial", format.HumanBytes2(memoryRequiredPartial),
-				// memory of KV cache
-				"kv", format.HumanBytes2(kv),
-			),
-			slog.Group(
-				"weights",
-				// memory of the weights
-				"total", format.HumanBytes2(memoryWeights),
-				// memory of repeating layers
-				"repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput),
-				// memory of non-repeating layers
-				"nonrepeating", format.HumanBytes2(memoryLayerOutput),
-			),
-			slog.Group(
-				"graph",
-				// memory of graph when fully offloaded
-				"full", format.HumanBytes2(graphFullOffload),
-				// memory of graph when not fully offloaded
-				"partial", format.HumanBytes2(graphPartialOffload),
-			),
-		),
-	)
+	// Loop through potential servers
+	finalErr := fmt.Errorf("no suitable llama servers found")
 
 
 	if len(adapters) > 1 {
 	if len(adapters) > 1 {
 		return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
 		return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
 	}
 	}
 
 
 	availableServers := availableServers()
 	availableServers := availableServers()
-	servers := serversForGpu(info)
-
+	var servers []string
+	if cpuRunner != "" {
+		servers = []string{cpuRunner}
+	} else {
+		servers = serversForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant
+	}
 	demandLib := strings.Trim(os.Getenv("OLLAMA_LLM_LIBRARY"), "\"' ")
 	demandLib := strings.Trim(os.Getenv("OLLAMA_LLM_LIBRARY"), "\"' ")
 	if demandLib != "" {
 	if demandLib != "" {
 		serverPath := availableServers[demandLib]
 		serverPath := availableServers[demandLib]
@@ -192,7 +137,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 	}
 	}
 
 
 	if len(servers) == 0 {
 	if len(servers) == 0 {
-		return nil, fmt.Errorf("no servers found for %v", info)
+		return nil, fmt.Errorf("no servers found for %v", gpus)
 	}
 	}
 
 
 	params := []string{
 	params := []string{
@@ -249,8 +194,18 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 		params = append(params, "--numa")
 		params = append(params, "--numa")
 	}
 	}
 
 
-	// Loop through potential servers
-	var finalErr error
+	// "--cont-batching", // TODO - doesn't seem to have any noticeable perf change for multiple requests
+	numParallel := 1
+	if onp := os.Getenv("OLLAMA_NUM_PARALLEL"); onp != "" {
+		numParallel, err = strconv.Atoi(onp)
+		if err != nil || numParallel <= 0 {
+			err = fmt.Errorf("invalid OLLAMA_NUM_PARALLEL=%s must be greater than zero - %w", onp, err)
+			slog.Error("misconfiguration", "error", err)
+			return nil, err
+		}
+	}
+	params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
+
 	for i := 0; i < len(servers); i++ {
 	for i := 0; i < len(servers); i++ {
 		dir := availableServers[servers[i]]
 		dir := availableServers[servers[i]]
 
 
@@ -275,30 +230,49 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 		}
 		}
 		// append the server directory to LD_LIBRARY_PATH/PATH
 		// append the server directory to LD_LIBRARY_PATH/PATH
 		libraryPaths := []string{dir}
 		libraryPaths := []string{dir}
+
 		if libraryPath, ok := os.LookupEnv(pathEnv); ok {
 		if libraryPath, ok := os.LookupEnv(pathEnv); ok {
 			// Append our runner directory to the path
 			// Append our runner directory to the path
 			// This will favor system libraries over our bundled library dependencies
 			// This will favor system libraries over our bundled library dependencies
 			libraryPaths = append(filepath.SplitList(libraryPath), libraryPaths...)
 			libraryPaths = append(filepath.SplitList(libraryPath), libraryPaths...)
 		}
 		}
 
 
+		// Note: we always put the dependency path first
+		// since this was the exact version we verified for AMD GPUs
+		// and we favor what the user had in their path
+		if gpus[0].DependencyPath != "" {
+			// TODO refine for multi-gpu support
+			libraryPaths = append([]string{gpus[0].DependencyPath}, libraryPaths...)
+		}
+
 		server := filepath.Join(dir, "ollama_llama_server")
 		server := filepath.Join(dir, "ollama_llama_server")
 		if runtime.GOOS == "windows" {
 		if runtime.GOOS == "windows" {
 			server = server + ".exe"
 			server = server + ".exe"
 		}
 		}
 
 
-		s := &LlamaServer{
-			port:    port,
-			cmd:     exec.Command(server, finalParams...),
-			status:  NewStatusWriter(os.Stderr),
-			options: opts,
+		s := &llmServer{
+			port:          port,
+			cmd:           exec.Command(server, finalParams...),
+			status:        NewStatusWriter(os.Stderr),
+			options:       opts,
+			estimatedVRAM: estimatedVRAM,
+			sem:           semaphore.NewWeighted(int64(numParallel)),
 		}
 		}
+
 		libEnv := fmt.Sprintf("%s=%s", pathEnv, strings.Join(libraryPaths, string(filepath.ListSeparator)))
 		libEnv := fmt.Sprintf("%s=%s", pathEnv, strings.Join(libraryPaths, string(filepath.ListSeparator)))
-		slog.Debug(libEnv)
 		s.cmd.Env = append(os.Environ(), libEnv)
 		s.cmd.Env = append(os.Environ(), libEnv)
 		s.cmd.Stdout = os.Stdout
 		s.cmd.Stdout = os.Stdout
 		s.cmd.Stderr = s.status
 		s.cmd.Stderr = s.status
 
 
+		// TODO - multiple GPU selection logic...
+		key, val := gpu.GpuInfoList(gpus).GetVisibleDevicesEnv()
+		if key != "" {
+			s.cmd.Env = append(s.cmd.Env, key+"="+val)
+		}
+
 		slog.Info("starting llama server", "cmd", s.cmd.String())
 		slog.Info("starting llama server", "cmd", s.cmd.String())
+		// Log at debug as the environment is inherited and might contain sensitive information
+		slog.Debug("subprocess", "environment", s.cmd.Env)
 
 
 		if err = s.cmd.Start(); err != nil {
 		if err = s.cmd.Start(); err != nil {
 			msg := ""
 			msg := ""
@@ -316,6 +290,13 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 			_ = s.cmd.Wait()
 			_ = s.cmd.Wait()
 		}()
 		}()
 
 
+		// TODO - make sure this is all wired up correctly
+		// if err = s.WaitUntilRunning(); err != nil {
+		// 	slog.Error("error starting llama server", "server", servers[i], "error", err)
+		// 	s.Close()
+		// 	finalErr = err
+		// 	continue
+		// }
 		return s, nil
 		return s, nil
 	}
 	}
 
 
@@ -353,6 +334,21 @@ const ( // iota is reset to 0
 	ServerStatusError
 	ServerStatusError
 )
 )
 
 
+func (s ServerStatus) ToString() string {
+	switch s {
+	case ServerStatusReady:
+		return "llm server ready"
+	case ServerStatusNoSlotsAvaialble:
+		return "llm busy - no slots available"
+	case ServerStatusLoadingModel:
+		return "llm server loading model"
+	case ServerStatusNotResponding:
+		return "llm server not responding"
+	default:
+		return "llm server error"
+	}
+}
+
 type ServerStatusResp struct {
 type ServerStatusResp struct {
 	Status          string `json:"status"`
 	Status          string `json:"status"`
 	SlotsIdle       int    `json:"slots_idle"`
 	SlotsIdle       int    `json:"slots_idle"`
@@ -360,7 +356,7 @@ type ServerStatusResp struct {
 	Error           string `json:"error"`
 	Error           string `json:"error"`
 }
 }
 
 
-func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
+func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
 	// Fail fast if its exited
 	// Fail fast if its exited
 	if s.cmd.ProcessState != nil {
 	if s.cmd.ProcessState != nil {
 		msg := ""
 		msg := ""
@@ -407,7 +403,7 @@ func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error)
 	}
 	}
 }
 }
 
 
-func (s *LlamaServer) Ping(ctx context.Context) error {
+func (s *llmServer) Ping(ctx context.Context) error {
 	_, err := s.getServerStatus(ctx)
 	_, err := s.getServerStatus(ctx)
 	if err != nil {
 	if err != nil {
 		slog.Debug("server unhealthy", "error", err)
 		slog.Debug("server unhealthy", "error", err)
@@ -416,7 +412,7 @@ func (s *LlamaServer) Ping(ctx context.Context) error {
 	return nil
 	return nil
 }
 }
 
 
-func (s *LlamaServer) WaitUntilRunning() error {
+func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
 	start := time.Now()
 	start := time.Now()
 	// TODO we need to wire up a better way to detect hangs during model load and startup of the server
 	// TODO we need to wire up a better way to detect hangs during model load and startup of the server
 	expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
 	expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
@@ -427,6 +423,9 @@ func (s *LlamaServer) WaitUntilRunning() error {
 	var lastStatus ServerStatus = -1
 	var lastStatus ServerStatus = -1
 	for {
 	for {
 		select {
 		select {
+		case <-ctx.Done():
+			slog.Info("context expired before server started")
+			return fmt.Errorf("timed out waiting for llama runner to start")
 		case err := <-s.done:
 		case err := <-s.done:
 			msg := ""
 			msg := ""
 			if s.status != nil && s.status.LastErrMsg != "" {
 			if s.status != nil && s.status.LastErrMsg != "" {
@@ -450,9 +449,9 @@ func (s *LlamaServer) WaitUntilRunning() error {
 				return fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
 				return fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
 			}
 			}
 
 
-			ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
+			c, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
 			defer cancel()
 			defer cancel()
-			status, err := s.getServerStatus(ctx)
+			status, err := s.getServerStatus(c)
 			if err != nil && lastStatus != status {
 			if err != nil && lastStatus != status {
 				slog.Debug("server not yet available", "error", err)
 				slog.Debug("server not yet available", "error", err)
 				lastStatus = status
 				lastStatus = status
@@ -538,7 +537,12 @@ type CompletionResponse struct {
 	EvalDuration       time.Duration
 	EvalDuration       time.Duration
 }
 }
 
 
-func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
+func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
+	if err := s.sem.Acquire(ctx, 1); err != nil {
+		slog.Error("Failed to acquire semaphore", "error", err)
+		return err
+	}
+	defer s.sem.Release(1)
 	request := map[string]any{
 	request := map[string]any{
 		"prompt":            req.Prompt,
 		"prompt":            req.Prompt,
 		"stream":            true,
 		"stream":            true,
@@ -569,7 +573,7 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	} else if status != ServerStatusReady {
 	} else if status != ServerStatusReady {
-		return fmt.Errorf("unexpected server status: %d", status)
+		return fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 	}
 
 
 	if req.Format == "json" {
 	if req.Format == "json" {
@@ -716,13 +720,18 @@ type EmbeddingResponse struct {
 	Embedding []float64 `json:"embedding"`
 	Embedding []float64 `json:"embedding"`
 }
 }
 
 
-func (s *LlamaServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
+func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
+	if err := s.sem.Acquire(ctx, 1); err != nil {
+		slog.Error("Failed to acquire semaphore", "error", err)
+		return nil, err
+	}
+	defer s.sem.Release(1)
 	// Make sure the server is ready
 	// Make sure the server is ready
 	status, err := s.getServerStatus(ctx)
 	status, err := s.getServerStatus(ctx)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	} else if status != ServerStatusReady {
 	} else if status != ServerStatusReady {
-		return nil, fmt.Errorf("unexpected server status: %d", status)
+		return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 	}
 
 
 	data, err := json.Marshal(TokenizeRequest{Content: prompt})
 	data, err := json.Marshal(TokenizeRequest{Content: prompt})
@@ -768,13 +777,13 @@ type TokenizeResponse struct {
 	Tokens []int `json:"tokens"`
 	Tokens []int `json:"tokens"`
 }
 }
 
 
-func (s *LlamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
+func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
 	// Make sure the server is ready
 	// Make sure the server is ready
 	status, err := s.getServerStatus(ctx)
 	status, err := s.getServerStatus(ctx)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
-	} else if status != ServerStatusReady {
-		return nil, fmt.Errorf("unexpected server status: %d", status)
+	} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble {
+		return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 	}
 
 
 	data, err := json.Marshal(TokenizeRequest{Content: content})
 	data, err := json.Marshal(TokenizeRequest{Content: content})
@@ -820,13 +829,13 @@ type DetokenizeResponse struct {
 	Content string `json:"content"`
 	Content string `json:"content"`
 }
 }
 
 
-func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
+func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
 	// Make sure the server is ready
 	// Make sure the server is ready
 	status, err := s.getServerStatus(ctx)
 	status, err := s.getServerStatus(ctx)
 	if err != nil {
 	if err != nil {
 		return "", err
 		return "", err
-	} else if status != ServerStatusReady {
-		return "", fmt.Errorf("unexpected server status: %d", status)
+	} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble {
+		return "", fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 	}
 
 
 	data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
 	data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
@@ -864,7 +873,7 @@ func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, err
 	return decoded.Content, nil
 	return decoded.Content, nil
 }
 }
 
 
-func (s *LlamaServer) Close() error {
+func (s *llmServer) Close() error {
 	if s.cmd != nil {
 	if s.cmd != nil {
 		slog.Debug("stopping llama server")
 		slog.Debug("stopping llama server")
 		return s.cmd.Process.Kill()
 		return s.cmd.Process.Kill()
@@ -873,6 +882,10 @@ func (s *LlamaServer) Close() error {
 	return nil
 	return nil
 }
 }
 
 
+func (s *llmServer) EstimatedVRAM() uint64 {
+	return s.estimatedVRAM
+}
+
 func parseDurationMs(ms float64) time.Duration {
 func parseDurationMs(ms float64) time.Duration {
 	dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
 	dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
 	if err != nil {
 	if err != nil {

+ 61 - 141
server/routes.go

@@ -15,11 +15,8 @@ import (
 	"os"
 	"os"
 	"os/signal"
 	"os/signal"
 	"path/filepath"
 	"path/filepath"
-	"reflect"
-	"runtime"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
-	"sync"
 	"syscall"
 	"syscall"
 	"time"
 	"time"
 
 
@@ -38,7 +35,8 @@ import (
 var mode string = gin.DebugMode
 var mode string = gin.DebugMode
 
 
 type Server struct {
 type Server struct {
-	addr net.Addr
+	addr  net.Addr
+	sched *Scheduler
 }
 }
 
 
 func init() {
 func init() {
@@ -53,88 +51,8 @@ func init() {
 	gin.SetMode(mode)
 	gin.SetMode(mode)
 }
 }
 
 
-var loaded struct {
-	mu sync.Mutex
-
-	llama *llm.LlamaServer
-
-	expireTimer *time.Timer
-
-	model      string
-	adapters   []string
-	projectors []string
-	*api.Options
-}
-
 var defaultSessionDuration = 5 * time.Minute
 var defaultSessionDuration = 5 * time.Minute
 
 
-func unload() {
-	if loaded.llama != nil {
-		loaded.llama.Close()
-	}
-
-	loaded.llama = nil
-	loaded.model = ""
-	loaded.adapters = nil
-	loaded.projectors = nil
-	loaded.Options = nil
-}
-
-// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
-func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
-	ctx, cancel := context.WithTimeout(c, 10*time.Second)
-	defer cancel()
-
-	needLoad := loaded.llama == nil || // is there a model loaded?
-		loaded.model != model.ModelPath || // has the base model changed?
-		!reflect.DeepEqual(loaded.adapters, model.AdapterPaths) || // have the adapters changed?
-		!reflect.DeepEqual(loaded.projectors, model.ProjectorPaths) || // have the adapters changed?
-		!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) || // have the runner options changed?
-		loaded.llama.Ping(ctx) != nil
-
-	if needLoad {
-		if loaded.llama != nil {
-			slog.Info("changing loaded model")
-			unload()
-		}
-
-		llama, err := llm.NewLlamaServer(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
-		if err != nil {
-			// some older models are not compatible with newer versions of llama.cpp
-			// show a generalized compatibility error until there is a better way to
-			// check for model compatibility
-			if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
-				err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
-			}
-
-			return err
-		}
-
-		loaded.model = model.ModelPath
-		loaded.adapters = model.AdapterPaths
-		loaded.projectors = model.ProjectorPaths
-		loaded.llama = llama
-		loaded.Options = &opts
-
-		if err = llama.WaitUntilRunning(); err != nil {
-			slog.Error("error loading llama server", "error", err)
-			unload()
-			return err
-		}
-	}
-
-	if loaded.expireTimer == nil {
-		loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
-			loaded.mu.Lock()
-			defer loaded.mu.Unlock()
-			unload()
-		})
-	}
-
-	loaded.expireTimer.Reset(sessionDuration)
-	return nil
-}
-
 func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
 func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
 	opts := api.DefaultOptions()
 	opts := api.DefaultOptions()
 	if err := opts.FromMap(model.Options); err != nil {
 	if err := opts.FromMap(model.Options); err != nil {
@@ -154,9 +72,7 @@ func isSupportedImageType(image []byte) bool {
 	return slices.Contains(allowedTypes, contentType)
 	return slices.Contains(allowedTypes, contentType)
 }
 }
 
 
-func GenerateHandler(c *gin.Context) {
-	loaded.mu.Lock()
-	defer loaded.mu.Unlock()
+func (s *Server) GenerateHandler(c *gin.Context) {
 
 
 	checkpointStart := time.Now()
 	checkpointStart := time.Now()
 	var req api.GenerateRequest
 	var req api.GenerateRequest
@@ -224,7 +140,11 @@ func GenerateHandler(c *gin.Context) {
 		sessionDuration = req.KeepAlive.Duration
 		sessionDuration = req.KeepAlive.Duration
 	}
 	}
 
 
-	if err := load(c, model, opts, sessionDuration); err != nil {
+	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
+	var runner *runnerRef
+	select {
+	case runner = <-rCh:
+	case err = <-eCh:
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 		return
 	}
 	}
@@ -275,7 +195,7 @@ func GenerateHandler(c *gin.Context) {
 
 
 		sb.Reset()
 		sb.Reset()
 		if req.Context != nil {
 		if req.Context != nil {
-			prev, err := loaded.llama.Detokenize(c.Request.Context(), req.Context)
+			prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context)
 			if err != nil {
 			if err != nil {
 				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 				return
 				return
@@ -297,9 +217,6 @@ func GenerateHandler(c *gin.Context) {
 		defer close(ch)
 		defer close(ch)
 
 
 		fn := func(r llm.CompletionResponse) {
 		fn := func(r llm.CompletionResponse) {
-			// Update model expiration
-			loaded.expireTimer.Reset(sessionDuration)
-
 			// Build up the full response
 			// Build up the full response
 			if _, err := generated.WriteString(r.Content); err != nil {
 			if _, err := generated.WriteString(r.Content); err != nil {
 				ch <- gin.H{"error": err.Error()}
 				ch <- gin.H{"error": err.Error()}
@@ -331,7 +248,7 @@ func GenerateHandler(c *gin.Context) {
 					}
 					}
 
 
 					// TODO (jmorganca): encode() should not strip special tokens
 					// TODO (jmorganca): encode() should not strip special tokens
-					tokens, err := loaded.llama.Tokenize(c.Request.Context(), p)
+					tokens, err := runner.llama.Tokenize(c.Request.Context(), p)
 					if err != nil {
 					if err != nil {
 						ch <- gin.H{"error": err.Error()}
 						ch <- gin.H{"error": err.Error()}
 						return
 						return
@@ -359,7 +276,7 @@ func GenerateHandler(c *gin.Context) {
 			Images:  images,
 			Images:  images,
 			Options: opts,
 			Options: opts,
 		}
 		}
-		if err := loaded.llama.Completion(c.Request.Context(), req, fn); err != nil {
+		if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil {
 			ch <- gin.H{"error": err.Error()}
 			ch <- gin.H{"error": err.Error()}
 		}
 		}
 	}()
 	}()
@@ -421,10 +338,7 @@ func getDefaultSessionDuration() time.Duration {
 	return defaultSessionDuration
 	return defaultSessionDuration
 }
 }
 
 
-func EmbeddingsHandler(c *gin.Context) {
-	loaded.mu.Lock()
-	defer loaded.mu.Unlock()
-
+func (s *Server) EmbeddingsHandler(c *gin.Context) {
 	var req api.EmbeddingRequest
 	var req api.EmbeddingRequest
 	err := c.ShouldBindJSON(&req)
 	err := c.ShouldBindJSON(&req)
 	switch {
 	switch {
@@ -469,7 +383,11 @@ func EmbeddingsHandler(c *gin.Context) {
 		sessionDuration = req.KeepAlive.Duration
 		sessionDuration = req.KeepAlive.Duration
 	}
 	}
 
 
-	if err := load(c, model, opts, sessionDuration); err != nil {
+	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
+	var runner *runnerRef
+	select {
+	case runner = <-rCh:
+	case err = <-eCh:
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 		return
 	}
 	}
@@ -480,7 +398,7 @@ func EmbeddingsHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
-	embedding, err := loaded.llama.Embedding(c.Request.Context(), req.Prompt)
+	embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
 	if err != nil {
 	if err != nil {
 		slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
 		slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
 		c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
 		c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
@@ -493,7 +411,7 @@ func EmbeddingsHandler(c *gin.Context) {
 	c.JSON(http.StatusOK, resp)
 	c.JSON(http.StatusOK, resp)
 }
 }
 
 
-func PullModelHandler(c *gin.Context) {
+func (s *Server) PullModelHandler(c *gin.Context) {
 	var req api.PullRequest
 	var req api.PullRequest
 	err := c.ShouldBindJSON(&req)
 	err := c.ShouldBindJSON(&req)
 	switch {
 	switch {
@@ -542,7 +460,7 @@ func PullModelHandler(c *gin.Context) {
 	streamResponse(c, ch)
 	streamResponse(c, ch)
 }
 }
 
 
-func PushModelHandler(c *gin.Context) {
+func (s *Server) PushModelHandler(c *gin.Context) {
 	var req api.PushRequest
 	var req api.PushRequest
 	err := c.ShouldBindJSON(&req)
 	err := c.ShouldBindJSON(&req)
 	switch {
 	switch {
@@ -591,7 +509,7 @@ func PushModelHandler(c *gin.Context) {
 	streamResponse(c, ch)
 	streamResponse(c, ch)
 }
 }
 
 
-func CreateModelHandler(c *gin.Context) {
+func (s *Server) CreateModelHandler(c *gin.Context) {
 	var req api.CreateRequest
 	var req api.CreateRequest
 	err := c.ShouldBindJSON(&req)
 	err := c.ShouldBindJSON(&req)
 	switch {
 	switch {
@@ -664,7 +582,7 @@ func CreateModelHandler(c *gin.Context) {
 	streamResponse(c, ch)
 	streamResponse(c, ch)
 }
 }
 
 
-func DeleteModelHandler(c *gin.Context) {
+func (s *Server) DeleteModelHandler(c *gin.Context) {
 	var req api.DeleteRequest
 	var req api.DeleteRequest
 	err := c.ShouldBindJSON(&req)
 	err := c.ShouldBindJSON(&req)
 	switch {
 	switch {
@@ -709,7 +627,7 @@ func DeleteModelHandler(c *gin.Context) {
 	c.JSON(http.StatusOK, nil)
 	c.JSON(http.StatusOK, nil)
 }
 }
 
 
-func ShowModelHandler(c *gin.Context) {
+func (s *Server) ShowModelHandler(c *gin.Context) {
 	var req api.ShowRequest
 	var req api.ShowRequest
 	err := c.ShouldBindJSON(&req)
 	err := c.ShouldBindJSON(&req)
 	switch {
 	switch {
@@ -809,7 +727,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
 	return resp, nil
 	return resp, nil
 }
 }
 
 
-func ListModelsHandler(c *gin.Context) {
+func (s *Server) ListModelsHandler(c *gin.Context) {
 	models := make([]api.ModelResponse, 0)
 	models := make([]api.ModelResponse, 0)
 	manifestsPath, err := GetManifestPath()
 	manifestsPath, err := GetManifestPath()
 	if err != nil {
 	if err != nil {
@@ -869,7 +787,7 @@ func ListModelsHandler(c *gin.Context) {
 	c.JSON(http.StatusOK, api.ListResponse{Models: models})
 	c.JSON(http.StatusOK, api.ListResponse{Models: models})
 }
 }
 
 
-func CopyModelHandler(c *gin.Context) {
+func (s *Server) CopyModelHandler(c *gin.Context) {
 	var req api.CopyRequest
 	var req api.CopyRequest
 	err := c.ShouldBindJSON(&req)
 	err := c.ShouldBindJSON(&req)
 	switch {
 	switch {
@@ -901,7 +819,7 @@ func CopyModelHandler(c *gin.Context) {
 	}
 	}
 }
 }
 
 
-func HeadBlobHandler(c *gin.Context) {
+func (s *Server) HeadBlobHandler(c *gin.Context) {
 	path, err := GetBlobsPath(c.Param("digest"))
 	path, err := GetBlobsPath(c.Param("digest"))
 	if err != nil {
 	if err != nil {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -916,7 +834,7 @@ func HeadBlobHandler(c *gin.Context) {
 	c.Status(http.StatusOK)
 	c.Status(http.StatusOK)
 }
 }
 
 
-func CreateBlobHandler(c *gin.Context) {
+func (s *Server) CreateBlobHandler(c *gin.Context) {
 	path, err := GetBlobsPath(c.Param("digest"))
 	path, err := GetBlobsPath(c.Param("digest"))
 	if err != nil {
 	if err != nil {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -1063,27 +981,27 @@ func (s *Server) GenerateRoutes() http.Handler {
 		allowedHostsMiddleware(s.addr),
 		allowedHostsMiddleware(s.addr),
 	)
 	)
 
 
-	r.POST("/api/pull", PullModelHandler)
-	r.POST("/api/generate", GenerateHandler)
-	r.POST("/api/chat", ChatHandler)
-	r.POST("/api/embeddings", EmbeddingsHandler)
-	r.POST("/api/create", CreateModelHandler)
-	r.POST("/api/push", PushModelHandler)
-	r.POST("/api/copy", CopyModelHandler)
-	r.DELETE("/api/delete", DeleteModelHandler)
-	r.POST("/api/show", ShowModelHandler)
-	r.POST("/api/blobs/:digest", CreateBlobHandler)
-	r.HEAD("/api/blobs/:digest", HeadBlobHandler)
+	r.POST("/api/pull", s.PullModelHandler)
+	r.POST("/api/generate", s.GenerateHandler)
+	r.POST("/api/chat", s.ChatHandler)
+	r.POST("/api/embeddings", s.EmbeddingsHandler)
+	r.POST("/api/create", s.CreateModelHandler)
+	r.POST("/api/push", s.PushModelHandler)
+	r.POST("/api/copy", s.CopyModelHandler)
+	r.DELETE("/api/delete", s.DeleteModelHandler)
+	r.POST("/api/show", s.ShowModelHandler)
+	r.POST("/api/blobs/:digest", s.CreateBlobHandler)
+	r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
 
 
 	// Compatibility endpoints
 	// Compatibility endpoints
-	r.POST("/v1/chat/completions", openai.Middleware(), ChatHandler)
+	r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler)
 
 
 	for _, method := range []string{http.MethodGet, http.MethodHead} {
 	for _, method := range []string{http.MethodGet, http.MethodHead} {
 		r.Handle(method, "/", func(c *gin.Context) {
 		r.Handle(method, "/", func(c *gin.Context) {
 			c.String(http.StatusOK, "Ollama is running")
 			c.String(http.StatusOK, "Ollama is running")
 		})
 		})
 
 
-		r.Handle(method, "/api/tags", ListModelsHandler)
+		r.Handle(method, "/api/tags", s.ListModelsHandler)
 		r.Handle(method, "/api/version", func(c *gin.Context) {
 		r.Handle(method, "/api/version", func(c *gin.Context) {
 			c.JSON(http.StatusOK, gin.H{"version": version.Version})
 			c.JSON(http.StatusOK, gin.H{"version": version.Version})
 		})
 		})
@@ -1137,7 +1055,9 @@ func Serve(ln net.Listener) error {
 		}
 		}
 	}
 	}
 
 
-	s := &Server{addr: ln.Addr()}
+	ctx, done := context.WithCancel(context.Background())
+	sched := InitScheduler(ctx)
+	s := &Server{addr: ln.Addr(), sched: sched}
 	r := s.GenerateRoutes()
 	r := s.GenerateRoutes()
 
 
 	slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
 	slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
@@ -1150,7 +1070,8 @@ func Serve(ln net.Listener) error {
 	signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
 	signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
 	go func() {
 	go func() {
 		<-signals
 		<-signals
-		unload()
+		done()
+		sched.unloadAllRunners()
 		gpu.Cleanup()
 		gpu.Cleanup()
 		os.Exit(0)
 		os.Exit(0)
 	}()
 	}()
@@ -1158,12 +1079,12 @@ func Serve(ln net.Listener) error {
 	if err := llm.Init(); err != nil {
 	if err := llm.Init(); err != nil {
 		return fmt.Errorf("unable to initialize llm library %w", err)
 		return fmt.Errorf("unable to initialize llm library %w", err)
 	}
 	}
-	if runtime.GOOS == "linux" { // TODO - windows too
-		// check compatibility to log warnings
-		if _, err := gpu.CheckVRAM(); err != nil {
-			slog.Info(err.Error())
-		}
-	}
+
+	s.sched.Run(ctx)
+
+	// At startup we retrieve GPU information so we can get log messages before loading a model
+	// This will log warnings to the log in case we have problems with detected GPUs
+	_ = gpu.GetGPUInfo()
 
 
 	return srvr.Serve(ln)
 	return srvr.Serve(ln)
 }
 }
@@ -1219,9 +1140,9 @@ func streamResponse(c *gin.Context, ch chan any) {
 }
 }
 
 
 // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
 // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
-func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
+func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) {
 	encode := func(s string) ([]int, error) {
 	encode := func(s string) ([]int, error) {
-		return loaded.llama.Tokenize(ctx, s)
+		return runner.llama.Tokenize(ctx, s)
 	}
 	}
 
 
 	prompt, err := ChatPrompt(template, messages, numCtx, encode)
 	prompt, err := ChatPrompt(template, messages, numCtx, encode)
@@ -1232,10 +1153,7 @@ func chatPrompt(ctx context.Context, template string, messages []api.Message, nu
 	return prompt, nil
 	return prompt, nil
 }
 }
 
 
-func ChatHandler(c *gin.Context) {
-	loaded.mu.Lock()
-	defer loaded.mu.Unlock()
-
+func (s *Server) ChatHandler(c *gin.Context) {
 	checkpointStart := time.Now()
 	checkpointStart := time.Now()
 
 
 	var req api.ChatRequest
 	var req api.ChatRequest
@@ -1292,7 +1210,11 @@ func ChatHandler(c *gin.Context) {
 		sessionDuration = req.KeepAlive.Duration
 		sessionDuration = req.KeepAlive.Duration
 	}
 	}
 
 
-	if err := load(c, model, opts, sessionDuration); err != nil {
+	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
+	var runner *runnerRef
+	select {
+	case runner = <-rCh:
+	case err = <-eCh:
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 		return
 	}
 	}
@@ -1309,7 +1231,7 @@ func ChatHandler(c *gin.Context) {
 		}, req.Messages...)
 		}, req.Messages...)
 	}
 	}
 
 
-	prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
+	prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx)
 	if err != nil {
 	if err != nil {
 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 		return
@@ -1352,8 +1274,6 @@ func ChatHandler(c *gin.Context) {
 		defer close(ch)
 		defer close(ch)
 
 
 		fn := func(r llm.CompletionResponse) {
 		fn := func(r llm.CompletionResponse) {
-			// Update model expiration
-			loaded.expireTimer.Reset(sessionDuration)
 
 
 			resp := api.ChatResponse{
 			resp := api.ChatResponse{
 				Model:     req.Model,
 				Model:     req.Model,
@@ -1376,7 +1296,7 @@ func ChatHandler(c *gin.Context) {
 			ch <- resp
 			ch <- resp
 		}
 		}
 
 
-		if err := loaded.llama.Completion(c.Request.Context(), llm.CompletionRequest{
+		if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
 			Prompt:  prompt,
 			Prompt:  prompt,
 			Format:  req.Format,
 			Format:  req.Format,
 			Images:  images,
 			Images:  images,

+ 525 - 0
server/sched.go

@@ -0,0 +1,525 @@
+package server
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"log/slog"
+	"os"
+	"reflect"
+	"sort"
+	"strconv"
+	"strings"
+	"sync"
+	"time"
+
+	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/format"
+	"github.com/ollama/ollama/gpu"
+	"github.com/ollama/ollama/llm"
+	"golang.org/x/exp/slices"
+)
+
+type LlmRequest struct {
+	ctx             context.Context //nolint:containedctx
+	model           *Model
+	ggml            *llm.GGML // TODO - how large is this, and do we need to free it after we've finished loading?
+	opts            api.Options
+	sessionDuration time.Duration
+	successCh       chan *runnerRef
+	errCh           chan error
+}
+
+type Scheduler struct {
+	pendingReqCh  chan *LlmRequest
+	finishedReqCh chan *LlmRequest
+	expiredCh     chan *runnerRef
+	unloadedCh    chan interface{}
+
+	loaded   map[string]*runnerRef
+	loadedMu sync.Mutex
+
+	loadFn      func(req *LlmRequest, gpus gpu.GpuInfoList)
+	newServerFn func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error)
+	getGpuFn    func() gpu.GpuInfoList
+}
+
+// TODO set this to zero after a release or two, to enable multiple models by default
+var loadedMax = 1          // Maximum runners; < 1 maps to as many as will fit in VRAM (unlimited for CPU runners)
+var maxQueuedRequests = 10 // TODO configurable
+
+func InitScheduler(ctx context.Context) *Scheduler {
+	maxRunners := os.Getenv("OLLAMA_MAX_LOADED_MODELS")
+	if maxRunners != "" {
+		m, err := strconv.Atoi(maxRunners)
+		if err != nil {
+			slog.Error("invalid setting", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err)
+		} else {
+			loadedMax = m
+		}
+	}
+
+	sched := &Scheduler{
+		pendingReqCh:  make(chan *LlmRequest, maxQueuedRequests),
+		finishedReqCh: make(chan *LlmRequest, maxQueuedRequests),
+		expiredCh:     make(chan *runnerRef, maxQueuedRequests),
+		unloadedCh:    make(chan interface{}, maxQueuedRequests),
+		loaded:        make(map[string]*runnerRef),
+		newServerFn:   llm.NewLlamaServer,
+		getGpuFn:      gpu.GetGPUInfo,
+	}
+	sched.loadFn = sched.load
+	return sched
+}
+
+// context must be canceled to decrement ref count and release the runner
+func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) {
+	ggml, err := llm.LoadModel(model.ModelPath)
+	req := &LlmRequest{
+		ctx:             c,
+		model:           model,
+		ggml:            ggml,
+		opts:            opts,
+		sessionDuration: sessionDuration,
+		successCh:       make(chan *runnerRef),
+		errCh:           make(chan error, 1),
+	}
+	if err != nil {
+		req.errCh <- err
+		return req.successCh, req.errCh
+	}
+	select {
+	case s.pendingReqCh <- req:
+	default:
+		req.errCh <- fmt.Errorf("server busy, please try again.  maximum pending requests exceeded")
+	}
+	return req.successCh, req.errCh
+}
+
+// Returns immediately, spawns go routines for the scheduler which will shutdown when ctx is done
+func (s *Scheduler) Run(ctx context.Context) {
+	slog.Debug("starting llm scheduler")
+	go func() {
+		s.processPending(ctx)
+	}()
+
+	go func() {
+		s.processCompleted(ctx)
+	}()
+}
+
+func (s *Scheduler) processPending(ctx context.Context) {
+	for {
+		select {
+		case <-ctx.Done():
+			slog.Debug("shutting down scheduler pending loop")
+			return
+		case pending := <-s.pendingReqCh:
+			// Block other requests until we get this pending request running
+			for {
+				var runnerToExpire *runnerRef
+				s.loadedMu.Lock()
+				runner := s.loaded[pending.model.ModelPath]
+				loadedCount := len(s.loaded)
+				s.loadedMu.Unlock()
+				if runner != nil {
+					if runner.needsReload(ctx, pending) {
+						runnerToExpire = runner
+					} else {
+						// Runner is usable, return it
+						pending.useLoadedRunner(runner, s.finishedReqCh)
+						break
+					}
+				} else if loadedCount == 0 {
+					slog.Debug("loading first model", "model", pending.model.ModelPath)
+					gpus := s.getGpuFn()
+					g := pickBestFitGPUs(pending, gpus)
+					if g != nil {
+						gpus = g
+					}
+					s.loadFn(pending, gpus)
+					break
+				} else if loadedMax > 0 && loadedCount >= loadedMax {
+					slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount)
+					runnerToExpire = s.findRunnerToUnload(pending)
+				} else {
+					// More than one loaded model, so we have to see if the new one fits
+					// Get a refreshed GPU list
+					gpus := s.getGpuFn()
+					// Update free memory from currently loaded models
+					s.updateFreeSpace(gpus)
+					gpus = pickBestFitGPUs(pending, gpus)
+					if gpus != nil {
+						slog.Debug("new model fits with existing models, loading")
+						s.loadFn(pending, gpus)
+						break
+					}
+					runnerToExpire = s.findRunnerToUnload(pending)
+				}
+
+				if runnerToExpire == nil {
+					// Shouildn't happen
+					slog.Error("runner to expire was nil!")
+					continue
+				}
+				// Trigger an expiration to unload once it's done
+				runnerToExpire.refMu.Lock()
+				slog.Debug("resetting model to expire immediately to make room", "model", runnerToExpire.model, "refCount", runnerToExpire.refCount)
+				if runnerToExpire.expireTimer != nil {
+					runnerToExpire.expireTimer.Stop()
+					runnerToExpire.expireTimer = nil
+				}
+				runnerToExpire.sessionDuration = 0
+				if runnerToExpire.refCount <= 0 {
+					s.expiredCh <- runnerToExpire
+				}
+				runnerToExpire.refMu.Unlock()
+				// Wait for the unload to happen
+				// Note: at this point we're queueing up all incoming requests, even if they were for
+				// a different model that's loaded and not scheduled to be removed.
+				slog.Debug("waiting for pending requests to complete and unload to occur", "model", runnerToExpire.model)
+				select {
+				case <-ctx.Done():
+					slog.Debug("shutting down scheduler pending loop")
+					return
+				case <-s.unloadedCh:
+					slog.Debug("unload completed", "model", runnerToExpire.model)
+					continue
+				}
+			}
+		case <-s.unloadedCh:
+			// An unload request when there are no pending request can be ignored
+			slog.Debug("ignoring unload event with no pending requests")
+		}
+	}
+}
+
+func (s *Scheduler) processCompleted(ctx context.Context) {
+	// Process completed requests, expired timers, and unloading models
+	for {
+		select {
+		case <-ctx.Done():
+			slog.Debug("shutting down scheduler completed loop")
+			return
+		case finished := <-s.finishedReqCh:
+			s.loadedMu.Lock()
+			runner := s.loaded[finished.model.ModelPath]
+			s.loadedMu.Unlock()
+			if runner == nil {
+				slog.Error("finished requeset signal received after model unloaded", "model", finished.model.ModelPath)
+				continue
+			}
+			runner.refMu.Lock()
+			runner.refCount--
+			if runner.refCount <= 0 {
+				if runner.sessionDuration <= 0 {
+					slog.Debug("runner with zero duration has gone idle, expiring to unload", "model", runner.model)
+					if runner.expireTimer != nil {
+						runner.expireTimer.Stop()
+						runner.expireTimer = nil
+					}
+					s.expiredCh <- runner
+				} else if runner.expireTimer == nil {
+					slog.Debug("runner with non-zero duration has gone idle, adding timer", "model", runner.model, "duration", runner.sessionDuration)
+					runner.expireTimer = time.AfterFunc(runner.sessionDuration, func() {
+						slog.Debug("timer expired, expiring to unload", "model", runner.model)
+						runner.refMu.Lock()
+						defer runner.refMu.Unlock()
+						if runner.expireTimer != nil {
+							runner.expireTimer.Stop()
+						}
+						s.expiredCh <- runner
+					})
+				} else {
+					slog.Debug("runner with non-zero duration has gone idle, resetting timer", "model", runner.model, "duration", runner.sessionDuration)
+					runner.expireTimer.Reset(runner.sessionDuration)
+				}
+			}
+			slog.Debug("after processing request finished event", "model", runner.model, "refCount", runner.refCount)
+			runner.refMu.Unlock()
+		case runner := <-s.expiredCh:
+			slog.Debug("runner expired event received", "model", runner.model)
+			runner.refMu.Lock()
+			if runner.refCount > 0 {
+				// Shouldn't happen, but safeguard to ensure no leaked runners
+				slog.Debug("expired event with positive ref count, retrying", "model", runner.model, "refCount", runner.refCount)
+				go func(runner *runnerRef) {
+					// We can't unload yet, but want to as soon as the current request completes
+					// So queue up another expired event
+					time.Sleep(10 * time.Millisecond)
+					s.expiredCh <- runner
+				}(runner)
+				runner.refMu.Unlock()
+				continue
+			}
+
+			slog.Debug("got lock to unload", "model", runner.model)
+			runner.unload()
+			s.loadedMu.Lock()
+			delete(s.loaded, runner.model)
+			s.loadedMu.Unlock()
+			slog.Debug("runner released", "model", runner.model)
+			runner.refMu.Unlock()
+			slog.Debug("sending an unloaded event", "model", runner.model)
+			s.unloadedCh <- struct{}{}
+		}
+	}
+}
+
+// Complete the pending request and send the runner back to the requester
+// Wires up a finished event after the request context is completed
+// Updates session duration, and resets expiration timer
+func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *LlmRequest) {
+	runner.refMu.Lock()
+	defer runner.refMu.Unlock()
+	runner.refCount++
+	runner.sessionDuration = pending.sessionDuration
+	pending.successCh <- runner
+	go func() {
+		<-pending.ctx.Done()
+		slog.Debug("context for request finished")
+		finished <- pending
+	}()
+}
+
+func (s *Scheduler) load(req *LlmRequest, gpus gpu.GpuInfoList) {
+	llama, err := s.newServerFn(gpus, req.model.ModelPath, req.ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts)
+	if err != nil {
+		// some older models are not compatible with newer versions of llama.cpp
+		// show a generalized compatibility error until there is a better way to
+		// check for model compatibility
+		if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
+			err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
+		}
+		slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err)
+		req.errCh <- err
+		return
+	}
+	runner := &runnerRef{}
+	runner.model = req.model.ModelPath
+	runner.adapters = req.model.AdapterPaths
+	runner.projectors = req.model.ProjectorPaths
+	runner.llama = llama
+	runner.Options = &req.opts
+	runner.sessionDuration = req.sessionDuration
+	runner.gpus = gpus
+	runner.estimatedVRAM = llama.EstimatedVRAM()
+	runner.loading = true
+	runner.refCount = 1
+	runner.refMu.Lock()
+	s.loadedMu.Lock()
+	s.loaded[req.model.ModelPath] = runner
+	slog.Info("loaded runners", "count", len(s.loaded))
+	s.loadedMu.Unlock()
+
+	go func() {
+		defer runner.refMu.Unlock()
+		if err = llama.WaitUntilRunning(req.ctx); err != nil {
+			slog.Error("error loading llama server", "error", err)
+			runner.refCount--
+			req.errCh <- err
+			slog.Debug("triggering expiration for failed load", "model", runner.model)
+			s.expiredCh <- runner
+			return
+		}
+		slog.Debug("finished setting up runner", "model", req.model.ModelPath)
+		runner.loading = false
+		go func() {
+			<-req.ctx.Done()
+			slog.Debug("context for request finished")
+			s.finishedReqCh <- req
+		}()
+		req.successCh <- runner
+	}()
+}
+
+func (s *Scheduler) updateFreeSpace(allGpus gpu.GpuInfoList) {
+	type predKey struct {
+		Library string
+		ID      string
+	}
+	predMap := map[predKey]uint64{} // Sum up the total predicted usage per GPU for all runners
+	s.loadedMu.Lock()
+	for _, r := range s.loaded {
+		r.refMu.Lock()
+		gpuIDs := make([]string, 0, len(r.gpus))
+		if r.llama != nil {
+
+			// TODO this should be broken down by GPU instead of assuming uniform spread
+			estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus))
+			for _, gpu := range r.gpus {
+				gpuIDs = append(gpuIDs, gpu.ID)
+			}
+			for _, gpu := range allGpus {
+				if slices.Contains(gpuIDs, gpu.ID) {
+					predMap[predKey{gpu.Library, gpu.ID}] += estimatedVRAMPerGPU
+				}
+			}
+		} else {
+			slog.Warn("unexpected nil runner reference, memory prediction may be incorrect")
+		}
+		r.refMu.Unlock()
+	}
+	s.loadedMu.Unlock()
+
+	// Now that we've summed up all the GPU usage predictions across all the loaded runners, update the gpu list
+	for i := range allGpus {
+		if p, ok := predMap[predKey{allGpus[i].Library, allGpus[i].ID}]; ok {
+			slog.Debug("gpu reported", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "available", format.HumanBytes2(allGpus[i].FreeMemory))
+			if p > allGpus[i].TotalMemory {
+				// Shouldn't happen
+				slog.Warn("predicted usage exceeds VRAM", "gpu", allGpus[i].ID, "totalMemory", allGpus[i].TotalMemory, "predicted", p)
+				allGpus[i].FreeMemory = 0
+			} else if (allGpus[i].TotalMemory - p) < allGpus[i].FreeMemory { // predicted free is smaller than reported free, use it
+				// TODO maybe we should just always trust our numbers, since cuda's free memory reporting is laggy
+				// and we might unload models we didn't actually need to.  The risk is if some other GPU intensive app is loaded
+				// after we start our first runner, then we'll never acount for that, so picking the smallest free value seems prudent.
+				allGpus[i].FreeMemory = allGpus[i].TotalMemory - p
+			}
+			slog.Info("updated VRAM", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "total", format.HumanBytes2(allGpus[i].TotalMemory), "available", format.HumanBytes2(allGpus[i].FreeMemory))
+		}
+	}
+}
+
+type runnerRef struct {
+	refMu sync.Mutex
+	// refCond   sync.Cond // Signaled on transition from 1 -> 0 refCount
+	refCount uint // prevent unloading if > 0
+	// unloading bool      // set to true when we are trying to unload the runner
+
+	llama         llm.LlamaServer
+	loading       bool            // True only during initial load, then false forever
+	gpus          gpu.GpuInfoList // Recorded at time of provisioning
+	estimatedVRAM uint64
+
+	sessionDuration time.Duration
+	expireTimer     *time.Timer
+
+	model      string
+	adapters   []string
+	projectors []string
+	*api.Options
+}
+
+// The refMu must already be held when calling unload
+func (runner *runnerRef) unload() {
+	if runner.llama != nil {
+		runner.llama.Close()
+	}
+	runner.llama = nil
+	runner.adapters = nil
+	runner.projectors = nil
+	runner.Options = nil
+	runner.gpus = nil
+}
+
+func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
+	slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
+	runner.refMu.Lock()
+	defer runner.refMu.Unlock()
+	// Ignore the NumGPU settings for comparison
+	optsExisting := runner.Options.Runner
+	optsExisting.NumGPU = -1
+	optsNew := req.opts.Runner
+	optsNew.NumGPU = -1
+	timeout := 10 * time.Second
+	if runner.loading {
+		timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems...
+	}
+	ctx, cancel := context.WithTimeout(ctx, timeout) // BUG -
+	defer cancel()
+	if !reflect.DeepEqual(runner.adapters, req.model.AdapterPaths) || // have the adapters changed?
+		!reflect.DeepEqual(runner.projectors, req.model.ProjectorPaths) || // have the projectors changed?
+		!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
+		runner.llama.Ping(ctx) != nil {
+		return true
+	}
+	return false
+}
+
+type ByDuration []*runnerRef
+
+func (a ByDuration) Len() int      { return len(a) }
+func (a ByDuration) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
+func (a ByDuration) Less(i, j int) bool {
+	// uint64 to turn negative time (never unload) to largest
+	return uint64(a[i].sessionDuration) < uint64(a[j].sessionDuration)
+}
+
+// TODO - future consideration to pick runners based on size
+// type BySize []*runnerRef
+// func (a BySize) Len() int           { return len(a) }
+// func (a BySize) Swap(i, j int)      { a[i], a[j] = a[j], a[i] }
+// func (a BySize) Less(i, j int) bool { return a[i].estimatedVRAM < a[j].estimatedVRAM }
+
+// pickBestFitGPUs will try to find the optimal placement of the model in the available GPUs where the model fully fits
+// If the model can not be fit fully within the available GPU(s) nil is returned
+func pickBestFitGPUs(req *LlmRequest, gpus gpu.GpuInfoList) gpu.GpuInfoList {
+	var estimatedVRAM uint64
+	for _, gl := range gpus.ByLibrary() {
+		var ok bool
+		sgl := append(make(gpu.GpuInfoList, 0, len(gl)), gl...)
+
+		// TODO - potentially sort by performance capability, existing models loaded, etc.
+		// Note: at present, this will favor more VRAM over faster GPU speed in mixed setups
+		sort.Sort(sort.Reverse(gpu.ByFreeMemory(sgl)))
+
+		// First attempt to fit the model into a single GPU
+		for _, g := range sgl {
+			if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, req.ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
+				slog.Debug("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
+				return []gpu.GpuInfo{g}
+			}
+		}
+
+		// TODO future refinements
+		// - if multiple Libraries, see if any single GPU in any Library will fit
+		// - try subsets of GPUs instead of just falling back to 1 or all in a family
+
+		// Now try all the GPUs
+		if ok, estimatedVRAM = llm.PredictServerFit(gl, req.ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
+			slog.Debug("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", gl[0].Library, "required", format.HumanBytes2(estimatedVRAM))
+			return gl
+		}
+	}
+	return nil
+}
+
+// findRunnerToUnload finds a runner to unload to make room for a new model
+func (s *Scheduler) findRunnerToUnload(req *LlmRequest) *runnerRef {
+	s.loadedMu.Lock()
+	runnerList := make([]*runnerRef, 0, len(s.loaded))
+	for _, r := range s.loaded {
+		runnerList = append(runnerList, r)
+	}
+	s.loadedMu.Unlock()
+
+	// In the future we can enhance the algorithm to be smarter about picking the optimal runner to unload
+	// e.g., if we have multiple options, will one make room for the request?
+	sort.Sort(ByDuration(runnerList))
+
+	// First try to find a runner that's already idle
+	for _, runner := range runnerList {
+		runner.refMu.Lock()
+		rc := runner.refCount
+		runner.refMu.Unlock()
+		if rc == 0 {
+			slog.Debug("found an idle runner to unload")
+			return runner
+		}
+	}
+	// None appear idle, just wait for the one with the shortest duration
+	slog.Debug("no idle runners, picking the shortest duration", "count", len(runnerList))
+	return runnerList[0]
+}
+
+func (s *Scheduler) unloadAllRunners() {
+	s.loadedMu.Lock()
+	defer s.loadedMu.Unlock()
+	for model, runner := range s.loaded {
+		if runner.llama != nil {
+			slog.Debug("shutting down runner", "model", model)
+			runner.llama.Close()
+		}
+	}
+}

+ 553 - 0
server/sched_test.go

@@ -0,0 +1,553 @@
+package server
+
+import (
+	"bytes"
+	"context"
+	"encoding/binary"
+	"fmt"
+	"log/slog"
+	"os"
+	"testing"
+	"time"
+
+	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/app/lifecycle"
+	"github.com/ollama/ollama/format"
+	"github.com/ollama/ollama/gpu"
+	"github.com/ollama/ollama/llm"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func init() {
+	os.Setenv("OLLAMA_DEBUG", "1")
+	lifecycle.InitLogging()
+}
+
+func TestInitScheduler(t *testing.T) {
+	ctx, done := context.WithCancel(context.Background())
+	defer done()
+	initialMax := loadedMax
+	s := InitScheduler(ctx)
+	require.Equal(t, initialMax, loadedMax)
+	require.NotNil(t, s.loaded)
+
+	os.Setenv("OLLAMA_MAX_LOADED_MODELS", "blue")
+	s = InitScheduler(ctx)
+	require.Equal(t, initialMax, loadedMax)
+	require.NotNil(t, s.loaded)
+
+	os.Setenv("OLLAMA_MAX_LOADED_MODELS", "0")
+	s = InitScheduler(ctx)
+	require.Equal(t, 0, loadedMax)
+	require.NotNil(t, s.loaded)
+}
+
+func TestLoad(t *testing.T) {
+	ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
+	defer done()
+	s := InitScheduler(ctx)
+	req := &LlmRequest{
+		ctx:             ctx,
+		model:           &Model{ModelPath: "foo"},
+		successCh:       make(chan *runnerRef, 1),
+		errCh:           make(chan error, 1),
+		sessionDuration: 2,
+	}
+	// Fail to load model first
+	s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
+		return nil, fmt.Errorf("something failed to load model blah")
+	}
+	gpus := gpu.GpuInfoList{}
+	s.load(req, gpus)
+	require.Len(t, req.successCh, 0)
+	require.Len(t, req.errCh, 1)
+	require.Len(t, s.loaded, 0)
+	err := <-req.errCh
+	require.Contains(t, err.Error(), "this model may be incompatible")
+
+	server := &mockLlm{estimatedVRAM: 10}
+	s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
+		return server, nil
+	}
+	s.load(req, gpus)
+	select {
+	case err := <-req.errCh:
+		require.NoError(t, err)
+	case resp := <-req.successCh:
+		require.Equal(t, uint64(10), resp.estimatedVRAM)
+		require.Equal(t, uint(1), resp.refCount)
+		require.Len(t, s.loaded, 1)
+	}
+
+	req.model.ModelPath = "dummy_model_path"
+	server.waitResp = fmt.Errorf("wait failure")
+	s.load(req, gpus)
+	select {
+	case err := <-req.errCh:
+		require.Contains(t, err.Error(), "wait failure")
+	case resp := <-req.successCh:
+		t.Errorf("unexpected success %v", resp)
+	}
+	runner := s.loaded["dummy_model_path"]
+	require.NotNil(t, runner)
+	require.Equal(t, uint(0), runner.refCount)
+	require.Len(t, s.expiredCh, 1)
+}
+
+type bundle struct {
+	ctx     context.Context //nolint:containedctx
+	ctxDone func()
+	srv     *mockLlm
+	req     *LlmRequest
+}
+
+func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
+	return scenario.srv, nil
+}
+
+func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle {
+	scenario := &bundle{}
+	scenario.ctx, scenario.ctxDone = context.WithCancel(ctx)
+	t.Helper()
+
+	f, err := os.CreateTemp(t.TempDir(), modelName)
+	assert.Nil(t, err)
+	defer f.Close()
+
+	gguf := llm.NewGGUFV3(binary.LittleEndian)
+	err = gguf.Encode(f, llm.KV{
+		"general.architecture":          "llama",
+		"general.name":                  "name",
+		"llama.context_length":          uint32(32),
+		"llama.embedding_length":        uint32(4096),
+		"llama.block_count":             uint32(1),
+		"llama.attention.head_count":    uint32(32),
+		"llama.attention.head_count_kv": uint32(32),
+		"tokenizer.ggml.tokens":         []string{" "},
+		"tokenizer.ggml.scores":         []float32{0},
+		"tokenizer.ggml.token_type":     []int32{0},
+	}, []llm.Tensor{
+		{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
+	})
+	assert.Nil(t, err)
+	fname := f.Name()
+	model := &Model{Name: modelName, ModelPath: fname}
+	ggml, err := llm.LoadModel(model.ModelPath)
+	require.NoError(t, err)
+	scenario.req = &LlmRequest{
+		ctx:             scenario.ctx,
+		model:           model,
+		ggml:            ggml,
+		sessionDuration: 5 * time.Millisecond,
+		successCh:       make(chan *runnerRef, 1),
+		errCh:           make(chan error, 1),
+	}
+	scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM}
+	return scenario
+}
+
+func TestRequests(t *testing.T) {
+	ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
+	defer done()
+
+	// Same model, same request
+	scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
+	scenario1a.req.sessionDuration = 0
+	scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
+	scenario1b.req.model = scenario1a.req.model
+	scenario1b.req.ggml = scenario1a.req.ggml
+	scenario1b.req.sessionDuration = 0
+
+	// simple reload of same model
+	scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
+	scenario2a.req.model = scenario1a.req.model
+	scenario2a.req.ggml = scenario1a.req.ggml
+
+	// Multiple loaded models
+	scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
+	scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
+	scenario3c := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
+
+	s := InitScheduler(ctx)
+	s.getGpuFn = func() gpu.GpuInfoList {
+		g := gpu.GpuInfo{Library: "metal"}
+		g.TotalMemory = 24 * format.GigaByte
+		g.FreeMemory = 12 * format.GigaByte
+		return []gpu.GpuInfo{g}
+	}
+	s.newServerFn = scenario1a.newServer
+	slog.Info("scenario1a")
+	s.pendingReqCh <- scenario1a.req
+	require.Len(t, s.pendingReqCh, 1)
+	s.Run(ctx)
+	select {
+	case resp := <-scenario1a.req.successCh:
+		require.Equal(t, resp.llama, scenario1a.srv)
+		require.Len(t, s.pendingReqCh, 0)
+		require.Len(t, scenario1a.req.errCh, 0)
+	case <-ctx.Done():
+		t.Errorf("timeout")
+	}
+
+	// Same runner as first request due to not needing a reload
+	s.newServerFn = scenario1b.newServer
+	slog.Info("scenario1b")
+	s.pendingReqCh <- scenario1b.req
+	select {
+	case resp := <-scenario1b.req.successCh:
+		require.Equal(t, resp.llama, scenario1a.srv)
+		require.Len(t, s.pendingReqCh, 0)
+		require.Len(t, scenario1b.req.errCh, 0)
+	case <-ctx.Done():
+		t.Errorf("timeout")
+	}
+
+	// Trigger a reload
+	s.newServerFn = scenario2a.newServer
+	scenario2a.req.model.AdapterPaths = []string{"new"}
+	slog.Info("scenario2a")
+	s.pendingReqCh <- scenario2a.req
+	// finish first two requests, so model can reload
+	time.Sleep(1 * time.Millisecond)
+	scenario1a.ctxDone()
+	scenario1b.ctxDone()
+	select {
+	case resp := <-scenario2a.req.successCh:
+		require.Equal(t, resp.llama, scenario2a.srv)
+		require.Len(t, s.pendingReqCh, 0)
+		require.Len(t, scenario2a.req.errCh, 0)
+	case <-ctx.Done():
+		t.Errorf("timeout")
+	}
+
+	loadedMax = 1
+	s.newServerFn = scenario3a.newServer
+	slog.Info("scenario3a")
+	s.pendingReqCh <- scenario3a.req
+	// finish prior request, so new model can load
+	time.Sleep(1 * time.Millisecond)
+	scenario2a.ctxDone()
+	select {
+	case resp := <-scenario3a.req.successCh:
+		require.Equal(t, resp.llama, scenario3a.srv)
+		require.Len(t, s.pendingReqCh, 0)
+		require.Len(t, scenario3a.req.errCh, 0)
+	case <-ctx.Done():
+		t.Errorf("timeout")
+	}
+	require.Len(t, s.loaded, 1)
+
+	loadedMax = 0
+	s.newServerFn = scenario3b.newServer
+	slog.Info("scenario3b")
+	s.pendingReqCh <- scenario3b.req
+	select {
+	case resp := <-scenario3b.req.successCh:
+		require.Equal(t, resp.llama, scenario3b.srv)
+		require.Len(t, s.pendingReqCh, 0)
+		require.Len(t, scenario3b.req.errCh, 0)
+	case <-ctx.Done():
+		t.Errorf("timeout")
+	}
+	require.Len(t, s.loaded, 2)
+
+	// Try to load a model that wont fit
+	s.newServerFn = scenario3c.newServer
+	slog.Info("scenario3c")
+	require.Len(t, s.loaded, 2)
+	scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
+	time.Sleep(2 * time.Millisecond)
+	s.pendingReqCh <- scenario3c.req
+	// finish prior request, so new model can load
+	time.Sleep(6 * time.Millisecond)
+	require.Len(t, s.loaded, 1)
+	scenario3b.ctxDone()
+	select {
+	case resp := <-scenario3c.req.successCh:
+		require.Equal(t, resp.llama, scenario3c.srv)
+		require.Len(t, s.pendingReqCh, 0)
+		require.Len(t, scenario3c.req.errCh, 0)
+	case <-ctx.Done():
+		t.Errorf("timeout")
+	}
+	require.Len(t, s.loaded, 1)
+}
+
+func TestGetRunner(t *testing.T) {
+	ctx, done := context.WithTimeout(context.Background(), 20*time.Millisecond)
+	defer done()
+
+	// Same model, same request
+	scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
+	scenario1a.req.sessionDuration = 0
+	scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
+	scenario1b.req.sessionDuration = 0
+	scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
+	scenario1c.req.sessionDuration = 0
+	maxQueuedRequests = 1
+	s := InitScheduler(ctx)
+	s.getGpuFn = func() gpu.GpuInfoList {
+		g := gpu.GpuInfo{Library: "metal"}
+		g.TotalMemory = 24 * format.GigaByte
+		g.FreeMemory = 12 * format.GigaByte
+		return []gpu.GpuInfo{g}
+	}
+	s.newServerFn = scenario1a.newServer
+	slog.Info("scenario1a")
+	successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
+	require.Len(t, s.pendingReqCh, 1)
+	slog.Info("scenario1b")
+	successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration)
+	require.Len(t, s.pendingReqCh, 1)
+	require.Len(t, successCh1b, 0)
+	require.Len(t, errCh1b, 1)
+	err := <-errCh1b
+	require.Contains(t, err.Error(), "server busy")
+	s.Run(ctx)
+	select {
+	case resp := <-successCh1a:
+		require.Equal(t, resp.llama, scenario1a.srv)
+		require.Len(t, s.pendingReqCh, 0)
+		require.Len(t, errCh1a, 0)
+	case <-ctx.Done():
+		t.Errorf("timeout")
+	}
+	scenario1a.ctxDone()
+	require.Len(t, s.loaded, 1)
+
+	scenario1c.req.model.ModelPath = "bad path"
+	slog.Info("scenario1c")
+	successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration)
+	require.Len(t, s.pendingReqCh, 0)
+	require.Len(t, successCh1c, 0)
+	require.Len(t, errCh1c, 1)
+	err = <-errCh1c
+	require.Contains(t, err.Error(), "bad path")
+	scenario1b.ctxDone()
+
+	time.Sleep(5 * time.Millisecond)
+	require.Len(t, s.loaded, 0)
+}
+
+// TODO - add one scenario that triggers the bogus finished event with positive ref count
+func TestPrematureExpired(t *testing.T) {
+	ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
+	defer done()
+
+	// Same model, same request
+	scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
+	s := InitScheduler(ctx)
+	s.getGpuFn = func() gpu.GpuInfoList {
+		g := gpu.GpuInfo{Library: "metal"}
+		g.TotalMemory = 24 * format.GigaByte
+		g.FreeMemory = 12 * format.GigaByte
+		return []gpu.GpuInfo{g}
+	}
+	s.newServerFn = scenario1a.newServer
+	successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
+	require.Len(t, s.pendingReqCh, 1)
+	s.Run(ctx)
+	select {
+	case resp := <-successCh1a:
+		require.Equal(t, resp.llama, scenario1a.srv)
+		require.Len(t, s.pendingReqCh, 0)
+		require.Len(t, errCh1a, 0)
+		require.Len(t, s.loaded, 1)
+		slog.Info("sending premature expired event now")
+		s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
+	case <-ctx.Done():
+		t.Errorf("timeout")
+	}
+	time.Sleep(scenario1a.req.sessionDuration)
+	scenario1a.ctxDone()
+	time.Sleep(20 * time.Millisecond)
+	require.LessOrEqual(t, len(s.finishedReqCh), 1)
+	time.Sleep(10 * time.Millisecond)
+	require.Len(t, s.finishedReqCh, 0)
+	require.Len(t, s.loaded, 0)
+
+	// also shouldn't happen in real life
+	s.finishedReqCh <- scenario1a.req
+	time.Sleep(5 * time.Millisecond)
+}
+
+func TestUseLoadedRunner(t *testing.T) {
+	ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
+	req := &LlmRequest{
+		ctx:             ctx,
+		successCh:       make(chan *runnerRef, 1),
+		sessionDuration: 2,
+	}
+	finished := make(chan *LlmRequest)
+	llm1 := &mockLlm{}
+	r1 := &runnerRef{llama: llm1, sessionDuration: 1}
+	req.useLoadedRunner(r1, finished)
+	require.Equal(t, uint(1), r1.refCount)
+	require.Equal(t, time.Duration(2), r1.sessionDuration)
+	select {
+	case success := <-req.successCh:
+		require.Equal(t, r1, success)
+	case <-ctx.Done():
+		t.Errorf("timeout")
+	}
+	done()
+	fin := <-finished
+	require.Equal(t, req, fin)
+}
+
+func TestUpdateFreeSpace(t *testing.T) {
+	ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
+	defer done()
+	gpus := gpu.GpuInfoList{
+		{
+			Library: "a",
+			ID:      "1",
+		},
+		{
+			Library: "a",
+			ID:      "2",
+		},
+	}
+	gpus[0].TotalMemory = 1000
+	gpus[0].FreeMemory = 900
+	gpus[1].TotalMemory = 2000
+	gpus[1].FreeMemory = 1900
+	llm1 := &mockLlm{estimatedVRAM: 100}
+	llm2 := &mockLlm{estimatedVRAM: 200}
+	r1 := &runnerRef{llama: llm1, gpus: gpus}
+	r2 := &runnerRef{llama: llm2, gpus: gpus}
+
+	s := InitScheduler(ctx)
+	s.loaded["a"] = r1
+	s.loaded["b"] = r2
+
+	s.updateFreeSpace(gpus)
+	require.Equal(t, uint64(850), gpus[0].FreeMemory)
+	require.Equal(t, uint64(1850), gpus[1].FreeMemory)
+
+}
+
+func TestFindRunnerToUnload(t *testing.T) {
+	ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
+	defer done()
+	req := &LlmRequest{ctx: ctx}
+	r1 := &runnerRef{refCount: 1, sessionDuration: 1}
+	r2 := &runnerRef{sessionDuration: 2}
+
+	s := InitScheduler(ctx)
+	s.loaded["a"] = r1
+	s.loaded["b"] = r2
+
+	resp := s.findRunnerToUnload(req)
+	require.Equal(t, r2, resp)
+	r2.refCount = 1
+	resp = s.findRunnerToUnload(req)
+	require.Equal(t, r1, resp)
+
+}
+
+func TestNeedsReload(t *testing.T) {
+	ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
+	defer done()
+
+	llm := &mockLlm{}
+	runner := &runnerRef{
+		adapters:   []string{"adapter1"},
+		projectors: []string{"projector1"},
+		Options:    &api.Options{},
+		llama:      llm,
+	}
+	req := &LlmRequest{
+		model: &Model{
+			AdapterPaths:   []string{"adapter2"},
+			ProjectorPaths: []string{"projector2"},
+		},
+		opts: api.Options{},
+	}
+	resp := runner.needsReload(ctx, req)
+	require.True(t, resp)
+	req.model.AdapterPaths = runner.adapters
+	resp = runner.needsReload(ctx, req)
+	require.True(t, resp)
+	req.model.ProjectorPaths = runner.projectors
+	runner.loading = true
+	req.opts.NumBatch = 1234
+	resp = runner.needsReload(ctx, req)
+	require.True(t, resp)
+	req.opts.NumBatch = runner.Options.NumBatch
+	llm.pingResp = fmt.Errorf("foo")
+	resp = runner.needsReload(ctx, req)
+	require.True(t, resp)
+	llm.pingResp = nil
+	resp = runner.needsReload(ctx, req)
+	require.False(t, resp)
+	req.opts.NumGPU = 99
+	resp = runner.needsReload(ctx, req)
+	require.False(t, resp)
+}
+
+func TestUnloadAllRunners(t *testing.T) {
+	ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
+	defer done()
+
+	llm1 := &mockLlm{}
+	llm2 := &mockLlm{}
+	s := InitScheduler(ctx)
+	s.unloadAllRunners()
+
+	r1 := &runnerRef{llama: llm1}
+	r2 := &runnerRef{llama: llm2}
+
+	s.loaded["a"] = r1
+	s.loaded["b"] = r2
+	s.unloadAllRunners()
+
+	require.True(t, llm1.closeCalled)
+	require.True(t, llm2.closeCalled)
+}
+
+func TestUnload(t *testing.T) {
+	llm1 := &mockLlm{}
+	r1 := &runnerRef{llama: llm1}
+	r2 := &runnerRef{adapters: []string{"A"}}
+	r1.unload()
+	require.True(t, llm1.closeCalled)
+	r2.unload()
+	require.Nil(t, r2.adapters)
+}
+
+type mockLlm struct {
+	pingResp          error
+	waitResp          error
+	completionResp    error
+	embeddingResp     []float64
+	embeddingRespErr  error
+	tokenizeResp      []int
+	tokenizeRespErr   error
+	detokenizeResp    string
+	detonekizeRespErr error
+	closeResp         error
+	closeCalled       bool
+	estimatedVRAM     uint64
+}
+
+func (s *mockLlm) Ping(ctx context.Context) error             { return s.pingResp }
+func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitResp }
+func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
+	return s.completionResp
+}
+func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
+	return s.embeddingResp, s.embeddingRespErr
+}
+func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
+	return s.tokenizeResp, s.tokenizeRespErr
+}
+func (s *mockLlm) Detokenize(ctx context.Context, tokens []int) (string, error) {
+	return s.detokenizeResp, s.detonekizeRespErr
+}
+func (s *mockLlm) Close() error {
+	s.closeCalled = true
+	return s.closeResp
+}
+func (s *mockLlm) EstimatedVRAM() uint64 { return s.estimatedVRAM }