Browse Source

ml/backend/ggml: handle tensor split

Michael Yang 2 months ago
parent
commit
b5312f30e8
1 changed files with 32 additions and 13 deletions
  1. 32 13
      ml/backend/ggml/ggml.go

+ 32 - 13
ml/backend/ggml/ggml.go

@@ -93,16 +93,8 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
 		}
 	}
 
-	var sum uint64
-	var cumsum []uint64
-
 	var gpuDeviceBufferTypes []deviceBufferType
 	for _, d := range gpus {
-		var free, total C.size_t
-		C.ggml_backend_dev_memory(d, &free, &total)
-		sum += uint64(free)
-		cumsum = append(cumsum, sum)
-
 		bt := C.ggml_backend_dev_buffer_type(d)
 		gpuDeviceBufferTypes = append(gpuDeviceBufferTypes, deviceBufferType{
 			d:   d,
@@ -110,9 +102,33 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
 		})
 	}
 
-	splits := make([]float64, len(cumsum))
+	splits := make([]float32, len(gpus))
+	if func() bool {
+		for _, s := range params.TensorSplit {
+			if s != 0 {
+				return true
+			}
+		}
+
+		return false
+	}() {
+		splits = params.TensorSplit
+	} else {
+		for i := range splits {
+			var free, total C.size_t
+			C.ggml_backend_dev_memory(gpus[i], &free, &total)
+			splits[i] = float32(free)
+		}
+	}
+
+	var sum float32
+	for i := range splits {
+		sum += splits[i]
+		splits[i] = sum
+	}
+
 	for i := range splits {
-		splits[i] = float64(cumsum[i]) / float64(sum)
+		splits[i] /= sum
 	}
 
 	cpuDeviceBufferTypes := deviceBufferType{C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU), cpuBufferTypes}
@@ -130,9 +146,12 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
 			return cpuDeviceBufferTypes
 		}
 
-		return gpuDeviceBufferTypes[slices.IndexFunc(splits, func(f float64) bool {
-			return float64(i)/float64(blocks+1) < f
-		})]
+		index := slices.IndexFunc(splits, func(f float32) bool { return float32(i)/float32(blocks+1) < f })
+		if index < 0 || index >= len(gpuDeviceBufferTypes) {
+			return cpuDeviceBufferTypes
+		}
+
+		return gpuDeviceBufferTypes[index]
 	}
 
 	layers := make([]deviceBufferType, blocks)