Browse Source

Merge pull request #4264 from dhiltgen/show_gpu_visible_settings

Centralize GPU configuration vars
Daniel Hiltgen 10 months ago
parent
commit
2786dff5d3
3 changed files with 33 additions and 6 deletions
  1. 26 1
      envconfig/config.go
  2. 5 4
      gpu/amd_linux.go
  3. 2 1
      gpu/amd_windows.go

+ 26 - 1
envconfig/config.go

@@ -57,6 +57,17 @@ var (
 	SchedSpread bool
 	// Set via OLLAMA_TMPDIR in the environment
 	TmpDir string
+
+	// Set via CUDA_VISIBLE_DEVICES in the environment
+	CudaVisibleDevices string
+	// Set via HIP_VISIBLE_DEVICES in the environment
+	HipVisibleDevices string
+	// Set via ROCR_VISIBLE_DEVICES in the environment
+	RocrVisibleDevices string
+	// Set via GPU_DEVICE_ORDINAL in the environment
+	GpuDeviceOrdinal string
+	// Set via HSA_OVERRIDE_GFX_VERSION in the environment
+	HsaOverrideGfxVersion string
 )
 
 type EnvVar struct {
@@ -66,7 +77,7 @@ type EnvVar struct {
 }
 
 func AsMap() map[string]EnvVar {
-	return map[string]EnvVar{
+	ret := map[string]EnvVar{
 		"OLLAMA_DEBUG":             {"OLLAMA_DEBUG", Debug, "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
 		"OLLAMA_FLASH_ATTENTION":   {"OLLAMA_FLASH_ATTENTION", FlashAttention, "Enabled flash attention"},
 		"OLLAMA_HOST":              {"OLLAMA_HOST", Host, "IP Address for the ollama server (default 127.0.0.1:11434)"},
@@ -84,6 +95,14 @@ func AsMap() map[string]EnvVar {
 		"OLLAMA_SCHED_SPREAD":      {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"},
 		"OLLAMA_TMPDIR":            {"OLLAMA_TMPDIR", TmpDir, "Location for temporary files"},
 	}
+	if runtime.GOOS != "darwin" {
+		ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices, "Set which NVIDIA devices are visible"}
+		ret["HIP_VISIBLE_DEVICES"] = EnvVar{"HIP_VISIBLE_DEVICES", HipVisibleDevices, "Set which AMD devices are visible"}
+		ret["ROCR_VISIBLE_DEVICES"] = EnvVar{"ROCR_VISIBLE_DEVICES", RocrVisibleDevices, "Set which AMD devices are visible"}
+		ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal, "Set which AMD devices are visible"}
+		ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion, "Override the gfx used for all detected AMD GPUs"}
+	}
+	return ret
 }
 
 func Values() map[string]string {
@@ -256,6 +275,12 @@ func LoadConfig() {
 	if err != nil {
 		slog.Error("invalid setting", "OLLAMA_HOST", Host, "error", err, "using default port", Host.Port)
 	}
+
+	CudaVisibleDevices = clean("CUDA_VISIBLE_DEVICES")
+	HipVisibleDevices = clean("HIP_VISIBLE_DEVICES")
+	RocrVisibleDevices = clean("ROCR_VISIBLE_DEVICES")
+	GpuDeviceOrdinal = clean("GPU_DEVICE_ORDINAL")
+	HsaOverrideGfxVersion = clean("HSA_OVERRIDE_GFX_VERSION")
 }
 
 func getModelsDir() (string, error) {

+ 5 - 4
gpu/amd_linux.go

@@ -13,6 +13,7 @@ import (
 	"strconv"
 	"strings"
 
+	"github.com/ollama/ollama/envconfig"
 	"github.com/ollama/ollama/format"
 )
 
@@ -59,9 +60,9 @@ func AMDGetGPUInfo() []RocmGPUInfo {
 
 	// Determine if the user has already pre-selected which GPUs to look at, then ignore the others
 	var visibleDevices []string
-	hipVD := os.Getenv("HIP_VISIBLE_DEVICES")   // zero based index only
-	rocrVD := os.Getenv("ROCR_VISIBLE_DEVICES") // zero based index or UUID, but consumer cards seem to not support UUID
-	gpuDO := os.Getenv("GPU_DEVICE_ORDINAL")    // zero based index
+	hipVD := envconfig.HipVisibleDevices   // zero based index only
+	rocrVD := envconfig.RocrVisibleDevices // zero based index or UUID, but consumer cards seem to not support UUID
+	gpuDO := envconfig.GpuDeviceOrdinal    // zero based index
 	switch {
 	// TODO is this priorty order right?
 	case hipVD != "":
@@ -74,7 +75,7 @@ func AMDGetGPUInfo() []RocmGPUInfo {
 		visibleDevices = strings.Split(gpuDO, ",")
 	}
 
-	gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
+	gfxOverride := envconfig.HsaOverrideGfxVersion
 	var supported []string
 	libDir := ""
 

+ 2 - 1
gpu/amd_windows.go

@@ -10,6 +10,7 @@ import (
 	"strconv"
 	"strings"
 
+	"github.com/ollama/ollama/envconfig"
 	"github.com/ollama/ollama/format"
 )
 
@@ -53,7 +54,7 @@ func AMDGetGPUInfo() []RocmGPUInfo {
 	}
 
 	var supported []string
-	gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
+	gfxOverride := envconfig.HsaOverrideGfxVersion
 	if gfxOverride == "" {
 		supported, err = GetSupportedGFX(libDir)
 		if err != nil {