Browse Source

calculate overhead based number of gpu devices (#1875)

Jeffrey Morgan 1 năm trước cách đây
mục cha
commit
c336693f07
8 tập tin đã thay đổi với 13 bổ sung6 xóa
  1. 3 1
      gpu/gpu.go
  2. 1 0
      gpu/gpu_darwin.go
  3. 1 0
      gpu/gpu_info.h
  4. 2 0
      gpu/gpu_info_cpu.c
  5. 2 4
      gpu/gpu_info_cuda.c
  6. 2 0
      gpu/gpu_info_rocm.c
  7. 1 1
      gpu/gpu_test.go
  8. 1 0
      gpu/types.go

+ 3 - 1
gpu/gpu.go

@@ -110,6 +110,8 @@ func GetGPUInfo() GpuInfo {
 		C.free(unsafe.Pointer(memInfo.err))
 		return resp
 	}
+
+	resp.DeviceCount = uint32(memInfo.count)
 	resp.FreeMemory = uint64(memInfo.free)
 	resp.TotalMemory = uint64(memInfo.total)
 	return resp
@@ -132,7 +134,7 @@ func CheckVRAM() (int64, error) {
 	gpuInfo := GetGPUInfo()
 	if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
 		// leave 10% or 384Mi of VRAM free for unaccounted for overhead
-		overhead := gpuInfo.FreeMemory / 10
+		overhead := gpuInfo.FreeMemory * uint64(gpuInfo.DeviceCount) / 10
 		if overhead < 384*1024*1024 {
 			overhead = 384 * 1024 * 1024
 		}

+ 1 - 0
gpu/gpu_darwin.go

@@ -42,6 +42,7 @@ func getCPUMem() (memInfo, error) {
 	return memInfo{
 		TotalMemory: 0,
 		FreeMemory:  0,
+		DeviceCount: 0,
 	}, nil
 }
 

+ 1 - 0
gpu/gpu_info.h

@@ -34,6 +34,7 @@ extern "C" {
 typedef struct mem_info {
   uint64_t total;
   uint64_t free;
+  unsigned int count;
   char *err;  // If non-nill, caller responsible for freeing
 } mem_info_t;
 

+ 2 - 0
gpu/gpu_info_cpu.c

@@ -8,6 +8,7 @@ void cpu_check_ram(mem_info_t *resp) {
   MEMORYSTATUSEX info;
   info.dwLength = sizeof(info);
   if (GlobalMemoryStatusEx(&info) != 0) {
+    resp->count = 1;
     resp->total = info.ullTotalPhys;
     resp->free = info.ullAvailPhys;
   } else {
@@ -26,6 +27,7 @@ void cpu_check_ram(mem_info_t *resp) {
   if (sysinfo(&info) != 0) {
     resp->err = strdup(strerror(errno));
   } else {
+    resp->count = 1;
     resp->total = info.totalram * info.mem_unit;
     resp->free = info.freeram * info.mem_unit;
   }

+ 2 - 4
gpu/gpu_info_cuda.c

@@ -94,8 +94,7 @@ void cuda_check_vram(cuda_handle_t h, mem_info_t *resp) {
     return;
   }
 
-  unsigned int devices;
-  ret = (*h.getCount)(&devices);
+  ret = (*h.getCount)(&resp->count);
   if (ret != NVML_SUCCESS) {
     snprintf(buf, buflen, "unable to get device count: %d", ret);
     resp->err = strdup(buf);
@@ -104,8 +103,7 @@ void cuda_check_vram(cuda_handle_t h, mem_info_t *resp) {
 
   resp->total = 0;
   resp->free = 0;
-
-  for (i = 0; i < devices; i++) {
+  for (i = 0; i < resp->count; i++) {
     ret = (*h.getHandle)(i, &device);
     if (ret != NVML_SUCCESS) {
       snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);

+ 2 - 0
gpu/gpu_info_rocm.c

@@ -110,6 +110,8 @@ void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) {
     return;
   }
 
+  // TODO: set this to the actual number of devices
+  resp->count = 1;
   resp->total = totalMem;
   resp->free = totalMem - usedMem;
   return;

+ 1 - 1
gpu/gpu_test.go

@@ -18,6 +18,7 @@ func TestBasicGetGPUInfo(t *testing.T) {
 	case "linux", "windows":
 		assert.Greater(t, info.TotalMemory, uint64(0))
 		assert.Greater(t, info.FreeMemory, uint64(0))
+		assert.Greater(t, info.DeviceCount, uint64(0))
 	default:
 		return
 	}
@@ -35,7 +36,6 @@ func TestCPUMemInfo(t *testing.T) {
 	default:
 		return
 	}
-
 }
 
 // TODO - add some logic to figure out card type through other means and actually verify we got back what we expected

+ 1 - 0
gpu/types.go

@@ -3,6 +3,7 @@ package gpu
 type memInfo struct {
 	TotalMemory uint64 `json:"total_memory,omitempty"`
 	FreeMemory  uint64 `json:"free_memory,omitempty"`
+	DeviceCount uint32 `json:"device_count,omitempty"`
 }
 
 // Beginning of an `ollama info` command