Browse Source

correct precedence of serve params (args over env over default)

cmiller01 1 year ago
parent
commit
93492f1e18
2 changed files with 130 additions and 16 deletions
  1. 27 16
      cmd/cmd.go
  2. 103 0
      cmd/cmd_test.go

+ 27 - 16
cmd/cmd.go

@@ -513,28 +513,39 @@ func generateBatch(cmd *cobra.Command, model string) error {
 	return nil
 }
 
-func RunServer(cmd *cobra.Command, _ []string) error {
-	host, err := cmd.Flags().GetString("host")
-	if err != nil {
-		return errors.New("host unset")
-	}
-	if os.Getenv("OLLAMA_HOST") != "" {
-		host = os.Getenv("OLLAMA_HOST")
-	}
-	port, err := cmd.Flags().GetString("port")
+// getRunServerParams takes a command and the environment variables and returns the correct params
+// given the order of precedence: command line args (highest), environment variables, defaults (lowest)
+func getRunServerParams(cmd *cobra.Command) (host, port string, extraOrigins []string, err error) {
+	host = os.Getenv("OLLAMA_HOST")
+	hostFlag := cmd.Flags().Lookup("host")
+	if hostFlag == nil {
+		return "", "", nil, errors.New("host unset")
+	}
+	if hostFlag.Changed || host == "" {
+		host = hostFlag.Value.String()
+	}
+	port = os.Getenv("OLLAMA_PORT")
+	portFlag := cmd.Flags().Lookup("port")
+	if portFlag == nil {
+		return "", "", nil, errors.New("port unset")
+	}
+	if portFlag.Changed || port == "" {
+		port = portFlag.Value.String()
+	}
+	extraOrigins, err = cmd.Flags().GetStringSlice("allowed-origins")
 	if err != nil {
-		return errors.New("port unset")
-	}
-
-	if os.Getenv("OLLAMA_PORT") != "" {
-		port = os.Getenv("OLLAMA_PORT")
+		return "", "", nil, err
 	}
+	return host, port, extraOrigins, nil
+}
 
-	ln, err := net.Listen("tcp", fmt.Sprintf("%s:%s", host, port))
+func RunServer(cmd *cobra.Command, _ []string) error {
+	host, port, extraOrigins, err := getRunServerParams(cmd)
 	if err != nil {
 		return err
 	}
-	extraOrigins, err := cmd.Flags().GetStringSlice("allowed-origins")
+
+	ln, err := net.Listen("tcp", fmt.Sprintf("%s:%s", host, port))
 	if err != nil {
 		return err
 	}

+ 103 - 0
cmd/cmd_test.go

@@ -0,0 +1,103 @@
+package cmd
+
+import (
+	"os"
+	"testing"
+)
+
+func TestGetRunServerParams(t *testing.T) {
+	t.Run("default values", func(t *testing.T) {
+		cmd := NewCLI()
+		serveCmd, _, err := cmd.Find([]string{"serve"})
+		if err != nil {
+			t.Errorf("expected serve command, got %s", err)
+		}
+		host, port, extraOrigins, err := getRunServerParams(serveCmd)
+		// assertions
+		if err != nil {
+			t.Errorf("unexpected error, got %s", err)
+		}
+		if host != "127.0.0.1" {
+			t.Errorf("unexpected host, got %s", host)
+		}
+		if port != "11434" {
+			t.Errorf("unexpected port, got %s", port)
+		}
+		if len(extraOrigins) != 0 {
+			t.Errorf("unexpected origins, got %s", extraOrigins)
+		}
+	})
+	t.Run("environment variables take precedence over default", func(t *testing.T) {
+		cmd := NewCLI()
+		serveCmd, _, err := cmd.Find([]string{"serve"})
+		if err != nil {
+			t.Errorf("expected serve command, got %s", err)
+		}
+		// setup environment variables
+		err = os.Setenv("OLLAMA_HOST", "0.0.0.0")
+		if err != nil {
+			t.Errorf("could not set env var")
+		}
+		err = os.Setenv("OLLAMA_PORT", "9999")
+		if err != nil {
+			t.Errorf("could not set env var")
+		}
+		defer func() {
+			os.Unsetenv("OLLAMA_HOST")
+			os.Unsetenv("OLLAMA_PORT")
+		}()
+
+		host, port, extraOrigins, err := getRunServerParams(serveCmd)
+		// assertions
+		if err != nil {
+			t.Errorf("unexpected error, got %s", err)
+		}
+		if host != "0.0.0.0" {
+			t.Errorf("unexpected host, got %s", host)
+		}
+		if port != "9999" {
+			t.Errorf("unexpected port, got %s", port)
+		}
+		if len(extraOrigins) != 0 {
+			t.Errorf("unexpected origins, got %s", extraOrigins)
+		}
+	})
+	t.Run("command line args take precedence over env vars", func(t *testing.T) {
+		cmd := NewCLI()
+		serveCmd, _, err := cmd.Find([]string{"serve"})
+		if err != nil {
+			t.Errorf("expected serve command, got %s", err)
+		}
+		// setup environment variables
+		err = os.Setenv("OLLAMA_HOST", "0.0.0.0")
+		if err != nil {
+			t.Errorf("could not set env var")
+		}
+		err = os.Setenv("OLLAMA_PORT", "9999")
+		if err != nil {
+			t.Errorf("could not set env var")
+		}
+		defer func() {
+			os.Unsetenv("OLLAMA_HOST")
+			os.Unsetenv("OLLAMA_PORT")
+		}()
+		// now set command flags
+		serveCmd.Flags().Set("host", "localhost")
+		serveCmd.Flags().Set("port", "8888")
+		serveCmd.Flags().Set("allowed-origins", "http://foo.example.com,http://192.168.1.1")
+
+		host, port, extraOrigins, err := getRunServerParams(serveCmd)
+		if err != nil {
+			t.Errorf("unexpected error, got %s", err)
+		}
+		if host != "localhost" {
+			t.Errorf("unexpected host, got %s", host)
+		}
+		if port != "8888" {
+			t.Errorf("unexpected port, got %s", port)
+		}
+		if len(extraOrigins) != 2 {
+			t.Errorf("expected two origins, got length %d", len(extraOrigins))
+		}
+	})
+}