瀏覽代碼

Use flash attention flag for now (#4580)

* put flash attention behind flag for now

* add test

* remove print

* up timeout for sheduler tests
Jeffrey Morgan 11 月之前
父節點
當前提交
38255d2af1
共有 4 個文件被更改,包括 19 次插入6 次删除
  1. 5 5
      llm/server.go
  2. 10 0
      server/envconfig/config.go
  3. 3 0
      server/envconfig/config_test.go
  4. 1 1
      server/sched_test.go

+ 5 - 5
llm/server.go

@@ -200,20 +200,20 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 		params = append(params, "--numa")
 		params = append(params, "--numa")
 	}
 	}
 
 
-	flashAttnSupported := true
+	flashAttnEnabled := envconfig.FlashAttention
 
 
 	// partial offloading does not support flash attention
 	// partial offloading does not support flash attention
-	if uint64(opts.NumGPU) < ggml.KV().BlockCount() + 1 {
-		flashAttnSupported = false
+	if uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
+		flashAttnEnabled = false
 	}
 	}
 
 
 	// only cuda (compute capability 7+) and metal support flash attention
 	// only cuda (compute capability 7+) and metal support flash attention
 	for _, g := range gpus {
 	for _, g := range gpus {
 		if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
 		if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
-			flashAttnSupported = false
+			flashAttnEnabled = false
 		}
 		}
 	}
 	}
-	if flashAttnSupported {
+	if flashAttnEnabled {
 		params = append(params, "--flash-attn")
 		params = append(params, "--flash-attn")
 	}
 	}
 
 

+ 10 - 0
server/envconfig/config.go

@@ -31,6 +31,8 @@ var (
 	RunnersDir string
 	RunnersDir string
 	// Set via OLLAMA_TMPDIR in the environment
 	// Set via OLLAMA_TMPDIR in the environment
 	TmpDir string
 	TmpDir string
+	// Experimental flash attention
+	FlashAttention bool
 )
 )
 
 
 func AsMap() map[string]string {
 func AsMap() map[string]string {
@@ -45,6 +47,7 @@ func AsMap() map[string]string {
 		"OLLAMA_NUM_PARALLEL":      fmt.Sprintf("%v", NumParallel),
 		"OLLAMA_NUM_PARALLEL":      fmt.Sprintf("%v", NumParallel),
 		"OLLAMA_RUNNERS_DIR":       fmt.Sprintf("%v", RunnersDir),
 		"OLLAMA_RUNNERS_DIR":       fmt.Sprintf("%v", RunnersDir),
 		"OLLAMA_TMPDIR":            fmt.Sprintf("%v", TmpDir),
 		"OLLAMA_TMPDIR":            fmt.Sprintf("%v", TmpDir),
+		"OLLAMA_FLASH_ATTENTION":   fmt.Sprintf("%v", FlashAttention),
 	}
 	}
 }
 }
 
 
@@ -78,6 +81,13 @@ func LoadConfig() {
 		}
 		}
 	}
 	}
 
 
+	if fa := clean("OLLAMA_FLASH_ATTENTION"); fa != "" {
+		d, err := strconv.ParseBool(fa)
+		if err == nil {
+			FlashAttention = d
+		}
+	}
+
 	RunnersDir = clean("OLLAMA_RUNNERS_DIR")
 	RunnersDir = clean("OLLAMA_RUNNERS_DIR")
 	if runtime.GOOS == "windows" && RunnersDir == "" {
 	if runtime.GOOS == "windows" && RunnersDir == "" {
 		// On Windows we do not carry the payloads inside the main executable
 		// On Windows we do not carry the payloads inside the main executable

+ 3 - 0
server/envconfig/config_test.go

@@ -17,4 +17,7 @@ func TestConfig(t *testing.T) {
 	t.Setenv("OLLAMA_DEBUG", "1")
 	t.Setenv("OLLAMA_DEBUG", "1")
 	LoadConfig()
 	LoadConfig()
 	require.True(t, Debug)
 	require.True(t, Debug)
+	t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
+	LoadConfig()
+	require.True(t, FlashAttention)
 }
 }

+ 1 - 1
server/sched_test.go

@@ -151,7 +151,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
 }
 }
 
 
 func TestRequests(t *testing.T) {
 func TestRequests(t *testing.T) {
-	ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
+	ctx, done := context.WithTimeout(context.Background(), time.Second)
 	defer done()
 	defer done()
 
 
 	// Same model, same request
 	// Same model, same request