|
@@ -50,10 +50,10 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options)
|
|
|
// fp16 k,v matrices require = n_ctx * n_layer * n_embd / n_head * n_head_kv * 2 bytes each * 2 key and value
|
|
|
kv := 2 * 2 * int64(opts.NumCtx) * int64(ggml.NumLayers()) * int64(ggml.NumEmbed()) * int64(ggml.NumHeadKv()) / int64(ggml.NumHead())
|
|
|
|
|
|
- // this amount is the overhead + tensors in memory
|
|
|
- // TODO: get this from the llama.cpp's graph calculations instead of
|
|
|
- // estimating it's 1/6 * kv_cache_size * num_gqa
|
|
|
- graph := int64(ggml.NumGQA()) * kv / 6
|
|
|
+ // rough estimation for scratch space based on context size, batch size and number of layers in the model
|
|
|
+ // TODO: instead call llama.cpp's alloc functions to measure required memory
|
|
|
+ // TODO: account for quantization levels
|
|
|
+ scratch := 8*int64(opts.NumCtx)*int64(opts.NumBatch)*int64(ggml.NumLayers()) + 1536*1024*1024 // 1536MiB overhead
|
|
|
|
|
|
info := gpu.GetGPUInfo()
|
|
|
switch runtime.GOOS {
|
|
@@ -62,7 +62,7 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options)
|
|
|
break
|
|
|
}
|
|
|
|
|
|
- if size+kv+graph > vram {
|
|
|
+ if size+kv+scratch > vram {
|
|
|
slog.Info("not enough vram available, falling back to CPU only")
|
|
|
info.Library = "cpu"
|
|
|
info.Variant = gpu.GetCPUVariant()
|
|
@@ -99,13 +99,13 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options)
|
|
|
maxlayers := int64(ggml.NumLayers()) + 1
|
|
|
devices := int64(info.DeviceCount)
|
|
|
avg := vram / devices
|
|
|
- layers := maxlayers * (avg - graph) / (kv + size/devices)
|
|
|
+ layers := maxlayers * (avg - scratch) / (kv + size/devices)
|
|
|
if layers > maxlayers {
|
|
|
layers = maxlayers
|
|
|
}
|
|
|
|
|
|
// 1 + 2 must fit on the main gpu
|
|
|
- min := graph + kv*layers/maxlayers
|
|
|
+ min := scratch + kv*layers/maxlayers
|
|
|
if layers <= 0 || min > avg {
|
|
|
slog.Info("not enough vram available, falling back to CPU only")
|
|
|
info.Library = "cpu"
|