Browse Source

Only set default keep_alive on initial model load

This change fixes the handling of keep_alive so that if client
request omits the setting, we only set this on initial load.  Once
the model is loaded, if new requests leave this unset, we'll keep
whatever keep_alive was there.
Daniel Hiltgen 10 months ago
parent
commit
955f2a4e03
5 changed files with 70 additions and 71 deletions
  1. 29 2
      envconfig/config.go
  2. 17 0
      envconfig/config_test.go
  3. 3 54
      server/routes.go
  4. 10 4
      server/sched.go
  5. 11 11
      server/sched_test.go

+ 29 - 2
envconfig/config.go

@@ -4,12 +4,14 @@ import (
 	"errors"
 	"fmt"
 	"log/slog"
+	"math"
 	"net"
 	"os"
 	"path/filepath"
 	"runtime"
 	"strconv"
 	"strings"
+	"time"
 )
 
 type OllamaHost struct {
@@ -34,7 +36,7 @@ var (
 	// Set via OLLAMA_HOST in the environment
 	Host *OllamaHost
 	// Set via OLLAMA_KEEP_ALIVE in the environment
-	KeepAlive string
+	KeepAlive time.Duration
 	// Set via OLLAMA_LLM_LIBRARY in the environment
 	LLMLibrary string
 	// Set via OLLAMA_MAX_LOADED_MODELS in the environment
@@ -132,6 +134,7 @@ func init() {
 	NumParallel = 0 // Autoselect
 	MaxRunners = 0  // Autoselect
 	MaxQueuedRequests = 512
+	KeepAlive = 5 * time.Minute
 
 	LoadConfig()
 }
@@ -266,7 +269,10 @@ func LoadConfig() {
 		}
 	}
 
-	KeepAlive = clean("OLLAMA_KEEP_ALIVE")
+	ka := clean("OLLAMA_KEEP_ALIVE")
+	if ka != "" {
+		loadKeepAlive(ka)
+	}
 
 	var err error
 	ModelsDir, err = getModelsDir()
@@ -344,3 +350,24 @@ func getOllamaHost() (*OllamaHost, error) {
 		Port:   port,
 	}, nil
 }
+
+func loadKeepAlive(ka string) {
+	v, err := strconv.Atoi(ka)
+	if err != nil {
+		d, err := time.ParseDuration(ka)
+		if err == nil {
+			if d < 0 {
+				KeepAlive = time.Duration(math.MaxInt64)
+			} else {
+				KeepAlive = d
+			}
+		}
+	} else {
+		d := time.Duration(v) * time.Second
+		if d < 0 {
+			KeepAlive = time.Duration(math.MaxInt64)
+		} else {
+			KeepAlive = d
+		}
+	}
+}

+ 17 - 0
envconfig/config_test.go

