Browse Source

partial offloading: allow flash attention and disable mmap (#4734)

* partial offloading: allow flash attention and disable mmap

* allow mmap with num_gpu=0
Jeffrey Morgan 11 months ago
parent
commit
a50a87a7b8
1 changed files with 21 additions and 18 deletions
  1. 21 18
      llm/server.go

+ 21 - 18
llm/server.go

@@ -191,35 +191,38 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 		params = append(params, "--memory-f32")
 		params = append(params, "--memory-f32")
 	}
 	}
 
 
-	if opts.UseMLock {
-		params = append(params, "--mlock")
-	}
-
-	if !opts.UseMMap {
-		params = append(params, "--no-mmap")
-	}
-
-	if opts.UseNUMA {
-		params = append(params, "--numa")
-	}
-
 	flashAttnEnabled := envconfig.FlashAttention
 	flashAttnEnabled := envconfig.FlashAttention
 
 
-	// partial offloading does not support flash attention
-	if uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
-		flashAttnEnabled = false
-	}
-
-	// only cuda (compute capability 7+) and metal support flash attention
 	for _, g := range gpus {
 	for _, g := range gpus {
+		// only cuda (compute capability 7+) and metal support flash attention
 		if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
 		if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
 			flashAttnEnabled = false
 			flashAttnEnabled = false
 		}
 		}
+
+		// mmap has issues with partial offloading on metal
+		if g.Library == "metal" &&
+			uint64(opts.NumGPU) > 0 &&
+			uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
+			opts.UseMMap = false
+		}
 	}
 	}
+
 	if flashAttnEnabled {
 	if flashAttnEnabled {
 		params = append(params, "--flash-attn")
 		params = append(params, "--flash-attn")
 	}
 	}
 
 
+	if !opts.UseMMap {
+		params = append(params, "--no-mmap")
+	}
+
+	if opts.UseMLock {
+		params = append(params, "--mlock")
+	}
+
+	if opts.UseNUMA {
+		params = append(params, "--numa")
+	}
+
 	numParallel := envconfig.NumParallel
 	numParallel := envconfig.NumParallel
 
 
 	// TODO (jmorganca): multimodal models don't support parallel yet
 	// TODO (jmorganca): multimodal models don't support parallel yet