Browse Source

keepalive

Michael Yang 10 months ago
parent
commit
8570c1c0ef
3 changed files with 55 additions and 47 deletions
  1. 21 30
      envconfig/config.go
  2. 33 16
      envconfig/config_test.go
  3. 1 1
      server/sched.go

+ 21 - 30
envconfig/config.go

@@ -99,6 +99,26 @@ func Models() string {
 	return filepath.Join(home, ".ollama", "models")
 }
 
+// KeepAlive returns the duration that models stay loaded in memory. KeepAlive can be configured via the OLLAMA_KEEP_ALIVE environment variable.
+// Negative values are treated as infinite. Zero is treated as no keep alive.
+// Default is 5 minutes.
+func KeepAlive() (keepAlive time.Duration) {
+	keepAlive = 5 * time.Minute
+	if s := os.Getenv("OLLAMA_KEEP_ALIVE"); s != "" {
+		if d, err := time.ParseDuration(s); err == nil {
+			keepAlive = d
+		} else if n, err := strconv.ParseInt(s, 10, 64); err == nil {
+			keepAlive = time.Duration(n) * time.Second
+		}
+	}
+
+	if keepAlive < 0 {
+		return time.Duration(math.MaxInt64)
+	}
+
+	return keepAlive
+}
+
 func Bool(k string) func() bool {
 	return func() bool {
 		if s := getenv(k); s != "" {
@@ -130,8 +150,6 @@ var (
 )
 
 var (
-	// Set via OLLAMA_KEEP_ALIVE in the environment
-	KeepAlive time.Duration
 	// Set via OLLAMA_LLM_LIBRARY in the environment
 	LLMLibrary string
 	// Set via OLLAMA_MAX_LOADED_MODELS in the environment
@@ -168,7 +186,7 @@ func AsMap() map[string]EnvVar {
 		"OLLAMA_DEBUG":             {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
 		"OLLAMA_FLASH_ATTENTION":   {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"},
 		"OLLAMA_HOST":              {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
-		"OLLAMA_KEEP_ALIVE":        {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"},
+		"OLLAMA_KEEP_ALIVE":        {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
 		"OLLAMA_LLM_LIBRARY":       {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
 		"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"},
 		"OLLAMA_MAX_QUEUE":         {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
@@ -210,7 +228,6 @@ func init() {
 	NumParallel = 0 // Autoselect
 	MaxRunners = 0  // Autoselect
 	MaxQueuedRequests = 512
-	KeepAlive = 5 * time.Minute
 
 	LoadConfig()
 }
@@ -284,35 +301,9 @@ func LoadConfig() {
 		}
 	}
 
-	ka := getenv("OLLAMA_KEEP_ALIVE")
-	if ka != "" {
-		loadKeepAlive(ka)
-	}
-
 	CudaVisibleDevices = getenv("CUDA_VISIBLE_DEVICES")
 	HipVisibleDevices = getenv("HIP_VISIBLE_DEVICES")
 	RocrVisibleDevices = getenv("ROCR_VISIBLE_DEVICES")
 	GpuDeviceOrdinal = getenv("GPU_DEVICE_ORDINAL")
 	HsaOverrideGfxVersion = getenv("HSA_OVERRIDE_GFX_VERSION")
 }
-
-func loadKeepAlive(ka string) {
-	v, err := strconv.Atoi(ka)
-	if err != nil {
-		d, err := time.ParseDuration(ka)
-		if err == nil {
-			if d < 0 {
-				KeepAlive = time.Duration(math.MaxInt64)
-			} else {
-				KeepAlive = d
-			}
-		}
-	} else {
-		d := time.Duration(v) * time.Second
-		if d < 0 {
-			KeepAlive = time.Duration(math.MaxInt64)
-		} else {
-			KeepAlive = d
-		}
-	}
-}

+ 33 - 16
envconfig/config_test.go

@@ -21,22 +21,6 @@ func TestSmoke(t *testing.T) {
 
 	t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
 	require.True(t, FlashAttention())
-
-	t.Setenv("OLLAMA_KEEP_ALIVE", "")
-	LoadConfig()
-	require.Equal(t, 5*time.Minute, KeepAlive)
-	t.Setenv("OLLAMA_KEEP_ALIVE", "3")
-	LoadConfig()
-	require.Equal(t, 3*time.Second, KeepAlive)
-	t.Setenv("OLLAMA_KEEP_ALIVE", "1h")
-	LoadConfig()
-	require.Equal(t, 1*time.Hour, KeepAlive)
-	t.Setenv("OLLAMA_KEEP_ALIVE", "-1s")
-	LoadConfig()
-	require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
-	t.Setenv("OLLAMA_KEEP_ALIVE", "-1")
-	LoadConfig()
-	require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
 }
 
 func TestHost(t *testing.T) {
@@ -186,3 +170,36 @@ func TestBool(t *testing.T) {
 		})
 	}
 }
+
+func TestKeepAlive(t *testing.T) {
+	cases := map[string]time.Duration{
+		"":       5 * time.Minute,
+		"1s":     time.Second,
+		"1m":     time.Minute,
+		"1h":     time.Hour,
+		"5m0s":   5 * time.Minute,
+		"1h2m3s": 1*time.Hour + 2*time.Minute + 3*time.Second,
+		"0":      time.Duration(0),
+		"60":     60 * time.Second,
+		"120":    2 * time.Minute,
+		"3600":   time.Hour,
+		"-0":     time.Duration(0),
+		"-1":     time.Duration(math.MaxInt64),
+		"-1m":    time.Duration(math.MaxInt64),
+		// invalid values
+		" ":   5 * time.Minute,
+		"???": 5 * time.Minute,
+		"1d":  5 * time.Minute,
+		"1y":  5 * time.Minute,
+		"1w":  5 * time.Minute,
+	}
+
+	for tt, expect := range cases {
+		t.Run(tt, func(t *testing.T) {
+			t.Setenv("OLLAMA_KEEP_ALIVE", tt)
+			if actual := KeepAlive(); actual != expect {
+				t.Errorf("%s: expected %s, got %s", tt, expect, actual)
+			}
+		})
+	}
+}

+ 1 - 1
server/sched.go

@@ -401,7 +401,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
 	if numParallel < 1 {
 		numParallel = 1
 	}
-	sessionDuration := envconfig.KeepAlive
+	sessionDuration := envconfig.KeepAlive()
 	if req.sessionDuration != nil {
 		sessionDuration = req.sessionDuration.Duration
 	}