瀏覽代碼

move OLLAMA_HOST to envconfig (#5009)

Patrick Devine 10 月之前
父節點
當前提交
c69bc19e46
共有 6 個文件被更改,包括 119 次插入103 次删除
  1. 2 53
      api/client.go
  2. 2 39
      api/client_test.go
  3. 0 3
      api/types.go
  4. 1 7
      cmd/cmd.go
  5. 66 1
      envconfig/config.go
  6. 48 0
      envconfig/config_test.go

+ 2 - 53
api/client.go

@@ -23,11 +23,9 @@ import (
 	"net"
 	"net/http"
 	"net/url"
-	"os"
 	"runtime"
-	"strconv"
-	"strings"
 
+	"github.com/ollama/ollama/envconfig"
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/version"
 )
@@ -65,10 +63,7 @@ func checkError(resp *http.Response, body []byte) error {
 // If the variable is not specified, a default ollama host and port will be
 // used.
 func ClientFromEnvironment() (*Client, error) {
-	ollamaHost, err := GetOllamaHost()
-	if err != nil {
-		return nil, err
-	}
+	ollamaHost := envconfig.Host
 
 	return &Client{
 		base: &url.URL{
@@ -79,52 +74,6 @@ func ClientFromEnvironment() (*Client, error) {
 	}, nil
 }
 
-type OllamaHost struct {
-	Scheme string
-	Host   string
-	Port   string
-}
-
-func GetOllamaHost() (OllamaHost, error) {
-	defaultPort := "11434"
-
-	hostVar := os.Getenv("OLLAMA_HOST")
-	hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'"))
-
-	scheme, hostport, ok := strings.Cut(hostVar, "://")
-	switch {
-	case !ok:
-		scheme, hostport = "http", hostVar
-	case scheme == "http":
-		defaultPort = "80"
-	case scheme == "https":
-		defaultPort = "443"
-	}
-
-	// trim trailing slashes
-	hostport = strings.TrimRight(hostport, "/")
-
-	host, port, err := net.SplitHostPort(hostport)
-	if err != nil {
-		host, port = "127.0.0.1", defaultPort
-		if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
-			host = ip.String()
-		} else if hostport != "" {
-			host = hostport
-		}
-	}
-
-	if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 {
-		return OllamaHost{}, ErrInvalidHostPort
-	}
-
-	return OllamaHost{
-		Scheme: scheme,
-		Host:   host,
-		Port:   port,
-	}, nil
-}
-
 func NewClient(base *url.URL, http *http.Client) *Client {
 	return &Client{
 		base: base,

+ 2 - 39
api/client_test.go

@@ -1,11 +1,9 @@
 package api
 
 import (
-	"fmt"
-	"net"
 	"testing"
 
-	"github.com/stretchr/testify/assert"
+	"github.com/ollama/ollama/envconfig"
 )
 
 func TestClientFromEnvironment(t *testing.T) {
@@ -35,6 +33,7 @@ func TestClientFromEnvironment(t *testing.T) {
 	for k, v := range testCases {
 		t.Run(k, func(t *testing.T) {
 			t.Setenv("OLLAMA_HOST", v.value)
+			envconfig.LoadConfig()
 
 			client, err := ClientFromEnvironment()
 			if err != v.err {
@@ -46,40 +45,4 @@ func TestClientFromEnvironment(t *testing.T) {
 			}
 		})
 	}
-
-	hostTestCases := map[string]*testCase{
-		"empty":               {value: "", expect: "127.0.0.1:11434"},
-		"only address":        {value: "1.2.3.4", expect: "1.2.3.4:11434"},
-		"only port":           {value: ":1234", expect: ":1234"},
-		"address and port":    {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"},
-		"hostname":            {value: "example.com", expect: "example.com:11434"},
-		"hostname and port":   {value: "example.com:1234", expect: "example.com:1234"},
-		"zero port":           {value: ":0", expect: ":0"},
-		"too large port":      {value: ":66000", err: ErrInvalidHostPort},
-		"too small port":      {value: ":-1", err: ErrInvalidHostPort},
-		"ipv6 localhost":      {value: "[::1]", expect: "[::1]:11434"},
-		"ipv6 world open":     {value: "[::]", expect: "[::]:11434"},
-		"ipv6 no brackets":    {value: "::1", expect: "[::1]:11434"},
-		"ipv6 + port":         {value: "[::1]:1337", expect: "[::1]:1337"},
-		"extra space":         {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"},
-		"extra quotes":        {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"},
-		"extra space+quotes":  {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"},
-		"extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"},
-	}
-
-	for k, v := range hostTestCases {
-		t.Run(k, func(t *testing.T) {
-			t.Setenv("OLLAMA_HOST", v.value)
-
-			oh, err := GetOllamaHost()
-			if err != v.err {
-				t.Fatalf("expected %s, got %s", v.err, err)
-			}
-
-			if err == nil {
-				host := net.JoinHostPort(oh.Host, oh.Port)
-				assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host))
-			}
-		})
-	}
 }

+ 0 - 3
api/types.go

@@ -2,7 +2,6 @@ package api
 
 import (
 	"encoding/json"
-	"errors"
 	"fmt"
 	"log/slog"
 	"math"
@@ -377,8 +376,6 @@ func (m *Metrics) Summary() {
 	}
 }
 
-var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST")
-
 func (opts *Options) FromMap(m map[string]interface{}) error {
 	valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
 	typeOpts := reflect.TypeOf(opts).Elem()   // types of the fields in the options struct

+ 1 - 7
cmd/cmd.go

@@ -960,17 +960,11 @@ func generate(cmd *cobra.Command, opts runOptions) error {
 }
 
 func RunServer(cmd *cobra.Command, _ []string) error {
-	// retrieve the OLLAMA_HOST environment variable
-	ollamaHost, err := api.GetOllamaHost()
-	if err != nil {
-		return err
-	}
-
 	if err := initializeKeypair(); err != nil {
 		return err
 	}
 
-	ln, err := net.Listen("tcp", net.JoinHostPort(ollamaHost.Host, ollamaHost.Port))
+	ln, err := net.Listen("tcp", net.JoinHostPort(envconfig.Host.Host, envconfig.Host.Port))
 	if err != nil {
 		return err
 	}

+ 66 - 1
envconfig/config.go

@@ -1,6 +1,7 @@
 package envconfig
 
 import (
+	"errors"
 	"fmt"
 	"log/slog"
 	"net"
@@ -11,6 +12,18 @@ import (
 	"strings"
 )
 
+type OllamaHost struct {
+	Scheme string
+	Host   string
+	Port   string
+}
+
+func (o OllamaHost) String() string {
+	return fmt.Sprintf("%s://%s:%s", o.Scheme, o.Host, o.Port)
+}
+
+var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST")
+
 var (
 	// Set via OLLAMA_ORIGINS in the environment
 	AllowOrigins []string
@@ -34,6 +47,8 @@ var (
 	NoPrune bool
 	// Set via OLLAMA_NUM_PARALLEL in the environment
 	NumParallel int
+	// Set via OLLAMA_HOST in the environment
+	Host *OllamaHost
 	// Set via OLLAMA_RUNNERS_DIR in the environment
 	RunnersDir string
 	// Set via OLLAMA_TMPDIR in the environment
@@ -50,7 +65,7 @@ func AsMap() map[string]EnvVar {
 	return map[string]EnvVar{
 		"OLLAMA_DEBUG":             {"OLLAMA_DEBUG", Debug, "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
 		"OLLAMA_FLASH_ATTENTION":   {"OLLAMA_FLASH_ATTENTION", FlashAttention, "Enabled flash attention"},
-		"OLLAMA_HOST":              {"OLLAMA_HOST", "", "IP Address for the ollama server (default 127.0.0.1:11434)"},
+		"OLLAMA_HOST":              {"OLLAMA_HOST", Host, "IP Address for the ollama server (default 127.0.0.1:11434)"},
 		"OLLAMA_KEEP_ALIVE":        {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"},
 		"OLLAMA_LLM_LIBRARY":       {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
 		"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models (default 1)"},
@@ -216,4 +231,54 @@ func LoadConfig() {
 	}
 
 	KeepAlive = clean("OLLAMA_KEEP_ALIVE")
+
+	var err error
+	Host, err = getOllamaHost()
+	if err != nil {
+		slog.Error("invalid setting", "OLLAMA_HOST", Host, "error", err, "using default port", Host.Port)
+	}
+}
+
+func getOllamaHost() (*OllamaHost, error) {
+	defaultPort := "11434"
+
+	hostVar := os.Getenv("OLLAMA_HOST")
+	hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'"))
+
+	scheme, hostport, ok := strings.Cut(hostVar, "://")
+	switch {
+	case !ok:
+		scheme, hostport = "http", hostVar
+	case scheme == "http":
+		defaultPort = "80"
+	case scheme == "https":
+		defaultPort = "443"
+	}
+
+	// trim trailing slashes
+	hostport = strings.TrimRight(hostport, "/")
+
+	host, port, err := net.SplitHostPort(hostport)
+	if err != nil {
+		host, port = "127.0.0.1", defaultPort
+		if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
+			host = ip.String()
+		} else if hostport != "" {
+			host = hostport
+		}
+	}
+
+	if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 {
+		return &OllamaHost{
+			Scheme: scheme,
+			Host:   host,
+			Port:   defaultPort,
+		}, ErrInvalidHostPort
+	}
+
+	return &OllamaHost{
+		Scheme: scheme,
+		Host:   host,
+		Port:   port,
+	}, nil
 }

+ 48 - 0
envconfig/config_test.go

@@ -1,8 +1,11 @@
 package envconfig
 
 import (
+	"fmt"
+	"net"
 	"testing"
 
+	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
 )
 
@@ -21,3 +24,48 @@ func TestConfig(t *testing.T) {
 	LoadConfig()
 	require.True(t, FlashAttention)
 }
+
+func TestClientFromEnvironment(t *testing.T) {
+	type testCase struct {
+		value  string
+		expect string
+		err    error
+	}
+
+	hostTestCases := map[string]*testCase{
+		"empty":               {value: "", expect: "127.0.0.1:11434"},
+		"only address":        {value: "1.2.3.4", expect: "1.2.3.4:11434"},
+		"only port":           {value: ":1234", expect: ":1234"},
+		"address and port":    {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"},
+		"hostname":            {value: "example.com", expect: "example.com:11434"},
+		"hostname and port":   {value: "example.com:1234", expect: "example.com:1234"},
+		"zero port":           {value: ":0", expect: ":0"},
+		"too large port":      {value: ":66000", err: ErrInvalidHostPort},
+		"too small port":      {value: ":-1", err: ErrInvalidHostPort},
+		"ipv6 localhost":      {value: "[::1]", expect: "[::1]:11434"},
+		"ipv6 world open":     {value: "[::]", expect: "[::]:11434"},
+		"ipv6 no brackets":    {value: "::1", expect: "[::1]:11434"},
+		"ipv6 + port":         {value: "[::1]:1337", expect: "[::1]:1337"},
+		"extra space":         {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"},
+		"extra quotes":        {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"},
+		"extra space+quotes":  {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"},
+		"extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"},
+	}
+
+	for k, v := range hostTestCases {
+		t.Run(k, func(t *testing.T) {
+			t.Setenv("OLLAMA_HOST", v.value)
+			LoadConfig()
+
+			oh, err := getOllamaHost()
+			if err != v.err {
+				t.Fatalf("expected %s, got %s", v.err, err)
+			}
+
+			if err == nil {
+				host := net.JoinHostPort(oh.Host, oh.Port)
+				assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host))
+			}
+		})
+	}
+}