Pārlūkot izejas kodu

Merge pull request #5447 from dhiltgen/fix_keepalive

Only set default keep_alive on initial model load
Daniel Hiltgen 10 mēneši atpakaļ
vecāks
revīzija
8072e205ff
5 mainītis faili ar 70 papildinājumiem un 71 dzēšanām
  1. 29 2
      envconfig/config.go
  2. 17 0
      envconfig/config_test.go
  3. 3 54
      server/routes.go
  4. 10 4
      server/sched.go
  5. 11 11
      server/sched_test.go

+ 29 - 2
envconfig/config.go

@@ -4,12 +4,14 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"log/slog"
 	"log/slog"
+	"math"
 	"net"
 	"net"
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
 	"runtime"
 	"runtime"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
+	"time"
 )
 )
 
 
 type OllamaHost struct {
 type OllamaHost struct {
@@ -34,7 +36,7 @@ var (
 	// Set via OLLAMA_HOST in the environment
 	// Set via OLLAMA_HOST in the environment
 	Host *OllamaHost
 	Host *OllamaHost
 	// Set via OLLAMA_KEEP_ALIVE in the environment
 	// Set via OLLAMA_KEEP_ALIVE in the environment
-	KeepAlive string
+	KeepAlive time.Duration
 	// Set via OLLAMA_LLM_LIBRARY in the environment
 	// Set via OLLAMA_LLM_LIBRARY in the environment
 	LLMLibrary string
 	LLMLibrary string
 	// Set via OLLAMA_MAX_LOADED_MODELS in the environment
 	// Set via OLLAMA_MAX_LOADED_MODELS in the environment
@@ -132,6 +134,7 @@ func init() {
 	NumParallel = 0 // Autoselect
 	NumParallel = 0 // Autoselect
 	MaxRunners = 0  // Autoselect
 	MaxRunners = 0  // Autoselect
 	MaxQueuedRequests = 512
 	MaxQueuedRequests = 512
+	KeepAlive = 5 * time.Minute
 
 
 	LoadConfig()
 	LoadConfig()
 }
 }
@@ -266,7 +269,10 @@ func LoadConfig() {
 		}
 		}
 	}
 	}
 
 
-	KeepAlive = clean("OLLAMA_KEEP_ALIVE")
+	ka := clean("OLLAMA_KEEP_ALIVE")
+	if ka != "" {
+		loadKeepAlive(ka)
+	}
 
 
 	var err error
 	var err error
 	ModelsDir, err = getModelsDir()
 	ModelsDir, err = getModelsDir()
@@ -344,3 +350,24 @@ func getOllamaHost() (*OllamaHost, error) {
 		Port:   port,
 		Port:   port,
 	}, nil
 	}, nil
 }
 }
+
+func loadKeepAlive(ka string) {
+	v, err := strconv.Atoi(ka)
+	if err != nil {
+		d, err := time.ParseDuration(ka)
+		if err == nil {
+			if d < 0 {
+				KeepAlive = time.Duration(math.MaxInt64)
+			} else {
+				KeepAlive = d
+			}
+		}
+	} else {
+		d := time.Duration(v) * time.Second
+		if d < 0 {
+			KeepAlive = time.Duration(math.MaxInt64)
+		} else {
+			KeepAlive = d
+		}
+	}
+}

+ 17 - 0
envconfig/config_test.go

