Browse Source

better checking for OLLAMA_HOST variable (#3661)

Patrick Devine 1 year ago
parent
commit
9009bedf13
4 changed files with 83 additions and 15 deletions
  1. 35 8
      api/client.go
  2. 43 1
      api/client_test.go
  3. 1 0
      api/types.go
  4. 4 6
      cmd/cmd.go

+ 35 - 8
api/client.go

@@ -18,6 +18,7 @@ import (
 	"net/url"
 	"net/url"
 	"os"
 	"os"
 	"runtime"
 	"runtime"
+	"strconv"
 	"strings"
 	"strings"
 
 
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/format"
@@ -57,12 +58,36 @@ func checkError(resp *http.Response, body []byte) error {
 // If the variable is not specified, a default ollama host and port will be
 // If the variable is not specified, a default ollama host and port will be
 // used.
 // used.
 func ClientFromEnvironment() (*Client, error) {
 func ClientFromEnvironment() (*Client, error) {
+	ollamaHost, err := GetOllamaHost()
+	if err != nil {
+		return nil, err
+	}
+
+	return &Client{
+		base: &url.URL{
+			Scheme: ollamaHost.Scheme,
+			Host:   net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
+		},
+		http: http.DefaultClient,
+	}, nil
+}
+
+type OllamaHost struct {
+	Scheme string
+	Host   string
+	Port   string
+}
+
+func GetOllamaHost() (OllamaHost, error) {
 	defaultPort := "11434"
 	defaultPort := "11434"
 
 
-	scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
+	hostVar := os.Getenv("OLLAMA_HOST")
+	hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'"))
+
+	scheme, hostport, ok := strings.Cut(hostVar, "://")
 	switch {
 	switch {
 	case !ok:
 	case !ok:
-		scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
+		scheme, hostport = "http", hostVar
 	case scheme == "http":
 	case scheme == "http":
 		defaultPort = "80"
 		defaultPort = "80"
 	case scheme == "https":
 	case scheme == "https":
@@ -82,12 +107,14 @@ func ClientFromEnvironment() (*Client, error) {
 		}
 		}
 	}
 	}
 
 
-	return &Client{
-		base: &url.URL{
-			Scheme: scheme,
-			Host:   net.JoinHostPort(host, port),
-		},
-		http: http.DefaultClient,
+	if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 {
+		return OllamaHost{}, ErrInvalidHostPort
+	}
+
+	return OllamaHost{
+		Scheme: scheme,
+		Host:   host,
+		Port:   port,
 	}, nil
 	}, nil
 }
 }
 
 

+ 43 - 1
api/client_test.go

@@ -1,6 +1,12 @@
 package api
 package api
 
 
-import "testing"
+import (
+	"fmt"
+	"net"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
 
 
 func TestClientFromEnvironment(t *testing.T) {
 func TestClientFromEnvironment(t *testing.T) {
 	type testCase struct {
 	type testCase struct {
@@ -40,4 +46,40 @@ func TestClientFromEnvironment(t *testing.T) {
 			}
 			}
 		})
 		})
 	}
 	}
+
+	hostTestCases := map[string]*testCase{
+		"empty":               {value: "", expect: "127.0.0.1:11434"},
+		"only address":        {value: "1.2.3.4", expect: "1.2.3.4:11434"},
+		"only port":           {value: ":1234", expect: ":1234"},
+		"address and port":    {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"},
+		"hostname":            {value: "example.com", expect: "example.com:11434"},
+		"hostname and port":   {value: "example.com:1234", expect: "example.com:1234"},
+		"zero port":           {value: ":0", expect: ":0"},
+		"too large port":      {value: ":66000", err: ErrInvalidHostPort},
+		"too small port":      {value: ":-1", err: ErrInvalidHostPort},
+		"ipv6 localhost":      {value: "[::1]", expect: "[::1]:11434"},
+		"ipv6 world open":     {value: "[::]", expect: "[::]:11434"},
+		"ipv6 no brackets":    {value: "::1", expect: "[::1]:11434"},
+		"ipv6 + port":         {value: "[::1]:1337", expect: "[::1]:1337"},
+		"extra space":         {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"},
+		"extra quotes":        {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"},
+		"extra space+quotes":  {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"},
+		"extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"},
+	}
+
+	for k, v := range hostTestCases {
+		t.Run(k, func(t *testing.T) {
+			t.Setenv("OLLAMA_HOST", v.value)
+
+			oh, err := GetOllamaHost()
+			if err != v.err {
+				t.Fatalf("expected %s, got %s", v.err, err)
+			}
+
+			if err == nil {
+				host := net.JoinHostPort(oh.Host, oh.Port)
+				assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host))
+			}
+		})
+	}
 }
 }

+ 1 - 0
api/types.go

@@ -309,6 +309,7 @@ func (m *Metrics) Summary() {
 }
 }
 
 
 var ErrInvalidOpts = errors.New("invalid options")
 var ErrInvalidOpts = errors.New("invalid options")
+var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST")
 
 
 func (opts *Options) FromMap(m map[string]interface{}) error {
 func (opts *Options) FromMap(m map[string]interface{}) error {
 	valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
 	valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct

+ 4 - 6
cmd/cmd.go

@@ -831,19 +831,17 @@ func generate(cmd *cobra.Command, opts runOptions) error {
 }
 }
 
 
 func RunServer(cmd *cobra.Command, _ []string) error {
 func RunServer(cmd *cobra.Command, _ []string) error {
-	host, port, err := net.SplitHostPort(strings.Trim(os.Getenv("OLLAMA_HOST"), "\"'"))
+	// retrieve the OLLAMA_HOST environment variable
+	ollamaHost, err := api.GetOllamaHost()
 	if err != nil {
 	if err != nil {
-		host, port = "127.0.0.1", "11434"
-		if ip := net.ParseIP(strings.Trim(os.Getenv("OLLAMA_HOST"), "[]")); ip != nil {
-			host = ip.String()
-		}
+		return err
 	}
 	}
 
 
 	if err := initializeKeypair(); err != nil {
 	if err := initializeKeypair(); err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	ln, err := net.Listen("tcp", net.JoinHostPort(host, port))
+	ln, err := net.Listen("tcp", net.JoinHostPort(ollamaHost.Host, ollamaHost.Port))
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}