|
@@ -35,7 +35,6 @@ const (
|
|
|
)
|
|
|
|
|
|
var gpuMutex sync.Mutex
|
|
|
-var gpuHandles *handles = nil
|
|
|
|
|
|
// With our current CUDA compile flags, older than 5.0 will not work properly
|
|
|
var CudaComputeMin = [2]C.int{5, 0}
|
|
@@ -85,11 +84,11 @@ var CudartWindowsGlobs = []string{
|
|
|
var CudaTegra string = os.Getenv("JETSON_JETPACK")
|
|
|
|
|
|
// Note: gpuMutex must already be held
|
|
|
-func initGPUHandles() {
|
|
|
+func initGPUHandles() *handles {
|
|
|
|
|
|
// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
|
|
|
|
|
|
- gpuHandles = &handles{nil, nil}
|
|
|
+ gpuHandles := &handles{nil, nil}
|
|
|
var nvmlMgmtName string
|
|
|
var nvmlMgmtPatterns []string
|
|
|
var cudartMgmtName string
|
|
@@ -116,7 +115,7 @@ func initGPUHandles() {
|
|
|
}
|
|
|
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartLinuxGlobs...)
|
|
|
default:
|
|
|
- return
|
|
|
+ return gpuHandles
|
|
|
}
|
|
|
|
|
|
slog.Info("Detecting GPU type")
|
|
@@ -126,7 +125,7 @@ func initGPUHandles() {
|
|
|
if cudart != nil {
|
|
|
slog.Info("Nvidia GPU detected via cudart")
|
|
|
gpuHandles.cudart = cudart
|
|
|
- return
|
|
|
+ return gpuHandles
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -137,10 +136,10 @@ func initGPUHandles() {
|
|
|
if nvml != nil {
|
|
|
slog.Info("Nvidia GPU detected via nvidia-ml")
|
|
|
gpuHandles.nvml = nvml
|
|
|
- return
|
|
|
+ return gpuHandles
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
+ return gpuHandles
|
|
|
}
|
|
|
|
|
|
func GetGPUInfo() GpuInfo {
|
|
@@ -148,9 +147,16 @@ func GetGPUInfo() GpuInfo {
|
|
|
// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
|
|
|
gpuMutex.Lock()
|
|
|
defer gpuMutex.Unlock()
|
|
|
- if gpuHandles == nil {
|
|
|
- initGPUHandles()
|
|
|
- }
|
|
|
+
|
|
|
+ gpuHandles := initGPUHandles()
|
|
|
+ defer func() {
|
|
|
+ if gpuHandles.nvml != nil {
|
|
|
+ C.nvml_release(*gpuHandles.nvml)
|
|
|
+ }
|
|
|
+ if gpuHandles.cudart != nil {
|
|
|
+ C.cudart_release(*gpuHandles.cudart)
|
|
|
+ }
|
|
|
+ }()
|
|
|
|
|
|
// All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX
|
|
|
cpuVariant := GetCPUVariant()
|