@@ -2,8 +2,10 @@ package envconfig
 
 import (
 	"fmt"
+	"math"
 	"net"
 	"testing"
+	"time"
 
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
@@ -23,6 +25,21 @@ func TestConfig(t *testing.T) {
 	t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
 	LoadConfig()
 	require.True(t, FlashAttention)
+	t.Setenv("OLLAMA_KEEP_ALIVE", "")
+	LoadConfig()
+	require.Equal(t, 5*time.Minute, KeepAlive)
+	t.Setenv("OLLAMA_KEEP_ALIVE", "3")
+	LoadConfig()
+	require.Equal(t, 3*time.Second, KeepAlive)
+	t.Setenv("OLLAMA_KEEP_ALIVE", "1h")
+	LoadConfig()
+	require.Equal(t, 1*time.Hour, KeepAlive)
+	t.Setenv("OLLAMA_KEEP_ALIVE", "-1s")
+	LoadConfig()
+	require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
+	t.Setenv("OLLAMA_KEEP_ALIVE", "-1")
+	LoadConfig()
+	require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
 }
 
 func TestClientFromEnvironment(t *testing.T) {

+ 3 - 54
server/routes.go

@@ -9,7 +9,6 @@ import (
 	"io"
 	"io/fs"
 	"log/slog"
-	"math"
 	"net"
 	"net/http"
 	"net/netip"
@@ -17,7 +16,6 @@ import (
 	"os/signal"
 	"path/filepath"
 	"slices"
-	"strconv"
 	"strings"
 	"syscall"
 	"time"
@@ -56,8 +54,6 @@ func init() {
 	gin.SetMode(mode)
 }
 
-var defaultSessionDuration = 5 * time.Minute
-
 func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
 	opts := api.DefaultOptions()
 	if err := opts.FromMap(model.Options); err != nil {
@@ -133,14 +129,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		return
 	}
 
-	var sessionDuration time.Duration
-	if req.KeepAlive == nil {
-		sessionDuration = getDefaultSessionDuration()
-	} else {
-		sessionDuration = req.KeepAlive.Duration
-	}
-
-	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
+	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
 	var runner *runnerRef
 	select {
 	case runner = <-rCh:
@@ -320,32 +309,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 	streamResponse(c, ch)
 }
 
-func getDefaultSessionDuration() time.Duration {
-	if envconfig.KeepAlive != "" {
-		v, err := strconv.Atoi(envconfig.KeepAlive)
-		if err != nil {
-			d, err := time.ParseDuration(envconfig.KeepAlive)
-			if err != nil {
-				return defaultSessionDuration
-			}
-
-			if d < 0 {
-				return time.Duration(math.MaxInt64)
-			}
-
-			return d
-		}
-
-		d := time.Duration(v) * time.Second
-		if d < 0 {
-			return time.Duration(math.MaxInt64)
-		}
-		return d
-	}
-
-	return defaultSessionDuration
-}
-
 func (s *Server) EmbeddingsHandler(c *gin.Context) {
 	var req api.EmbeddingRequest
 	err := c.ShouldBindJSON(&req)
@@ -380,14 +343,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
 		return
 	}
 
-	var sessionDuration time.Duration
-	if req.KeepAlive == nil {
-		sessionDuration = getDefaultSessionDuration()
-	} else {
-		sessionDuration = req.KeepAlive.Duration
-	}
-
-	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
+	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
 	var runner *runnerRef
 	select {
 	case runner = <-rCh:
@@ -1318,14 +1274,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
 		return
 	}
 
-	var sessionDuration time.Duration
-	if req.KeepAlive == nil {
-		sessionDuration = getDefaultSessionDuration()
-	} else {
-		sessionDuration = req.KeepAlive.Duration
-	}
-
-	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
+	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
 	var runner *runnerRef
 	select {
 	case runner = <-rCh:

+ 10 - 4
server/sched.go

@@ -24,7 +24,7 @@ type LlmRequest struct {
 	model           *Model
 	opts            api.Options
 	origNumCtx      int // Track the initial ctx request
-	sessionDuration time.Duration
+	sessionDuration *api.Duration
 	successCh       chan *runnerRef
 	errCh           chan error
 	schedAttempts   uint
@@ -75,7 +75,7 @@ func InitScheduler(ctx context.Context) *Scheduler {
 }
 
 // context must be canceled to decrement ref count and release the runner
-func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) {
+func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
 	if opts.NumCtx < 4 {
 		opts.NumCtx = 4
 	}
@@ -389,7 +389,9 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
 		runner.expireTimer.Stop()
 		runner.expireTimer = nil
 	}
-	runner.sessionDuration = pending.sessionDuration
+	if pending.sessionDuration != nil {
+		runner.sessionDuration = pending.sessionDuration.Duration
+	}
 	pending.successCh <- runner
 	go func() {
 		<-pending.ctx.Done()
@@ -402,6 +404,10 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
 	if numParallel < 1 {
 		numParallel = 1
 	}
+	sessionDuration := envconfig.KeepAlive
+	if req.sessionDuration != nil {
+		sessionDuration = req.sessionDuration.Duration
+	}
 	llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
 	if err != nil {
 		// some older models are not compatible with newer versions of llama.cpp
@@ -419,7 +425,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
 		modelPath:       req.model.ModelPath,
 		llama:           llama,
 		Options:         &req.opts,
-		sessionDuration: req.sessionDuration,
+		sessionDuration: sessionDuration,
 		gpus:            gpus,
 		estimatedVRAM:   llama.EstimatedVRAM(),
 		estimatedTotal:  llama.EstimatedTotal(),

+ 11 - 11
server/sched_test.go

@@ -44,7 +44,7 @@ func TestLoad(t *testing.T) {
 		opts:            api.DefaultOptions(),
 		successCh:       make(chan *runnerRef, 1),
 		errCh:           make(chan error, 1),
-		sessionDuration: 2,
+		sessionDuration: &api.Duration{Duration: 2 * time.Second},
 	}
 	// Fail to load model first
 	s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
@@ -142,7 +142,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
 		ctx:             scenario.ctx,
 		model:           model,
 		opts:            api.DefaultOptions(),
-		sessionDuration: 5 * time.Millisecond,
+		sessionDuration: &api.Duration{Duration: 5 * time.Millisecond},
 		successCh:       make(chan *runnerRef, 1),
 		errCh:           make(chan error, 1),
 	}
@@ -156,18 +156,18 @@ func TestRequests(t *testing.T) {
 
 	// Same model, same request
 	scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
-	scenario1a.req.sessionDuration = 5 * time.Millisecond
+	scenario1a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
 	scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
 	scenario1b.req.model = scenario1a.req.model
 	scenario1b.ggml = scenario1a.ggml
-	scenario1b.req.sessionDuration = 0
+	scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
 
 	// simple reload of same model
 	scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
 	tmpModel := *scenario1a.req.model
 	scenario2a.req.model = &tmpModel
 	scenario2a.ggml = scenario1a.ggml
-	scenario2a.req.sessionDuration = 5 * time.Millisecond
+	scenario2a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
 
 	// Multiple loaded models
 	scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
@@ -318,11 +318,11 @@ func TestGetRunner(t *testing.T) {
 	defer done()
 
 	scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
-	scenario1a.req.sessionDuration = 0
+	scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
 	scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
-	scenario1b.req.sessionDuration = 0
+	scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
 	scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
-	scenario1c.req.sessionDuration = 0
+	scenario1c.req.sessionDuration = &api.Duration{Duration: 0}
 	envconfig.MaxQueuedRequests = 1
 	s := InitScheduler(ctx)
 	s.getGpuFn = func() gpu.GpuInfoList {
@@ -402,7 +402,7 @@ func TestPrematureExpired(t *testing.T) {
 	case <-ctx.Done():
 		t.Fatal("timeout")
 	}
-	time.Sleep(scenario1a.req.sessionDuration)
+	time.Sleep(scenario1a.req.sessionDuration.Duration)
 	scenario1a.ctxDone()
 	time.Sleep(20 * time.Millisecond)
 	require.LessOrEqual(t, len(s.finishedReqCh), 1)
@@ -423,7 +423,7 @@ func TestUseLoadedRunner(t *testing.T) {
 		ctx:             ctx,
 		opts:            api.DefaultOptions(),
 		successCh:       make(chan *runnerRef, 1),
-		sessionDuration: 2,
+		sessionDuration: &api.Duration{Duration: 2},
 	}
 	finished := make(chan *LlmRequest)
 	llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
@@ -614,7 +614,7 @@ func TestAlreadyCanceled(t *testing.T) {
 	dctx, done2 := context.WithCancel(ctx)
 	done2()
 	scenario1a := newScenario(t, dctx, "ollama-model-1", 10)
-	scenario1a.req.sessionDuration = 0
+	scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
 	s := InitScheduler(ctx)
 	slog.Info("scenario1a")
 	s.pendingReqCh <- scenario1a.req