@@ -2,8 +2,10 @@ package envconfig
 
 
 import (
 import (
 	"fmt"
 	"fmt"
+	"math"
 	"net"
 	"net"
 	"testing"
 	"testing"
+	"time"
 
 
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
 	"github.com/stretchr/testify/require"
@@ -23,6 +25,21 @@ func TestConfig(t *testing.T) {
 	t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
 	t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
 	LoadConfig()
 	LoadConfig()
 	require.True(t, FlashAttention)
 	require.True(t, FlashAttention)
+	t.Setenv("OLLAMA_KEEP_ALIVE", "")
+	LoadConfig()
+	require.Equal(t, 5*time.Minute, KeepAlive)
+	t.Setenv("OLLAMA_KEEP_ALIVE", "3")
+	LoadConfig()
+	require.Equal(t, 3*time.Second, KeepAlive)
+	t.Setenv("OLLAMA_KEEP_ALIVE", "1h")
+	LoadConfig()
+	require.Equal(t, 1*time.Hour, KeepAlive)
+	t.Setenv("OLLAMA_KEEP_ALIVE", "-1s")
+	LoadConfig()
+	require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
+	t.Setenv("OLLAMA_KEEP_ALIVE", "-1")
+	LoadConfig()
+	require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
 }
 }
 
 
 func TestClientFromEnvironment(t *testing.T) {
 func TestClientFromEnvironment(t *testing.T) {

+ 3 - 54
server/routes.go

@@ -9,7 +9,6 @@ import (
 	"io"
 	"io"
 	"io/fs"
 	"io/fs"
 	"log/slog"
 	"log/slog"
-	"math"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
 	"net/netip"
 	"net/netip"
@@ -17,7 +16,6 @@ import (
 	"os/signal"
 	"os/signal"
 	"path/filepath"
 	"path/filepath"
 	"slices"
 	"slices"
-	"strconv"
 	"strings"
 	"strings"
 	"syscall"
 	"syscall"
 	"time"
 	"time"
@@ -56,8 +54,6 @@ func init() {
 	gin.SetMode(mode)
 	gin.SetMode(mode)
 }
 }
 
 
-var defaultSessionDuration = 5 * time.Minute
-
 func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
 func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
 	opts := api.DefaultOptions()
 	opts := api.DefaultOptions()
 	if err := opts.FromMap(model.Options); err != nil {
 	if err := opts.FromMap(model.Options); err != nil {
@@ -133,14 +129,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
-	var sessionDuration time.Duration
-	if req.KeepAlive == nil {
-		sessionDuration = getDefaultSessionDuration()
-	} else {
-		sessionDuration = req.KeepAlive.Duration
-	}
-
-	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
+	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
 	var runner *runnerRef
 	var runner *runnerRef
 	select {
 	select {
 	case runner = <-rCh:
 	case runner = <-rCh:
@@ -320,32 +309,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 	streamResponse(c, ch)
 	streamResponse(c, ch)
 }
 }
 
 
-func getDefaultSessionDuration() time.Duration {
-	if envconfig.KeepAlive != "" {
-		v, err := strconv.Atoi(envconfig.KeepAlive)
-		if err != nil {
-			d, err := time.ParseDuration(envconfig.KeepAlive)
-			if err != nil {
-				return defaultSessionDuration
-			}
-
-			if d < 0 {
-				return time.Duration(math.MaxInt64)
-			}
-
-			return d
-		}
-
-		d := time.Duration(v) * time.Second
-		if d < 0 {
-			return time.Duration(math.MaxInt64)
-		}
-		return d
-	}
-
-	return defaultSessionDuration
-}
-
 func (s *Server) EmbeddingsHandler(c *gin.Context) {
 func (s *Server) EmbeddingsHandler(c *gin.Context) {
 	var req api.EmbeddingRequest
 	var req api.EmbeddingRequest
 	err := c.ShouldBindJSON(&req)
 	err := c.ShouldBindJSON(&req)
@@ -380,14 +343,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
-	var sessionDuration time.Duration
-	if req.KeepAlive == nil {
-		sessionDuration = getDefaultSessionDuration()
-	} else {
-		sessionDuration = req.KeepAlive.Duration
-	}
-
-	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
+	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
 	var runner *runnerRef
 	var runner *runnerRef
 	select {
 	select {
 	case runner = <-rCh:
 	case runner = <-rCh:
@@ -1318,14 +1274,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
-	var sessionDuration time.Duration
-	if req.KeepAlive == nil {
-		sessionDuration = getDefaultSessionDuration()
-	} else {
-		sessionDuration = req.KeepAlive.Duration
-	}
-
-	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
+	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
 	var runner *runnerRef
 	var runner *runnerRef
 	select {
 	select {
 	case runner = <-rCh:
 	case runner = <-rCh:

+ 10 - 4
server/sched.go

@@ -24,7 +24,7 @@ type LlmRequest struct {
 	model           *Model
 	model           *Model
 	opts            api.Options
 	opts            api.Options
 	origNumCtx      int // Track the initial ctx request
 	origNumCtx      int // Track the initial ctx request
-	sessionDuration time.Duration
+	sessionDuration *api.Duration
 	successCh       chan *runnerRef
 	successCh       chan *runnerRef
 	errCh           chan error
 	errCh           chan error
 	schedAttempts   uint
 	schedAttempts   uint
@@ -75,7 +75,7 @@ func InitScheduler(ctx context.Context) *Scheduler {
 }
 }
 
 
 // context must be canceled to decrement ref count and release the runner
 // context must be canceled to decrement ref count and release the runner
-func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) {
+func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
 	if opts.NumCtx < 4 {
 	if opts.NumCtx < 4 {
 		opts.NumCtx = 4
 		opts.NumCtx = 4
 	}
 	}
@@ -389,7 +389,9 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
 		runner.expireTimer.Stop()
 		runner.expireTimer.Stop()
 		runner.expireTimer = nil
 		runner.expireTimer = nil
 	}
 	}
-	runner.sessionDuration = pending.sessionDuration
+	if pending.sessionDuration != nil {
+		runner.sessionDuration = pending.sessionDuration.Duration
+	}
 	pending.successCh <- runner
 	pending.successCh <- runner
 	go func() {
 	go func() {
 		<-pending.ctx.Done()
 		<-pending.ctx.Done()
@@ -402,6 +404,10 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
 	if numParallel < 1 {
 	if numParallel < 1 {
 		numParallel = 1
 		numParallel = 1
 	}
 	}
+	sessionDuration := envconfig.KeepAlive
+	if req.sessionDuration != nil {
+		sessionDuration = req.sessionDuration.Duration
+	}
 	llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
 	llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
 	if err != nil {
 	if err != nil {
 		// some older models are not compatible with newer versions of llama.cpp
 		// some older models are not compatible with newer versions of llama.cpp
@@ -419,7 +425,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
 		modelPath:       req.model.ModelPath,
 		modelPath:       req.model.ModelPath,
 		llama:           llama,
 		llama:           llama,
 		Options:         &req.opts,
 		Options:         &req.opts,
-		sessionDuration: req.sessionDuration,
+		sessionDuration: sessionDuration,
 		gpus:            gpus,
 		gpus:            gpus,
 		estimatedVRAM:   llama.EstimatedVRAM(),
 		estimatedVRAM:   llama.EstimatedVRAM(),
 		estimatedTotal:  llama.EstimatedTotal(),
 		estimatedTotal:  llama.EstimatedTotal(),

+ 11 - 11
server/sched_test.go

@@ -44,7 +44,7 @@ func TestLoad(t *testing.T) {
 		opts:            api.DefaultOptions(),
 		opts:            api.DefaultOptions(),
 		successCh:       make(chan *runnerRef, 1),
 		successCh:       make(chan *runnerRef, 1),
 		errCh:           make(chan error, 1),
 		errCh:           make(chan error, 1),
-		sessionDuration: 2,
+		sessionDuration: &api.Duration{Duration: 2 * time.Second},
 	}
 	}
 	// Fail to load model first
 	// Fail to load model first
 	s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
 	s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
@@ -142,7 +142,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
 		ctx:             scenario.ctx,
 		ctx:             scenario.ctx,
 		model:           model,
 		model:           model,
 		opts:            api.DefaultOptions(),
 		opts:            api.DefaultOptions(),
-		sessionDuration: 5 * time.Millisecond,
+		sessionDuration: &api.Duration{Duration: 5 * time.Millisecond},
 		successCh:       make(chan *runnerRef, 1),
 		successCh:       make(chan *runnerRef, 1),
 		errCh:           make(chan error, 1),
 		errCh:           make(chan error, 1),
 	}
 	}
@@ -156,18 +156,18 @@ func TestRequests(t *testing.T) {
 
 
 	// Same model, same request
 	// Same model, same request
 	scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
 	scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
-	scenario1a.req.sessionDuration = 5 * time.Millisecond
+	scenario1a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
 	scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
 	scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
 	scenario1b.req.model = scenario1a.req.model
 	scenario1b.req.model = scenario1a.req.model
 	scenario1b.ggml = scenario1a.ggml
 	scenario1b.ggml = scenario1a.ggml
-	scenario1b.req.sessionDuration = 0
+	scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
 
 
 	// simple reload of same model
 	// simple reload of same model
 	scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
 	scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
 	tmpModel := *scenario1a.req.model
 	tmpModel := *scenario1a.req.model
 	scenario2a.req.model = &tmpModel
 	scenario2a.req.model = &tmpModel
 	scenario2a.ggml = scenario1a.ggml
 	scenario2a.ggml = scenario1a.ggml
-	scenario2a.req.sessionDuration = 5 * time.Millisecond
+	scenario2a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
 
 
 	// Multiple loaded models
 	// Multiple loaded models
 	scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
 	scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
@@ -318,11 +318,11 @@ func TestGetRunner(t *testing.T) {
 	defer done()
 	defer done()
 
 
 	scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
 	scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
-	scenario1a.req.sessionDuration = 0
+	scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
 	scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
 	scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
-	scenario1b.req.sessionDuration = 0
+	scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
 	scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
 	scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
-	scenario1c.req.sessionDuration = 0
+	scenario1c.req.sessionDuration = &api.Duration{Duration: 0}
 	envconfig.MaxQueuedRequests = 1
 	envconfig.MaxQueuedRequests = 1
 	s := InitScheduler(ctx)
 	s := InitScheduler(ctx)
 	s.getGpuFn = func() gpu.GpuInfoList {
 	s.getGpuFn = func() gpu.GpuInfoList {
@@ -402,7 +402,7 @@ func TestPrematureExpired(t *testing.T) {
 	case <-ctx.Done():
 	case <-ctx.Done():
 		t.Fatal("timeout")
 		t.Fatal("timeout")
 	}
 	}
-	time.Sleep(scenario1a.req.sessionDuration)
+	time.Sleep(scenario1a.req.sessionDuration.Duration)
 	scenario1a.ctxDone()
 	scenario1a.ctxDone()
 	time.Sleep(20 * time.Millisecond)
 	time.Sleep(20 * time.Millisecond)
 	require.LessOrEqual(t, len(s.finishedReqCh), 1)
 	require.LessOrEqual(t, len(s.finishedReqCh), 1)
@@ -423,7 +423,7 @@ func TestUseLoadedRunner(t *testing.T) {
 		ctx:             ctx,
 		ctx:             ctx,
 		opts:            api.DefaultOptions(),
 		opts:            api.DefaultOptions(),
 		successCh:       make(chan *runnerRef, 1),
 		successCh:       make(chan *runnerRef, 1),
-		sessionDuration: 2,
+		sessionDuration: &api.Duration{Duration: 2},
 	}
 	}
 	finished := make(chan *LlmRequest)
 	finished := make(chan *LlmRequest)
 	llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
 	llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
@@ -614,7 +614,7 @@ func TestAlreadyCanceled(t *testing.T) {
 	dctx, done2 := context.WithCancel(ctx)
 	dctx, done2 := context.WithCancel(ctx)
 	done2()
 	done2()
 	scenario1a := newScenario(t, dctx, "ollama-model-1", 10)
 	scenario1a := newScenario(t, dctx, "ollama-model-1", 10)
-	scenario1a.req.sessionDuration = 0
+	scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
 	s := InitScheduler(ctx)
 	s := InitScheduler(ctx)
 	slog.Info("scenario1a")
 	slog.Info("scenario1a")
 	s.pendingReqCh <- scenario1a.req
 	s.pendingReqCh <- scenario1a.req