Selaa lähdekoodia

feat: add support for flash_attn (#4120)

* feat: enable flash attention if supported

* feat: enable flash attention if supported

* feat: enable flash attention if supported

* feat: add flash_attn support
Sam 11 kuukautta sitten
vanhempi
commit
e15307fdf4
2 muutettua tiedostoa jossa 28 lisäystä ja 3 poistoa
  1. 11 3
      llm/ext_server/server.cpp
  2. 17 0
      llm/server.go

+ 11 - 3
llm/ext_server/server.cpp

@@ -2104,6 +2104,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
     printf("  --embedding               enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
     printf("  -np N, --parallel N       number of slots for process requests (default: %d)\n", params.n_parallel);
     printf("  -cb, --cont-batching      enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
+    printf("  -fa, --flash-attn         enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled");
     printf("  -spf FNAME, --system-prompt-file FNAME\n");
     printf("                            set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
     printf("  -ctk TYPE, --cache-type-k TYPE\n");
@@ -2501,7 +2502,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
         {
             params.use_mmap = false;
         }
-        else if (arg == "--numa") {
+        else if (arg == "--numa")
+        {
             if (++i >= argc) {
                 invalid_param = true;
                 break;
@@ -2521,6 +2523,10 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
         {
             params.cont_batching = true;
         }
+        else if (arg == "-fa" || arg == "--flash-attn")
+        {
+            params.flash_attn = true;
+        }
         else if (arg == "-np" || arg == "--parallel")
         {
             if (++i >= argc)
@@ -2529,7 +2535,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
                 break;
             }
             params.n_parallel = std::stoi(argv[i]);
-        } else if (arg == "-n" || arg == "--n-predict")
+        }
+        else if (arg == "-n" || arg == "--n-predict")
         {
             if (++i >= argc)
             {
@@ -2537,7 +2544,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
                 break;
             }
             params.n_predict = std::stoi(argv[i]);
-        } else if (arg == "-spf" || arg == "--system-prompt-file")
+        }
+        else if (arg == "-spf" || arg == "--system-prompt-file")
         {
             if (++i >= argc)
             {

+ 17 - 0
llm/server.go

@@ -200,6 +200,23 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 		params = append(params, "--numa")
 	}
 
+	flashAttnSupported := true
+
+	// partial offloading does not support flash attention
+	if uint64(opts.NumGPU) < ggml.KV().BlockCount() + 1 {
+		flashAttnSupported = false
+	}
+
+	// only cuda (compute capability 7+) and metal support flash attention
+	for _, g := range gpus {
+		if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
+			flashAttnSupported = false
+		}
+	}
+	if flashAttnSupported {
+		params = append(params, "--flash-attn")
+	}
+
 	numParallel := envconfig.NumParallel
 
 	// TODO (jmorganca): multimodal models don't support parallel yet