Преглед изворни кода

Harden GPU mgmt library lookup

When there are multiple management libraries installed on a system
not every one will be compatible with the current driver.  This change
improves our management library algorithm to build up a set of discovered
libraries based on glob patterns, and then try all of them until we're able to
load one without error.
Daniel Hiltgen пре 1 година
родитељ
комит
3c49c3ab0d
5 измењених фајлова са 167 додато и 64 уклоњено
  1. 153 17
      gpu/gpu.go
  2. 5 24
      gpu/gpu_info_cuda.c
  3. 1 1
      gpu/gpu_info_cuda.h
  4. 7 21
      gpu/gpu_info_rocm.c
  5. 1 1
      gpu/gpu_info_rocm.h

+ 153 - 17
gpu/gpu.go

@@ -13,7 +13,10 @@ import "C"
 import (
 	"fmt"
 	"log"
+	"os"
+	"path/filepath"
 	"runtime"
+	"strings"
 	"sync"
 	"unsafe"
 )
@@ -29,31 +32,79 @@ var gpuHandles *handles = nil
 // With our current CUDA compile flags, 5.2 and older will not work properly
 const CudaComputeMajorMin = 6
 
+// Possible locations for the nvidia-ml library
+var CudaLinuxGlobs = []string{
+	"/usr/local/cuda/lib64/libnvidia-ml.so*",
+	"/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so*",
+	"/usr/lib/x86_64-linux-gnu/libnvidia-ml.so*",
+	"/usr/lib/wsl/lib/libnvidia-ml.so*",
+	"/opt/cuda/lib64/libnvidia-ml.so*",
+	"/usr/lib*/libnvidia-ml.so*",
+	"/usr/local/lib*/libnvidia-ml.so*",
+	"/usr/lib/aarch64-linux-gnu/nvidia/current/libnvidia-ml.so*",
+	"/usr/lib/aarch64-linux-gnu/libnvidia-ml.so*",
+}
+
+var CudaWindowsGlobs = []string{
+	"c:\\Windows\\System32\\nvml.dll",
+}
+
+var RocmLinuxGlobs = []string{
+	"/opt/rocm*/lib*/librocm_smi64.so*",
+}
+
+var RocmWindowsGlobs = []string{
+	"c:\\Windows\\System32\\rocm_smi64.dll",
+}
+
 // Note: gpuMutex must already be held
 func initGPUHandles() {
+
 	// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
+
+	var cudaMgmtName string
+	var cudaMgmtPatterns []string
+	var rocmMgmtName string
+	var rocmMgmtPatterns []string
+	switch runtime.GOOS {
+	case "windows":
+		cudaMgmtName = "nvml.dll"
+		cudaMgmtPatterns = make([]string, len(CudaWindowsGlobs))
+		copy(cudaMgmtPatterns, CudaWindowsGlobs)
+		rocmMgmtName = "rocm_smi64.dll"
+		rocmMgmtPatterns = make([]string, len(RocmWindowsGlobs))
+		copy(rocmMgmtPatterns, RocmWindowsGlobs)
+	case "linux":
+		cudaMgmtName = "libnvidia-ml.so"
+		cudaMgmtPatterns = make([]string, len(CudaLinuxGlobs))
+		copy(cudaMgmtPatterns, CudaLinuxGlobs)
+		rocmMgmtName = "librocm_smi64.so"
+		rocmMgmtPatterns = make([]string, len(RocmLinuxGlobs))
+		copy(rocmMgmtPatterns, RocmLinuxGlobs)
+	default:
+		return
+	}
+
 	log.Printf("Detecting GPU type")
 	gpuHandles = &handles{nil, nil}
-	var resp C.cuda_init_resp_t
-	C.cuda_init(&resp)
-	if resp.err != nil {
-		log.Printf("CUDA not detected: %s", C.GoString(resp.err))
-		C.free(unsafe.Pointer(resp.err))
+	cudaLibPaths := FindGPULibs(cudaMgmtName, cudaMgmtPatterns)
+	if len(cudaLibPaths) > 0 {
+		cuda := LoadCUDAMgmt(cudaLibPaths)
+		if cuda != nil {
+			log.Printf("Nvidia GPU detected")
+			gpuHandles.cuda = cuda
+			return
+		}
+	}
 
-		var resp C.rocm_init_resp_t
-		C.rocm_init(&resp)
-		if resp.err != nil {
-			log.Printf("ROCm not detected: %s", C.GoString(resp.err))
-			C.free(unsafe.Pointer(resp.err))
-		} else {
+	rocmLibPaths := FindGPULibs(rocmMgmtName, rocmMgmtPatterns)
+	if len(rocmLibPaths) > 0 {
+		rocm := LoadROCMMgmt(rocmLibPaths)
+		if rocm != nil {
 			log.Printf("Radeon GPU detected")
-			rocm := resp.rh
-			gpuHandles.rocm = &rocm
+			gpuHandles.rocm = rocm
+			return
 		}
-	} else {
-		log.Printf("Nvidia GPU detected")
-		cuda := resp.ch
-		gpuHandles.cuda = &cuda
 	}
 }
 
@@ -143,3 +194,88 @@ func CheckVRAM() (int64, error) {
 
 	return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
 }
+
+func FindGPULibs(baseLibName string, patterns []string) []string {
+	// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
+	var ldPaths []string
+	gpuLibPaths := []string{}
+	log.Printf("Searching for GPU management library %s", baseLibName)
+
+	switch runtime.GOOS {
+	case "windows":
+		ldPaths = strings.Split(os.Getenv("PATH"), ";")
+	case "linux":
+		ldPaths = strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
+	default:
+		return gpuLibPaths
+	}
+	// Start with whatever we find in the PATH/LD_LIBRARY_PATH
+	for _, ldPath := range ldPaths {
+		d, err := filepath.Abs(ldPath)
+		if err != nil {
+			continue
+		}
+		patterns = append(patterns, filepath.Join(d, baseLibName+"*"))
+	}
+	for _, pattern := range patterns {
+		// Ignore glob discovery errors
+		matches, _ := filepath.Glob(pattern)
+		for _, match := range matches {
+			// Resolve any links so we don't try the same lib multiple times
+			// and weed out any dups across globs
+			libPath := match
+			tmp := match
+			var err error
+			for ; err == nil; tmp, err = os.Readlink(libPath) {
+				if !filepath.IsAbs(tmp) {
+					tmp = filepath.Join(filepath.Dir(libPath), tmp)
+				}
+				libPath = tmp
+			}
+			new := true
+			for _, cmp := range gpuLibPaths {
+				if cmp == libPath {
+					new = false
+					break
+				}
+			}
+			if new {
+				gpuLibPaths = append(gpuLibPaths, libPath)
+			}
+		}
+	}
+	log.Printf("Discovered GPU libraries: %v", gpuLibPaths)
+	return gpuLibPaths
+}
+
+func LoadCUDAMgmt(cudaLibPaths []string) *C.cuda_handle_t {
+	var resp C.cuda_init_resp_t
+	for _, libPath := range cudaLibPaths {
+		lib := C.CString(libPath)
+		defer C.free(unsafe.Pointer(lib))
+		C.cuda_init(lib, &resp)
+		if resp.err != nil {
+			log.Printf("Unable to load CUDA management library %s: %s", libPath, C.GoString(resp.err))
+			C.free(unsafe.Pointer(resp.err))
+		} else {
+			return &resp.ch
+		}
+	}
+	return nil
+}
+
+func LoadROCMMgmt(rocmLibPaths []string) *C.rocm_handle_t {
+	var resp C.rocm_init_resp_t
+	for _, libPath := range rocmLibPaths {
+		lib := C.CString(libPath)
+		defer C.free(unsafe.Pointer(lib))
+		C.rocm_init(lib, &resp)
+		if resp.err != nil {
+			log.Printf("Unable to load ROCm management library %s: %s", libPath, C.GoString(resp.err))
+			C.free(unsafe.Pointer(resp.err))
+		} else {
+			return &resp.rh
+		}
+	}
+	return nil
+}

+ 5 - 24
gpu/gpu_info_cuda.c

@@ -4,26 +4,9 @@
 
 #include <string.h>
 
-#ifndef _WIN32
-const char *cuda_lib_paths[] = {
-    "libnvidia-ml.so",
-    "/usr/local/cuda/lib64/libnvidia-ml.so",
-    "/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so",
-    "/usr/lib/x86_64-linux-gnu/libnvidia-ml.so.1",
-    "/usr/lib/wsl/lib/libnvidia-ml.so.1",  // TODO Maybe glob?
-    NULL,
-};
-#else
-const char *cuda_lib_paths[] = {
-    "nvml.dll",
-    "",
-    NULL,
-};
-#endif
-
 #define CUDA_LOOKUP_SIZE 6
 
-void cuda_init(cuda_init_resp_t *resp) {
+void cuda_init(char *cuda_lib_path, cuda_init_resp_t *resp) {
   nvmlReturn_t ret;
   resp->err = NULL;
   const int buflen = 256;
@@ -42,16 +25,12 @@ void cuda_init(cuda_init_resp_t *resp) {
       {"nvmlDeviceGetCudaComputeCapability", (void *)&resp->ch.getComputeCapability},
   };
 
-  for (i = 0; cuda_lib_paths[i] != NULL && resp->ch.handle == NULL; i++) {
-    resp->ch.handle = LOAD_LIBRARY(cuda_lib_paths[i], RTLD_LAZY);
-  }
+  resp->ch.handle = LOAD_LIBRARY(cuda_lib_path, RTLD_LAZY);
   if (!resp->ch.handle) {
-    // TODO improve error message, as the LOAD_ERR will have typically have the
-    // final path that was checked which might be confusing.
     char *msg = LOAD_ERR();
     snprintf(buf, buflen,
              "Unable to load %s library to query for Nvidia GPUs: %s",
-             cuda_lib_paths[0], msg);
+             cuda_lib_path, msg);
     free(msg);
     resp->err = strdup(buf);
     return;
@@ -73,6 +52,8 @@ void cuda_init(cuda_init_resp_t *resp) {
 
   ret = (*resp->ch.initFn)();
   if (ret != NVML_SUCCESS) {
+    UNLOAD_LIBRARY(resp->ch.handle);
+    resp->ch.handle = NULL;
     snprintf(buf, buflen, "nvml vram init failure: %d", ret);
     resp->err = strdup(buf);
   }

+ 1 - 1
gpu/gpu_info_cuda.h

@@ -36,7 +36,7 @@ typedef struct cuda_compute_capability {
   int minor;
 } cuda_compute_capability_t;
 
-void cuda_init(cuda_init_resp_t *resp);
+void cuda_init(char *cuda_lib_path, cuda_init_resp_t *resp);
 void cuda_check_vram(cuda_handle_t ch, mem_info_t *resp);
 void cuda_compute_capability(cuda_handle_t ch, cuda_compute_capability_t *cc);
 

+ 7 - 21
gpu/gpu_info_rocm.c

@@ -4,22 +4,7 @@
 
 #include <string.h>
 
-#ifndef _WIN32
-const char *rocm_lib_paths[] = {
-    "librocm_smi64.so",
-    "/opt/rocm/lib/librocm_smi64.so",
-    NULL,
-};
-#else
-// TODO untested
-const char *rocm_lib_paths[] = {
-    "rocm_smi64.dll",
-    "/opt/rocm/lib/rocm_smi64.dll",
-    NULL,
-};
-#endif
-
-void rocm_init(rocm_init_resp_t *resp) {
+void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp) {
   rsmi_status_t ret;
   resp->err = NULL;
   const int buflen = 256;
@@ -36,14 +21,12 @@ void rocm_init(rocm_init_resp_t *resp) {
       // { "rsmi_dev_id_get", (void*)&resp->rh.getHandle },
   };
 
-  for (i = 0; rocm_lib_paths[i] != NULL && resp->rh.handle == NULL; i++) {
-    resp->rh.handle = LOAD_LIBRARY(rocm_lib_paths[i], RTLD_LAZY);
-  }
+  resp->rh.handle = LOAD_LIBRARY(rocm_lib_path, RTLD_LAZY);
   if (!resp->rh.handle) {
     char *msg = LOAD_ERR();
     snprintf(buf, buflen,
              "Unable to load %s library to query for Radeon GPUs: %s\n",
-             rocm_lib_paths[0], msg);
+             rocm_lib_path, msg);
     free(msg);
     resp->err = strdup(buf);
     return;
@@ -53,6 +36,7 @@ void rocm_init(rocm_init_resp_t *resp) {
     *l[i].p = LOAD_SYMBOL(resp->rh.handle, l[i].s);
     if (!l[i].p) {
       UNLOAD_LIBRARY(resp->rh.handle);
+      resp->rh.handle = NULL;
       char *msg = LOAD_ERR();
       snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
                msg);
@@ -64,6 +48,8 @@ void rocm_init(rocm_init_resp_t *resp) {
 
   ret = (*resp->rh.initFn)(0);
   if (ret != RSMI_STATUS_SUCCESS) {
+    UNLOAD_LIBRARY(resp->rh.handle);
+    resp->rh.handle = NULL;
     snprintf(buf, buflen, "rocm vram init failure: %d", ret);
     resp->err = strdup(buf);
   }
@@ -83,7 +69,7 @@ void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) {
   int i;
 
   if (h.handle == NULL) {
-    resp->err = strdup("nvml handle sn't initialized");
+    resp->err = strdup("rocm handle not initialized");
     return;
   }
 

+ 1 - 1
gpu/gpu_info_rocm.h

@@ -29,7 +29,7 @@ typedef struct rocm_init_resp {
   rocm_handle_t rh;
 } rocm_init_resp_t;
 
-void rocm_init(rocm_init_resp_t *resp);
+void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp);
 void rocm_check_vram(rocm_handle_t rh, mem_info_t *resp);
 
 #endif  // __GPU_INFO_ROCM_H__