Michael Yang пре 9 месеци
родитељ
комит
85d9d73a72
3 измењених фајлова са 90 додато и 41 уклоњено
  1. 27 23
      envconfig/config.go
  2. 61 16
      envconfig/config_test.go
  3. 2 2
      server/sched.go

+ 27 - 23
envconfig/config.go

@@ -1,7 +1,6 @@
 package envconfig
 
 import (
-	"errors"
 	"fmt"
 	"log/slog"
 	"math"
@@ -15,15 +14,12 @@ import (
 	"time"
 )
 
-var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST")
-
 // Host returns the scheme and host. Host can be configured via the OLLAMA_HOST environment variable.
 // Default is scheme "http" and host "127.0.0.1:11434"
 func Host() *url.URL {
 	defaultPort := "11434"
 
-	s := os.Getenv("OLLAMA_HOST")
-	s = strings.TrimSpace(strings.Trim(strings.TrimSpace(s), "\"'"))
+	s := strings.TrimSpace(Var("OLLAMA_HOST"))
 	scheme, hostport, ok := strings.Cut(s, "://")
 	switch {
 	case !ok:
@@ -48,6 +44,7 @@ func Host() *url.URL {
 	}
 
 	if n, err := strconv.ParseInt(port, 10, 32); err != nil || n > 65535 || n < 0 {
+		slog.Warn("invalid port, using default", "port", port, "default", defaultPort)
 		return &url.URL{
 			Scheme: scheme,
 			Host:   net.JoinHostPort(host, defaultPort),
@@ -62,7 +59,7 @@ func Host() *url.URL {
 
 // Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable.
 func Origins() (origins []string) {
-	if s := getenv("OLLAMA_ORIGINS"); s != "" {
+	if s := Var("OLLAMA_ORIGINS"); s != "" {
 		origins = strings.Split(s, ",")
 	}
 
@@ -87,7 +84,7 @@ func Origins() (origins []string) {
 // Models returns the path to the models directory. Models directory can be configured via the OLLAMA_MODELS environment variable.
 // Default is $HOME/.ollama/models
 func Models() string {
-	if s, ok := os.LookupEnv("OLLAMA_MODELS"); ok {
+	if s := Var("OLLAMA_MODELS"); s != "" {
 		return s
 	}
 
@@ -104,7 +101,7 @@ func Models() string {
 // Default is 5 minutes.
 func KeepAlive() (keepAlive time.Duration) {
 	keepAlive = 5 * time.Minute
-	if s := os.Getenv("OLLAMA_KEEP_ALIVE"); s != "" {
+	if s := Var("OLLAMA_KEEP_ALIVE"); s != "" {
 		if d, err := time.ParseDuration(s); err == nil {
 			keepAlive = d
 		} else if n, err := strconv.ParseInt(s, 10, 64); err == nil {
@@ -121,7 +118,7 @@ func KeepAlive() (keepAlive time.Duration) {
 
 func Bool(k string) func() bool {
 	return func() bool {
-		if s := getenv(k); s != "" {
+		if s := Var(k); s != "" {
 			b, err := strconv.ParseBool(s)
 			if err != nil {
 				return true
@@ -151,7 +148,7 @@ var (
 
 func String(s string) func() string {
 	return func() string {
-		return getenv(s)
+		return Var(s)
 	}
 }
 
@@ -167,7 +164,7 @@ var (
 )
 
 func RunnersDir() (p string) {
-	if p := getenv("OLLAMA_RUNNERS_DIR"); p != "" {
+	if p := Var("OLLAMA_RUNNERS_DIR"); p != "" {
 		return p
 	}
 
@@ -213,22 +210,29 @@ func RunnersDir() (p string) {
 	return p
 }
 
-func Int(k string, n int) func() int {
-	return func() int {
-		if s := getenv(k); s != "" {
-			if n, err := strconv.ParseInt(s, 10, 64); err == nil && n >= 0 {
-				return int(n)
+func Uint(key string, defaultValue uint) func() uint {
+	return func() uint {
+		if s := Var(key); s != "" {
+			if n, err := strconv.ParseUint(s, 10, 64); err != nil {
+				slog.Warn("invalid environment variable, using default", "key", key, "value", s, "default", defaultValue)
+			} else {
+				return uint(n)
 			}
 		}
 
-		return n
+		return defaultValue
 	}
 }
 
 var (
-	NumParallel = Int("OLLAMA_NUM_PARALLEL", 0)
-	MaxRunners  = Int("OLLAMA_MAX_LOADED_MODELS", 0)
-	MaxQueue    = Int("OLLAMA_MAX_QUEUE", 512)
+	// NumParallel sets the number of parallel model requests. NumParallel can be configured via the OLLAMA_NUM_PARALLEL environment variable.
+	NumParallel = Uint("OLLAMA_NUM_PARALLEL", 0)
+	// MaxRunners sets the maximum number of loaded models. MaxRunners can be configured via the OLLAMA_MAX_LOADED_MODELS environment variable.
+	MaxRunners = Uint("OLLAMA_MAX_LOADED_MODELS", 0)
+	// MaxQueue sets the maximum number of queued requests. MaxQueue can be configured via the OLLAMA_MAX_QUEUE environment variable.
+	MaxQueue = Uint("OLLAMA_MAX_QUEUE", 512)
+	// MaxVRAM sets a maximum VRAM override in bytes. MaxVRAM can be configured via the OLLAMA_MAX_VRAM environment variable.
+	MaxVRAM = Uint("OLLAMA_MAX_VRAM", 0)
 )
 
 type EnvVar struct {
@@ -274,7 +278,7 @@ func Values() map[string]string {
 	return vals
 }
 
-// getenv returns an environment variable stripped of leading and trailing quotes or spaces
-func getenv(key string) string {
-	return strings.Trim(os.Getenv(key), "\"' ")
+// Var returns an environment variable stripped of leading and trailing quotes or spaces
+func Var(key string) string {
+	return strings.Trim(strings.TrimSpace(os.Getenv(key)), "\"'")
 }

+ 61 - 16
envconfig/config_test.go

@@ -30,6 +30,10 @@ func TestHost(t *testing.T) {
 		"extra quotes":        {"\"1.2.3.4\"", "1.2.3.4:11434"},
 		"extra space+quotes":  {" \" 1.2.3.4 \" ", "1.2.3.4:11434"},
 		"extra single quotes": {"'1.2.3.4'", "1.2.3.4:11434"},
+		"http":                {"http://1.2.3.4", "1.2.3.4:80"},
+		"http port":           {"http://1.2.3.4:4321", "1.2.3.4:4321"},
+		"https":               {"https://1.2.3.4", "1.2.3.4:443"},
+		"https port":          {"https://1.2.3.4:4321", "1.2.3.4:4321"},
 	}
 
 	for name, tt := range cases {
@@ -133,24 +137,45 @@ func TestOrigins(t *testing.T) {
 }
 
 func TestBool(t *testing.T) {
-	cases := map[string]struct {
-		value  string
-		expect bool
-	}{
-		"empty":     {"", false},
-		"true":      {"true", true},
-		"false":     {"false", false},
-		"1":         {"1", true},
-		"0":         {"0", false},
-		"random":    {"random", true},
-		"something": {"something", true},
+	cases := map[string]bool{
+		"":      false,
+		"true":  true,
+		"false": false,
+		"1":     true,
+		"0":     false,
+		// invalid values
+		"random":    true,
+		"something": true,
 	}
 
-	for name, tt := range cases {
-		t.Run(name, func(t *testing.T) {
-			t.Setenv("OLLAMA_BOOL", tt.value)
-			if b := Bool("OLLAMA_BOOL"); b() != tt.expect {
-				t.Errorf("%s: expected %t, got %t", name, tt.expect, b())
+	for k, v := range cases {
+		t.Run(k, func(t *testing.T) {
+			t.Setenv("OLLAMA_BOOL", k)
+			if b := Bool("OLLAMA_BOOL")(); b != v {
+				t.Errorf("%s: expected %t, got %t", k, v, b)
+			}
+		})
+	}
+}
+
+func TestUint(t *testing.T) {
+	cases := map[string]uint{
+		"0":    0,
+		"1":    1,
+		"1337": 1337,
+		// default values
+		"":       11434,
+		"-1":     11434,
+		"0o10":   11434,
+		"0x10":   11434,
+		"string": 11434,
+	}
+
+	for k, v := range cases {
+		t.Run(k, func(t *testing.T) {
+			t.Setenv("OLLAMA_UINT", k)
+			if i := Uint("OLLAMA_UINT", 11434)(); i != v {
+				t.Errorf("%s: expected %d, got %d", k, v, i)
 			}
 		})
 	}
@@ -188,3 +213,23 @@ func TestKeepAlive(t *testing.T) {
 		})
 	}
 }
+
+func TestVar(t *testing.T) {
+	cases := map[string]string{
+		"value":       "value",
+		" value ":     "value",
+		" 'value' ":   "value",
+		` "value" `:   "value",
+		" ' value ' ": " value ",
+		` " value " `: " value ",
+	}
+
+	for k, v := range cases {
+		t.Run(k, func(t *testing.T) {
+			t.Setenv("OLLAMA_VAR", k)
+			if s := Var("OLLAMA_VAR"); s != v {
+				t.Errorf("%s: expected %q, got %q", k, v, s)
+			}
+		})
+	}
+}

+ 2 - 2
server/sched.go

@@ -129,7 +129,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
 				slog.Debug("pending request cancelled or timed out, skipping scheduling")
 				continue
 			}
-			numParallel := envconfig.NumParallel()
+			numParallel := int(envconfig.NumParallel())
 			// TODO (jmorganca): multimodal models don't support parallel yet
 			// see https://github.com/ollama/ollama/issues/4165
 			if len(pending.model.ProjectorPaths) > 0 && numParallel != 1 {
@@ -151,7 +151,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
 						pending.useLoadedRunner(runner, s.finishedReqCh)
 						break
 					}
-				} else if envconfig.MaxRunners() > 0 && loadedCount >= envconfig.MaxRunners() {
+				} else if envconfig.MaxRunners() > 0 && loadedCount >= int(envconfig.MaxRunners()) {
 					slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount)
 					runnerToExpire = s.findRunnerToUnload()
 				} else {