Forráskód Böngészése

Refactor intel gpu discovery

Daniel Hiltgen 11 hónapja
szülő
commit
4e2b7e181d
4 módosított fájl, 297 hozzáadás és 162 törlés
  1. 118 58
      gpu/gpu.go
  2. 164 101
      gpu/gpu_info_oneapi.c
  3. 13 2
      gpu/gpu_info_oneapi.h
  4. 2 1
      gpu/types.go

+ 118 - 58
gpu/gpu.go

@@ -16,7 +16,6 @@ import (
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
 	"runtime"
 	"runtime"
-	"strconv"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 	"unsafe"
 	"unsafe"
@@ -25,16 +24,21 @@ import (
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/format"
 )
 )
 
 
-type handles struct {
+type cudaHandles struct {
 	deviceCount int
 	deviceCount int
 	cudart      *C.cudart_handle_t
 	cudart      *C.cudart_handle_t
 	nvcuda      *C.nvcuda_handle_t
 	nvcuda      *C.nvcuda_handle_t
+}
+
+type oneapiHandles struct {
 	oneapi      *C.oneapi_handle_t
 	oneapi      *C.oneapi_handle_t
+	deviceCount int
 }
 }
 
 
 const (
 const (
 	cudaMinimumMemory = 457 * format.MebiByte
 	cudaMinimumMemory = 457 * format.MebiByte
 	rocmMinimumMemory = 457 * format.MebiByte
 	rocmMinimumMemory = 457 * format.MebiByte
+	// TODO OneAPI minimum memory
 )
 )
 
 
 var (
 var (
@@ -107,19 +111,19 @@ var OneapiLinuxGlobs = []string{
 var CudaTegra string = os.Getenv("JETSON_JETPACK")
 var CudaTegra string = os.Getenv("JETSON_JETPACK")
 
 
 // Note: gpuMutex must already be held
 // Note: gpuMutex must already be held
-func initCudaHandles() *handles {
+func initCudaHandles() *cudaHandles {
 
 
 	// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
 	// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
 
 
-	gpuHandles := &handles{}
+	cHandles := &cudaHandles{}
 	// Short Circuit if we already know which library to use
 	// Short Circuit if we already know which library to use
 	if nvcudaLibPath != "" {
 	if nvcudaLibPath != "" {
-		gpuHandles.deviceCount, gpuHandles.nvcuda, _ = LoadNVCUDAMgmt([]string{nvcudaLibPath})
-		return gpuHandles
+		cHandles.deviceCount, cHandles.nvcuda, _ = LoadNVCUDAMgmt([]string{nvcudaLibPath})
+		return cHandles
 	}
 	}
 	if cudartLibPath != "" {
 	if cudartLibPath != "" {
-		gpuHandles.deviceCount, gpuHandles.cudart, _ = LoadCUDARTMgmt([]string{cudartLibPath})
-		return gpuHandles
+		cHandles.deviceCount, cHandles.cudart, _ = LoadCUDARTMgmt([]string{cudartLibPath})
+		return cHandles
 	}
 	}
 
 
 	slog.Debug("searching for GPU discovery libraries for NVIDIA")
 	slog.Debug("searching for GPU discovery libraries for NVIDIA")
@@ -127,8 +131,6 @@ func initCudaHandles() *handles {
 	var cudartMgmtPatterns []string
 	var cudartMgmtPatterns []string
 	var nvcudaMgmtName string
 	var nvcudaMgmtName string
 	var nvcudaMgmtPatterns []string
 	var nvcudaMgmtPatterns []string
-	var oneapiMgmtName string
-	var oneapiMgmtPatterns []string
 
 
 	tmpDir, _ := PayloadsDir()
 	tmpDir, _ := PayloadsDir()
 	switch runtime.GOOS {
 	switch runtime.GOOS {
@@ -140,8 +142,6 @@ func initCudaHandles() *handles {
 		// Aligned with driver, we can't carry as payloads
 		// Aligned with driver, we can't carry as payloads
 		nvcudaMgmtName = "nvcuda.dll"
 		nvcudaMgmtName = "nvcuda.dll"
 		nvcudaMgmtPatterns = NvcudaWindowsGlobs
 		nvcudaMgmtPatterns = NvcudaWindowsGlobs
-		oneapiMgmtName = "ze_intel_gpu64.dll"
-		oneapiMgmtPatterns = OneapiWindowsGlobs
 	case "linux":
 	case "linux":
 		cudartMgmtName = "libcudart.so*"
 		cudartMgmtName = "libcudart.so*"
 		if tmpDir != "" {
 		if tmpDir != "" {
@@ -152,10 +152,8 @@ func initCudaHandles() *handles {
 		// Aligned with driver, we can't carry as payloads
 		// Aligned with driver, we can't carry as payloads
 		nvcudaMgmtName = "libcuda.so*"
 		nvcudaMgmtName = "libcuda.so*"
 		nvcudaMgmtPatterns = NvcudaLinuxGlobs
 		nvcudaMgmtPatterns = NvcudaLinuxGlobs
-		oneapiMgmtName = "libze_intel_gpu.so"
-		oneapiMgmtPatterns = OneapiLinuxGlobs
 	default:
 	default:
-		return gpuHandles
+		return cHandles
 	}
 	}
 
 
 	nvcudaLibPaths := FindGPULibs(nvcudaMgmtName, nvcudaMgmtPatterns)
 	nvcudaLibPaths := FindGPULibs(nvcudaMgmtName, nvcudaMgmtPatterns)
@@ -163,10 +161,10 @@ func initCudaHandles() *handles {
 		deviceCount, nvcuda, libPath := LoadNVCUDAMgmt(nvcudaLibPaths)
 		deviceCount, nvcuda, libPath := LoadNVCUDAMgmt(nvcudaLibPaths)
 		if nvcuda != nil {
 		if nvcuda != nil {
 			slog.Debug("detected GPUs", "count", deviceCount, "library", libPath)
 			slog.Debug("detected GPUs", "count", deviceCount, "library", libPath)
-			gpuHandles.nvcuda = nvcuda
-			gpuHandles.deviceCount = deviceCount
+			cHandles.nvcuda = nvcuda
+			cHandles.deviceCount = deviceCount
 			nvcudaLibPath = libPath
 			nvcudaLibPath = libPath
-			return gpuHandles
+			return cHandles
 		}
 		}
 	}
 	}
 
 
@@ -175,26 +173,45 @@ func initCudaHandles() *handles {
 		deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths)
 		deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths)
 		if cudart != nil {
 		if cudart != nil {
 			slog.Debug("detected GPUs", "library", libPath, "count", deviceCount)
 			slog.Debug("detected GPUs", "library", libPath, "count", deviceCount)
-			gpuHandles.cudart = cudart
-			gpuHandles.deviceCount = deviceCount
+			cHandles.cudart = cudart
+			cHandles.deviceCount = deviceCount
 			cudartLibPath = libPath
 			cudartLibPath = libPath
-			return gpuHandles
+			return cHandles
 		}
 		}
 	}
 	}
 
 
+	return cHandles
+}
+
+// Note: gpuMutex must already be held
+func initOneAPIHandles() *oneapiHandles {
+	oHandles := &oneapiHandles{}
+	var oneapiMgmtName string
+	var oneapiMgmtPatterns []string
+
+	// Short Circuit if we already know which library to use
+	if oneapiLibPath != "" {
+		oHandles.deviceCount, oHandles.oneapi, _ = LoadOneapiMgmt([]string{oneapiLibPath})
+		return oHandles
+	}
+
+	switch runtime.GOOS {
+	case "windows":
+		oneapiMgmtName = "ze_intel_gpu64.dll"
+		oneapiMgmtPatterns = OneapiWindowsGlobs
+	case "linux":
+		oneapiMgmtName = "libze_intel_gpu.so"
+		oneapiMgmtPatterns = OneapiLinuxGlobs
+	default:
+		return oHandles
+	}
+
 	oneapiLibPaths := FindGPULibs(oneapiMgmtName, oneapiMgmtPatterns)
 	oneapiLibPaths := FindGPULibs(oneapiMgmtName, oneapiMgmtPatterns)
 	if len(oneapiLibPaths) > 0 {
 	if len(oneapiLibPaths) > 0 {
-		deviceCount, oneapi, libPath := LoadOneapiMgmt(oneapiLibPaths)
-		if oneapi != nil {
-			slog.Debug("detected Intel GPUs", "library", libPath, "count", deviceCount)
-			gpuHandles.oneapi = oneapi
-			gpuHandles.deviceCount = deviceCount
-			oneapiLibPath = libPath
-			return gpuHandles
-		}
+		oHandles.deviceCount, oHandles.oneapi, oneapiLibPath = LoadOneapiMgmt(oneapiLibPaths)
 	}
 	}
 
 
-	return gpuHandles
+	return oHandles
 }
 }
 
 
 func GetGPUInfo() GpuInfoList {
 func GetGPUInfo() GpuInfoList {
@@ -203,16 +220,22 @@ func GetGPUInfo() GpuInfoList {
 	gpuMutex.Lock()
 	gpuMutex.Lock()
 	defer gpuMutex.Unlock()
 	defer gpuMutex.Unlock()
 	needRefresh := true
 	needRefresh := true
-	var gpuHandles *handles
+	var cHandles *cudaHandles
+	var oHandles *oneapiHandles
 	defer func() {
 	defer func() {
-		if gpuHandles == nil {
-			return
-		}
-		if gpuHandles.cudart != nil {
-			C.cudart_release(*gpuHandles.cudart)
+		if cHandles != nil {
+			if cHandles.cudart != nil {
+				C.cudart_release(*cHandles.cudart)
+			}
+			if cHandles.nvcuda != nil {
+				C.nvcuda_release(*cHandles.nvcuda)
+			}
 		}
 		}
-		if gpuHandles.nvcuda != nil {
-			C.nvcuda_release(*gpuHandles.nvcuda)
+		if oHandles != nil {
+			if oHandles.oneapi != nil {
+				// TODO - is this needed?
+				C.oneapi_release(*oHandles.oneapi)
+			}
 		}
 		}
 	}()
 	}()
 
 
@@ -253,13 +276,11 @@ func GetGPUInfo() GpuInfoList {
 		}
 		}
 
 
 		// Load ALL libraries
 		// Load ALL libraries
-		gpuHandles = initCudaHandles()
-
-		// TODO needs a refactoring pass to init oneapi handles
+		cHandles = initCudaHandles()
 
 
 		// NVIDIA
 		// NVIDIA
-		for i := range gpuHandles.deviceCount {
-			if gpuHandles.cudart != nil || gpuHandles.nvcuda != nil {
+		for i := range cHandles.deviceCount {
+			if cHandles.cudart != nil || cHandles.nvcuda != nil {
 				gpuInfo := CudaGPUInfo{
 				gpuInfo := CudaGPUInfo{
 					GpuInfo: GpuInfo{
 					GpuInfo: GpuInfo{
 						Library: "cuda",
 						Library: "cuda",
@@ -268,12 +289,12 @@ func GetGPUInfo() GpuInfoList {
 				}
 				}
 				var driverMajor int
 				var driverMajor int
 				var driverMinor int
 				var driverMinor int
-				if gpuHandles.cudart != nil {
-					C.cudart_bootstrap(*gpuHandles.cudart, C.int(i), &memInfo)
+				if cHandles.cudart != nil {
+					C.cudart_bootstrap(*cHandles.cudart, C.int(i), &memInfo)
 				} else {
 				} else {
-					C.nvcuda_bootstrap(*gpuHandles.nvcuda, C.int(i), &memInfo)
-					driverMajor = int(gpuHandles.nvcuda.driver_major)
-					driverMinor = int(gpuHandles.nvcuda.driver_minor)
+					C.nvcuda_bootstrap(*cHandles.nvcuda, C.int(i), &memInfo)
+					driverMajor = int(cHandles.nvcuda.driver_major)
+					driverMinor = int(cHandles.nvcuda.driver_minor)
 				}
 				}
 				if memInfo.err != nil {
 				if memInfo.err != nil {
 					slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
 					slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
@@ -297,20 +318,35 @@ func GetGPUInfo() GpuInfoList {
 				// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
 				// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
 				cudaGPUs = append(cudaGPUs, gpuInfo)
 				cudaGPUs = append(cudaGPUs, gpuInfo)
 			}
 			}
-			if gpuHandles.oneapi != nil {
+		}
+
+		// Intel
+		oHandles = initOneAPIHandles()
+		for d := 0; oHandles.oneapi != nil && d < int(oHandles.oneapi.num_drivers); d++ {
+			if oHandles.oneapi == nil {
+				// shouldn't happen
+				slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers))
+				continue
+			}
+			devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d))
+			for i := 0; i < int(devCount); i++ {
 				gpuInfo := OneapiGPUInfo{
 				gpuInfo := OneapiGPUInfo{
 					GpuInfo: GpuInfo{
 					GpuInfo: GpuInfo{
 						Library: "oneapi",
 						Library: "oneapi",
 					},
 					},
-					index: i,
+					driverIndex: d,
+					gpuIndex:    i,
 				}
 				}
 				// TODO - split bootstrapping from updating free memory
 				// TODO - split bootstrapping from updating free memory
-				C.oneapi_check_vram(*gpuHandles.oneapi, &memInfo)
+				C.oneapi_check_vram(*oHandles.oneapi, C.int(d), C.int(i), &memInfo)
+				// TODO - convert this to MinimumMemory based on testing...
 				var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
 				var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
 				memInfo.free = C.uint64_t(totalFreeMem)
 				memInfo.free = C.uint64_t(totalFreeMem)
 				gpuInfo.TotalMemory = uint64(memInfo.total)
 				gpuInfo.TotalMemory = uint64(memInfo.total)
 				gpuInfo.FreeMemory = uint64(memInfo.free)
 				gpuInfo.FreeMemory = uint64(memInfo.free)
-				gpuInfo.ID = strconv.Itoa(i)
+				gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
+				gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
+				// TODO dependency path?
 				oneapiGPUs = append(oneapiGPUs, gpuInfo)
 				oneapiGPUs = append(oneapiGPUs, gpuInfo)
 			}
 			}
 		}
 		}
@@ -325,14 +361,14 @@ func GetGPUInfo() GpuInfoList {
 	if needRefresh {
 	if needRefresh {
 		// TODO - CPU system memory tracking/refresh
 		// TODO - CPU system memory tracking/refresh
 		var memInfo C.mem_info_t
 		var memInfo C.mem_info_t
-		if gpuHandles == nil && len(cudaGPUs) > 0 {
-			gpuHandles = initCudaHandles()
+		if cHandles == nil && len(cudaGPUs) > 0 {
+			cHandles = initCudaHandles()
 		}
 		}
 		for i, gpu := range cudaGPUs {
 		for i, gpu := range cudaGPUs {
-			if gpuHandles.cudart != nil {
-				C.cudart_bootstrap(*gpuHandles.cudart, C.int(gpu.index), &memInfo)
+			if cHandles.cudart != nil {
+				C.cudart_bootstrap(*cHandles.cudart, C.int(gpu.index), &memInfo)
 			} else {
 			} else {
-				C.nvcuda_get_free(*gpuHandles.nvcuda, C.int(gpu.index), &memInfo.free)
+				C.nvcuda_get_free(*cHandles.nvcuda, C.int(gpu.index), &memInfo.free)
 			}
 			}
 			if memInfo.err != nil {
 			if memInfo.err != nil {
 				slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
 				slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
@@ -346,6 +382,23 @@ func GetGPUInfo() GpuInfoList {
 			slog.Debug("updating cuda free memory", "gpu", gpu.ID, "name", gpu.Name, "before", format.HumanBytes2(gpu.FreeMemory), "now", format.HumanBytes2(uint64(memInfo.free)))
 			slog.Debug("updating cuda free memory", "gpu", gpu.ID, "name", gpu.Name, "before", format.HumanBytes2(gpu.FreeMemory), "now", format.HumanBytes2(uint64(memInfo.free)))
 			cudaGPUs[i].FreeMemory = uint64(memInfo.free)
 			cudaGPUs[i].FreeMemory = uint64(memInfo.free)
 		}
 		}
+
+		if oHandles == nil && len(oneapiGPUs) > 0 {
+			oHandles = initOneAPIHandles()
+		}
+		for i, gpu := range oneapiGPUs {
+			if oHandles.oneapi == nil {
+				// shouldn't happen
+				slog.Warn("nil oneapi handle with device count", "count", oHandles.deviceCount)
+				continue
+			}
+			C.oneapi_check_vram(*oHandles.oneapi, C.int(gpu.driverIndex), C.int(gpu.gpuIndex), &memInfo)
+			// TODO - convert this to MinimumMemory based on testing...
+			var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
+			memInfo.free = C.uint64_t(totalFreeMem)
+			oneapiGPUs[i].FreeMemory = uint64(memInfo.free)
+		}
+
 		err := RocmGPUInfoList(rocmGPUs).RefreshFreeMemory()
 		err := RocmGPUInfoList(rocmGPUs).RefreshFreeMemory()
 		if err != nil {
 		if err != nil {
 			slog.Debug("problem refreshing ROCm free memory", "error", err)
 			slog.Debug("problem refreshing ROCm free memory", "error", err)
@@ -359,6 +412,9 @@ func GetGPUInfo() GpuInfoList {
 	for _, gpu := range rocmGPUs {
 	for _, gpu := range rocmGPUs {
 		resp = append(resp, gpu.GpuInfo)
 		resp = append(resp, gpu.GpuInfo)
 	}
 	}
+	for _, gpu := range oneapiGPUs {
+		resp = append(resp, gpu.GpuInfo)
+	}
 	if len(resp) == 0 {
 	if len(resp) == 0 {
 		resp = append(resp, cpus[0].GpuInfo)
 		resp = append(resp, cpus[0].GpuInfo)
 	}
 	}
@@ -476,6 +532,7 @@ func LoadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string) {
 
 
 func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) {
 func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) {
 	var resp C.oneapi_init_resp_t
 	var resp C.oneapi_init_resp_t
+	num_devices := 0
 	resp.oh.verbose = getVerboseState()
 	resp.oh.verbose = getVerboseState()
 	for _, libPath := range oneapiLibPaths {
 	for _, libPath := range oneapiLibPaths {
 		lib := C.CString(libPath)
 		lib := C.CString(libPath)
@@ -485,7 +542,10 @@ func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) {
 			slog.Debug("Unable to load oneAPI management library", "library", libPath, "error", C.GoString(resp.err))
 			slog.Debug("Unable to load oneAPI management library", "library", libPath, "error", C.GoString(resp.err))
 			C.free(unsafe.Pointer(resp.err))
 			C.free(unsafe.Pointer(resp.err))
 		} else {
 		} else {
-			return int(resp.num_devices), &resp.oh, libPath
+			for i := 0; i < int(resp.oh.num_drivers); i++ {
+				num_devices += int(C.oneapi_get_device_count(resp.oh, C.int(i)))
+			}
+			return num_devices, &resp.oh, libPath
 		}
 		}
 	}
 	}
 	return 0, nil, ""
 	return 0, nil, ""

+ 164 - 101
gpu/gpu_info_oneapi.c

@@ -8,9 +8,13 @@ void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp)
 {
 {
   ze_result_t ret;
   ze_result_t ret;
   resp->err = NULL;
   resp->err = NULL;
+  resp->oh.devices = NULL;
+  resp->oh.num_devices = NULL;
+  resp->oh.drivers = NULL;
+  resp->oh.num_drivers = 0;
   const int buflen = 256;
   const int buflen = 256;
   char buf[buflen + 1];
   char buf[buflen + 1];
-  int i;
+  int i, d, count;
   struct lookup
   struct lookup
   {
   {
     char *s;
     char *s;
@@ -66,19 +70,65 @@ void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp)
   ret = (*resp->oh.zesInit)(0);
   ret = (*resp->oh.zesInit)(0);
   if (ret != ZE_RESULT_SUCCESS)
   if (ret != ZE_RESULT_SUCCESS)
   {
   {
-    LOG(resp->oh.verbose, "zesInit err: %d\n", ret);
-    UNLOAD_LIBRARY(resp->oh.handle);
-    resp->oh.handle = NULL;
-    snprintf(buf, buflen, "oneapi vram init failure: %d", ret);
+    LOG(resp->oh.verbose, "zesInit err: %x\n", ret);
+    snprintf(buf, buflen, "oneapi vram init failure: %x", ret);
     resp->err = strdup(buf);
     resp->err = strdup(buf);
+    oneapi_release(resp->oh);
+    return;
   }
   }
 
 
-  (*resp->oh.zesDriverGet)(&resp->num_devices, NULL);
+  count = 0;
+  ret = (*resp->oh.zesDriverGet)(&resp->oh.num_drivers, NULL);
+  if (ret != ZE_RESULT_SUCCESS)
+  {
+    LOG(resp->oh.verbose, "zesDriverGet err: %x\n", ret);
+    snprintf(buf, buflen, "unable to get driver count: %x", ret);
+    resp->err = strdup(buf);
+    oneapi_release(resp->oh);
+    return;
+  }
+  LOG(resp->oh.verbose, "oneapi driver count: %d\n", resp->oh.num_drivers);
+  resp->oh.drivers = malloc(resp->oh.num_drivers * sizeof(zes_driver_handle_t));
+  resp->oh.num_devices = malloc(resp->oh.num_drivers * sizeof(uint32_t));
+  memset(&resp->oh.num_devices[0], 0, resp->oh.num_drivers * sizeof(uint32_t));
+  resp->oh.devices = malloc(resp->oh.num_drivers * sizeof(zes_device_handle_t*));
+  ret = (*resp->oh.zesDriverGet)(&resp->oh.num_drivers, &resp->oh.drivers[0]);
+  if (ret != ZE_RESULT_SUCCESS)
+  {
+    LOG(resp->oh.verbose, "zesDriverGet err: %x\n", ret);
+    snprintf(buf, buflen, "unable to get driver count: %x", ret);
+    resp->err = strdup(buf);
+    oneapi_release(resp->oh);
+    return;
+  }
+
+  for (d = 0; d < resp->oh.num_drivers; d++) {
+    ret = (*resp->oh.zesDeviceGet)(resp->oh.drivers[d], &resp->oh.num_devices[d], NULL);
+    if (ret != ZE_RESULT_SUCCESS)
+    {
+      LOG(resp->oh.verbose, "zesDeviceGet err: %x\n", ret);
+      snprintf(buf, buflen, "unable to get device count: %x", ret);
+      resp->err = strdup(buf);
+      oneapi_release(resp->oh);
+      return;
+    }
+    resp->oh.devices[d] = malloc(resp->oh.num_devices[d] * sizeof(zes_device_handle_t));
+    ret = (*resp->oh.zesDeviceGet)(resp->oh.drivers[d], &resp->oh.num_devices[d], resp->oh.devices[d]);
+    if (ret != ZE_RESULT_SUCCESS)
+    {
+      LOG(resp->oh.verbose, "zesDeviceGet err: %x\n", ret);
+      snprintf(buf, buflen, "unable to get device count: %x", ret);
+      resp->err = strdup(buf);
+      oneapi_release(resp->oh);
+      return;
+    }
+    count += resp->oh.num_devices[d];
+  }
 
 
   return;
   return;
 }
 }
 
 
-void oneapi_check_vram(oneapi_handle_t h, mem_info_t *resp)
+void oneapi_check_vram(oneapi_handle_t h, int driver, int device, mem_info_t *resp)
 {
 {
   ze_result_t ret;
   ze_result_t ret;
   resp->err = NULL;
   resp->err = NULL;
@@ -93,122 +143,135 @@ void oneapi_check_vram(oneapi_handle_t h, mem_info_t *resp)
     resp->err = strdup("Level-Zero handle not initialized");
     resp->err = strdup("Level-Zero handle not initialized");
     return;
     return;
   }
   }
+  
+  if (driver > h.num_drivers || device > h.num_devices[driver]) {
+    resp->err = strdup("driver of device index out of bounds");
+    return;
+  }
 
 
-  uint32_t driversCount = 0;
-  ret = (*h.zesDriverGet)(&driversCount, NULL);
+  resp->total = 0;
+  resp->free = 0;
+
+  zes_device_ext_properties_t ext_props;
+  ext_props.stype = ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES;
+  ext_props.pNext = NULL;
+
+  zes_device_properties_t props;
+  props.stype = ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES;
+  props.pNext = &ext_props;
+
+  ret = (*h.zesDeviceGetProperties)(h.devices[driver][device], &props);
   if (ret != ZE_RESULT_SUCCESS)
   if (ret != ZE_RESULT_SUCCESS)
   {
   {
-    snprintf(buf, buflen, "unable to get driver count: %d", ret);
+    snprintf(buf, buflen, "unable to get device properties: %d", ret);
     resp->err = strdup(buf);
     resp->err = strdup(buf);
     return;
     return;
   }
   }
-  LOG(h.verbose, "discovered %d Level-Zero drivers\n", driversCount);
 
 
-  zes_driver_handle_t *allDrivers =
-      malloc(driversCount * sizeof(zes_driver_handle_t));
-  (*h.zesDriverGet)(&driversCount, allDrivers);
+  snprintf(&resp->gpu_name[0], GPU_NAME_LEN, props.modelName);
 
 
-  resp->total = 0;
-  resp->free = 0;
+  // TODO this needs to map to ONEAPI_DEVICE_SELECTOR syntax
+  // (this is probably wrong...)
+  // TODO - the driver isn't included - what if there are multiple drivers?
+  snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", device);
 
 
-  for (d = 0; d < driversCount; d++)
+  if (h.verbose)
   {
   {
-    uint32_t deviceCount = 0;
-    ret = (*h.zesDeviceGet)(allDrivers[d], &deviceCount, NULL);
+    // When in verbose mode, report more information about
+    // the card we discover.
+    LOG(h.verbose, "[%d:%d] oneAPI device name: %s\n", driver, device,
+        props.modelName);
+    LOG(h.verbose, "[%d:%d] oneAPI brand: %s\n", driver, device,
+        props.brandName);
+    LOG(h.verbose, "[%d:%d] oneAPI vendor: %s\n", driver, device,
+        props.vendorName);
+    LOG(h.verbose, "[%d:%d] oneAPI S/N: %s\n", driver, device,
+        props.serialNumber);
+    LOG(h.verbose, "[%d:%d] oneAPI board number: %s\n", driver, device,
+        props.boardNumber);
+  }
+
+  // TODO
+  // Compute Capability equivalent in resp->major, resp->minor, resp->patch
+
+  uint32_t memCount = 0;
+  ret = (*h.zesDeviceEnumMemoryModules)(h.devices[driver][device], &memCount, NULL);
+  if (ret != ZE_RESULT_SUCCESS)
+  {
+    snprintf(buf, buflen,
+              "unable to enumerate Level-Zero memory modules: %x", ret);
+    resp->err = strdup(buf);
+    return;
+  }
+
+  LOG(h.verbose, "discovered %d Level-Zero memory modules\n", memCount);
+
+  zes_mem_handle_t *mems = malloc(memCount * sizeof(zes_mem_handle_t));
+  (*h.zesDeviceEnumMemoryModules)(h.devices[driver][device], &memCount, mems);
+
+  for (m = 0; m < memCount; m++)
+  {
+    zes_mem_state_t state;
+    state.stype = ZES_STRUCTURE_TYPE_MEM_STATE;
+    state.pNext = NULL;
+    ret = (*h.zesMemoryGetState)(mems[m], &state);
     if (ret != ZE_RESULT_SUCCESS)
     if (ret != ZE_RESULT_SUCCESS)
     {
     {
-      snprintf(buf, buflen, "unable to get device count: %d", ret);
+      snprintf(buf, buflen, "unable to get memory state: %x", ret);
       resp->err = strdup(buf);
       resp->err = strdup(buf);
-      free(allDrivers);
+      free(mems);
       return;
       return;
     }
     }
 
 
-    LOG(h.verbose, "discovered %d Level-Zero devices\n", deviceCount);
+    resp->total += state.size;
+    resp->free += state.free;
+  }
 
 
-    zes_device_handle_t *devices =
-        malloc(deviceCount * sizeof(zes_device_handle_t));
-    (*h.zesDeviceGet)(allDrivers[d], &deviceCount, devices);
+  free(mems);
+}
 
 
-    for (i = 0; i < deviceCount; i++)
+void oneapi_release(oneapi_handle_t h)
+{
+  int d;
+  LOG(h.verbose, "releasing oneapi library\n");
+  for (d = 0; d < h.num_drivers; d++)
+  {
+    if (h.devices != NULL && h.devices[d] != NULL)
     {
     {
-      zes_device_ext_properties_t ext_props;
-      ext_props.stype = ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES;
-      ext_props.pNext = NULL;
-
-      zes_device_properties_t props;
-      props.stype = ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES;
-      props.pNext = &ext_props;
-
-      ret = (*h.zesDeviceGetProperties)(devices[i], &props);
-      if (ret != ZE_RESULT_SUCCESS)
-      {
-        snprintf(buf, buflen, "unable to get device properties: %d", ret);
-        resp->err = strdup(buf);
-        free(allDrivers);
-        free(devices);
-        return;
-      }
-
-      if (h.verbose)
-      {
-        // When in verbose mode, report more information about
-        // the card we discover.
-        LOG(h.verbose, "[%d] oneAPI device name: %s\n", i,
-            props.modelName);
-        LOG(h.verbose, "[%d] oneAPI brand: %s\n", i,
-            props.brandName);
-        LOG(h.verbose, "[%d] oneAPI vendor: %s\n", i,
-            props.vendorName);
-        LOG(h.verbose, "[%d] oneAPI S/N: %s\n", i,
-            props.serialNumber);
-        LOG(h.verbose, "[%d] oneAPI board number: %s\n", i,
-            props.boardNumber);
-      }
-
-      uint32_t memCount = 0;
-      ret = (*h.zesDeviceEnumMemoryModules)(devices[i], &memCount, NULL);
-      if (ret != ZE_RESULT_SUCCESS)
-      {
-        snprintf(buf, buflen,
-                 "unable to enumerate Level-Zero memory modules: %d", ret);
-        resp->err = strdup(buf);
-        free(allDrivers);
-        free(devices);
-        return;
-      }
-
-      LOG(h.verbose, "discovered %d Level-Zero memory modules\n", memCount);
-
-      zes_mem_handle_t *mems = malloc(memCount * sizeof(zes_mem_handle_t));
-      (*h.zesDeviceEnumMemoryModules)(devices[i], &memCount, mems);
-
-      for (m = 0; m < memCount; m++)
-      {
-        zes_mem_state_t state;
-        state.stype = ZES_STRUCTURE_TYPE_MEM_STATE;
-        state.pNext = NULL;
-        ret = (*h.zesMemoryGetState)(mems[m], &state);
-        if (ret != ZE_RESULT_SUCCESS)
-        {
-          snprintf(buf, buflen, "unable to get memory state: %d", ret);
-          resp->err = strdup(buf);
-          free(allDrivers);
-          free(devices);
-          free(mems);
-          return;
-        }
-
-        resp->total += state.size;
-        resp->free += state.free;
-      }
-
-      free(mems);
+      free(h.devices[d]);
     }
     }
-
-    free(devices);
   }
   }
+  if (h.devices != NULL)
+  {
+    free(h.devices);
+    h.devices = NULL;
+  }
+  if (h.num_devices != NULL)
+  {
+    free(h.num_devices);
+    h.num_devices = NULL;
+  }
+  if (h.drivers != NULL)
+  {
+    free(h.drivers);
+    h.drivers = NULL;
+  }
+  h.num_drivers = 0;
+  UNLOAD_LIBRARY(h.handle);
+  h.handle = NULL;
+}
 
 
-  free(allDrivers);
+int oneapi_get_device_count(oneapi_handle_t h, int driver) 
+{
+  if (h.handle == NULL || h.num_devices == NULL) 
+  {
+    return 0;
+  }
+  if (driver > h.num_drivers)
+  {
+    return 0;
+  }
+  return (int)h.num_devices[driver];
 }
 }
 
 
 #endif // __APPLE__
 #endif // __APPLE__

+ 13 - 2
gpu/gpu_info_oneapi.h

@@ -175,6 +175,16 @@ typedef struct oneapi_handle
 {
 {
   void *handle;
   void *handle;
   uint16_t verbose;
   uint16_t verbose;
+
+  uint32_t num_drivers;
+  zes_driver_handle_t *drivers; 
+  uint32_t *num_devices;
+  zes_device_handle_t **devices; 
+
+  // TODO Driver major, minor information
+  // int driver_major;
+  // int driver_minor;
+
   ze_result_t (*zesInit)(int);
   ze_result_t (*zesInit)(int);
   ze_result_t (*zesDriverGet)(uint32_t *pCount, zes_driver_handle_t *phDrivers);
   ze_result_t (*zesDriverGet)(uint32_t *pCount, zes_driver_handle_t *phDrivers);
   ze_result_t (*zesDeviceGet)(zes_driver_handle_t hDriver, uint32_t *pCount,
   ze_result_t (*zesDeviceGet)(zes_driver_handle_t hDriver, uint32_t *pCount,
@@ -194,7 +204,6 @@ typedef struct oneapi_handle
 typedef struct oneapi_init_resp
 typedef struct oneapi_init_resp
 {
 {
   char *err; // If err is non-null handle is invalid
   char *err; // If err is non-null handle is invalid
-  int num_devices;
   oneapi_handle_t oh;
   oneapi_handle_t oh;
 } oneapi_init_resp_t;
 } oneapi_init_resp_t;
 
 
@@ -205,7 +214,9 @@ typedef struct oneapi_version_resp
 } oneapi_version_resp_t;
 } oneapi_version_resp_t;
 
 
 void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp);
 void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp);
-void oneapi_check_vram(oneapi_handle_t rh, mem_info_t *resp);
+void oneapi_check_vram(oneapi_handle_t h, int driver, int device, mem_info_t *resp);
+void oneapi_release(oneapi_handle_t h);
+int oneapi_get_device_count(oneapi_handle_t h, int driver);
 
 
 #endif // __GPU_INFO_INTEL_H__
 #endif // __GPU_INFO_INTEL_H__
 #endif // __APPLE__
 #endif // __APPLE__

+ 2 - 1
gpu/types.go

@@ -57,7 +57,8 @@ type RocmGPUInfoList []RocmGPUInfo
 
 
 type OneapiGPUInfo struct {
 type OneapiGPUInfo struct {
 	GpuInfo
 	GpuInfo
-	index int // device index
+	driverIndex int // nolint: unused
+	gpuIndex    int // nolint: unused
 }
 }
 type OneapiGPUInfoList []OneapiGPUInfo
 type OneapiGPUInfoList []OneapiGPUInfo