Michael Yang 10 ماه پیش
والد
کامیت
0f1910129f
5فایلهای تغییر یافته به همراه42 افزوده شده و 77 حذف شده
  1. 18 48
      envconfig/config.go
  2. 1 8
      integration/basic_test.go
  3. 6 8
      integration/max_queue_test.go
  4. 14 9
      server/sched.go
  5. 3 4
      server/sched_test.go

+ 18 - 48
envconfig/config.go

@@ -213,13 +213,22 @@ func RunnersDir() (p string) {
 	return p
 }
 
+func Int(k string, n int) func() int {
+	return func() int {
+		if s := getenv(k); s != "" {
+			if n, err := strconv.ParseInt(s, 10, 64); err == nil && n >= 0 {
+				return int(n)
+			}
+		}
+
+		return n
+	}
+}
+
 var (
-	// Set via OLLAMA_MAX_LOADED_MODELS in the environment
-	MaxRunners int
-	// Set via OLLAMA_MAX_QUEUE in the environment
-	MaxQueuedRequests int
-	// Set via OLLAMA_NUM_PARALLEL in the environment
-	NumParallel int
+	NumParallel = Int("OLLAMA_NUM_PARALLEL", 0)
+	MaxRunners  = Int("OLLAMA_MAX_LOADED_MODELS", 0)
+	MaxQueue    = Int("OLLAMA_MAX_QUEUE", 512)
 )
 
 type EnvVar struct {
@@ -235,12 +244,12 @@ func AsMap() map[string]EnvVar {
 		"OLLAMA_HOST":              {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
 		"OLLAMA_KEEP_ALIVE":        {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
 		"OLLAMA_LLM_LIBRARY":       {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"},
-		"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"},
-		"OLLAMA_MAX_QUEUE":         {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
+		"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
+		"OLLAMA_MAX_QUEUE":         {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
 		"OLLAMA_MODELS":            {"OLLAMA_MODELS", Models(), "The path to the models directory"},
 		"OLLAMA_NOHISTORY":         {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
 		"OLLAMA_NOPRUNE":           {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
-		"OLLAMA_NUM_PARALLEL":      {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests"},
+		"OLLAMA_NUM_PARALLEL":      {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
 		"OLLAMA_ORIGINS":           {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"},
 		"OLLAMA_RUNNERS_DIR":       {"OLLAMA_RUNNERS_DIR", RunnersDir(), "Location for runners"},
 		"OLLAMA_SCHED_SPREAD":      {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
@@ -269,42 +278,3 @@ func Values() map[string]string {
 func getenv(key string) string {
 	return strings.Trim(os.Getenv(key), "\"' ")
 }
-
-func init() {
-	// default values
-	NumParallel = 0 // Autoselect
-	MaxRunners = 0  // Autoselect
-	MaxQueuedRequests = 512
-
-	LoadConfig()
-}
-
-func LoadConfig() {
-	if onp := getenv("OLLAMA_NUM_PARALLEL"); onp != "" {
-		val, err := strconv.Atoi(onp)
-		if err != nil {
-			slog.Error("invalid setting, ignoring", "OLLAMA_NUM_PARALLEL", onp, "error", err)
-		} else {
-			NumParallel = val
-		}
-	}
-
-	maxRunners := getenv("OLLAMA_MAX_LOADED_MODELS")
-	if maxRunners != "" {
-		m, err := strconv.Atoi(maxRunners)
-		if err != nil {
-			slog.Error("invalid setting, ignoring", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err)
-		} else {
-			MaxRunners = m
-		}
-	}
-
-	if onp := os.Getenv("OLLAMA_MAX_QUEUE"); onp != "" {
-		p, err := strconv.Atoi(onp)
-		if err != nil || p <= 0 {
-			slog.Error("invalid setting, ignoring", "OLLAMA_MAX_QUEUE", onp, "error", err)
-		} else {
-			MaxQueuedRequests = p
-		}
-	}
-}

+ 1 - 8
integration/basic_test.go

@@ -45,14 +45,7 @@ func TestUnicodeModelDir(t *testing.T) {
 	defer os.RemoveAll(modelDir)
 	slog.Info("unicode", "OLLAMA_MODELS", modelDir)
 
-	oldModelsDir := os.Getenv("OLLAMA_MODELS")
-	if oldModelsDir == "" {
-		defer os.Unsetenv("OLLAMA_MODELS")
-	} else {
-		defer os.Setenv("OLLAMA_MODELS", oldModelsDir)
-	}
-	err = os.Setenv("OLLAMA_MODELS", modelDir)
-	require.NoError(t, err)
+	t.Setenv("OLLAMA_MODELS", modelDir)
 
 	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
 	defer cancel()

+ 6 - 8
integration/max_queue_test.go

@@ -5,7 +5,6 @@ package integration
 import (
 	"context"
 	"errors"
-	"fmt"
 	"log/slog"
 	"os"
 	"strconv"
@@ -14,8 +13,10 @@ import (
 	"testing"
 	"time"
 
-	"github.com/ollama/ollama/api"
 	"github.com/stretchr/testify/require"
+
+	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/envconfig"
 )
 
 func TestMaxQueue(t *testing.T) {
@@ -27,13 +28,10 @@ func TestMaxQueue(t *testing.T) {
 	// Note: This test can be quite slow when running in CPU mode, so keep the threadCount low unless your on GPU
 	// Also note that by default Darwin can't sustain > ~128 connections without adjusting limits
 	threadCount := 32
-	mq := os.Getenv("OLLAMA_MAX_QUEUE")
-	if mq != "" {
-		var err error
-		threadCount, err = strconv.Atoi(mq)
-		require.NoError(t, err)
+	if maxQueue := envconfig.MaxQueue(); maxQueue != 0 {
+		threadCount = maxQueue
 	} else {
-		os.Setenv("OLLAMA_MAX_QUEUE", fmt.Sprintf("%d", threadCount))
+		t.Setenv("OLLAMA_MAX_QUEUE", strconv.Itoa(threadCount))
 	}
 
 	req := api.GenerateRequest{

+ 14 - 9
server/sched.go

@@ -5,9 +5,11 @@ import (
 	"errors"
 	"fmt"
 	"log/slog"
+	"os"
 	"reflect"
 	"runtime"
 	"sort"
+	"strconv"
 	"strings"
 	"sync"
 	"time"
@@ -59,11 +61,12 @@ var defaultParallel = 4
 var ErrMaxQueue = fmt.Errorf("server busy, please try again.  maximum pending requests exceeded")
 
 func InitScheduler(ctx context.Context) *Scheduler {
+	maxQueue := envconfig.MaxQueue()
 	sched := &Scheduler{
-		pendingReqCh:  make(chan *LlmRequest, envconfig.MaxQueuedRequests),
-		finishedReqCh: make(chan *LlmRequest, envconfig.MaxQueuedRequests),
-		expiredCh:     make(chan *runnerRef, envconfig.MaxQueuedRequests),
-		unloadedCh:    make(chan interface{}, envconfig.MaxQueuedRequests),
+		pendingReqCh:  make(chan *LlmRequest, maxQueue),
+		finishedReqCh: make(chan *LlmRequest, maxQueue),
+		expiredCh:     make(chan *runnerRef, maxQueue),
+		unloadedCh:    make(chan interface{}, maxQueue),
 		loaded:        make(map[string]*runnerRef),
 		newServerFn:   llm.NewLlamaServer,
 		getGpuFn:      gpu.GetGPUInfo,
@@ -126,7 +129,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
 				slog.Debug("pending request cancelled or timed out, skipping scheduling")
 				continue
 			}
-			numParallel := envconfig.NumParallel
+			numParallel := envconfig.NumParallel()
 			// TODO (jmorganca): multimodal models don't support parallel yet
 			// see https://github.com/ollama/ollama/issues/4165
 			if len(pending.model.ProjectorPaths) > 0 && numParallel != 1 {
@@ -148,7 +151,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
 						pending.useLoadedRunner(runner, s.finishedReqCh)
 						break
 					}
-				} else if envconfig.MaxRunners > 0 && loadedCount >= envconfig.MaxRunners {
+				} else if envconfig.MaxRunners() > 0 && loadedCount >= envconfig.MaxRunners() {
 					slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount)
 					runnerToExpire = s.findRunnerToUnload()
 				} else {
@@ -161,7 +164,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
 						gpus = s.getGpuFn()
 					}
 
-					if envconfig.MaxRunners <= 0 {
+					if envconfig.MaxRunners() <= 0 {
 						// No user specified MaxRunners, so figure out what automatic setting to use
 						// If all GPUs have reliable free memory reporting, defaultModelsPerGPU * the number of GPUs
 						// if any GPU has unreliable free memory reporting, 1x the number of GPUs
@@ -173,11 +176,13 @@ func (s *Scheduler) processPending(ctx context.Context) {
 							}
 						}
 						if allReliable {
-							envconfig.MaxRunners = defaultModelsPerGPU * len(gpus)
+							// HACK
+							os.Setenv("OLLAMA_MAX_LOADED_MODELS", strconv.Itoa(defaultModelsPerGPU*len(gpus)))
 							slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", envconfig.MaxRunners, "gpu_count", len(gpus))
 						} else {
+							// HACK
+							os.Setenv("OLLAMA_MAX_LOADED_MODELS", strconv.Itoa(len(gpus)))
 							slog.Info("one or more GPUs detected that are unable to accurately report free memory - disabling default concurrency")
-							envconfig.MaxRunners = len(gpus)
 						}
 					}
 

+ 3 - 4
server/sched_test.go

@@ -12,7 +12,6 @@ import (
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/app/lifecycle"
-	"github.com/ollama/ollama/envconfig"
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/gpu"
 	"github.com/ollama/ollama/llm"
@@ -272,7 +271,7 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
 	c.req.opts.NumGPU = 0                                       // CPU load, will be allowed
 	d := newScenarioRequest(t, ctx, "ollama-model-3c", 30, nil) // Needs prior unloaded
 
-	envconfig.MaxRunners = 1
+	t.Setenv("OLLAMA_MAX_LOADED_MODELS", "1")
 	s.newServerFn = a.newServer
 	slog.Info("a")
 	s.pendingReqCh <- a.req
@@ -291,7 +290,7 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
 	require.Len(t, s.loaded, 1)
 	s.loadedMu.Unlock()
 
-	envconfig.MaxRunners = 0
+	t.Setenv("OLLAMA_MAX_LOADED_MODELS", "0")
 	s.newServerFn = b.newServer
 	slog.Info("b")
 	s.pendingReqCh <- b.req
@@ -362,7 +361,7 @@ func TestGetRunner(t *testing.T) {
 	a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond})
 	b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond})
 	c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond})
-	envconfig.MaxQueuedRequests = 1
+	t.Setenv("OLLAMA_MAX_QUEUE", "1")
 	s := InitScheduler(ctx)
 	s.getGpuFn = getGpuFn
 	s.getCpuFn = getCpuFn