Browse Source

clean up cli flags

Jeffrey Morgan 1 year ago
parent
commit
040a5b9750
3 changed files with 8 additions and 10 deletions
  1. 4 4
      cmd/cmd.go
  2. 1 1
      cmd/cmd_test.go
  3. 3 5
      server/routes.go

+ 4 - 4
cmd/cmd.go

@@ -538,7 +538,7 @@ func getRunServerParams(cmd *cobra.Command) (host, port string, extraOrigins []s
 	if portFlag.Changed || port == "" {
 		port = portFlag.Value.String()
 	}
-	extraOrigins, err = cmd.Flags().GetStringSlice("allowed-origins")
+	extraOrigins, err = cmd.Flags().GetStringSlice("origins")
 	if err != nil {
 		return "", "", nil, err
 	}
@@ -546,7 +546,7 @@ func getRunServerParams(cmd *cobra.Command) (host, port string, extraOrigins []s
 }
 
 func RunServer(cmd *cobra.Command, _ []string) error {
-	host, port, extraOrigins, err := getRunServerParams(cmd)
+	host, port, origins, err := getRunServerParams(cmd)
 	if err != nil {
 		return err
 	}
@@ -556,7 +556,7 @@ func RunServer(cmd *cobra.Command, _ []string) error {
 		return err
 	}
 
-	return server.Serve(ln, extraOrigins)
+	return server.Serve(ln, origins)
 }
 
 func startMacApp(client *api.Client) error {
@@ -650,7 +650,7 @@ func NewCLI() *cobra.Command {
 
 	serveCmd.Flags().String("port", "11434", "Port to listen on, may also use OLLAMA_PORT environment variable")
 	serveCmd.Flags().String("host", "127.0.0.1", "Host listen address, may also use OLLAMA_HOST environment variable")
-	serveCmd.Flags().StringSlice("allowed-origins", []string{}, "Additional allowed CORS origins (outside of localhost), specify as comma-separated list")
+	serveCmd.Flags().StringSlice("origins", nil, "Additional allowed CORS origins as comma-separated list")
 
 	pullCmd := &cobra.Command{
 		Use:     "pull MODEL",

+ 1 - 1
cmd/cmd_test.go

@@ -84,7 +84,7 @@ func TestGetRunServerParams(t *testing.T) {
 		// now set command flags
 		serveCmd.Flags().Set("host", "localhost")
 		serveCmd.Flags().Set("port", "8888")
-		serveCmd.Flags().Set("allowed-origins", "http://foo.example.com,http://192.168.1.1")
+		serveCmd.Flags().Set("origins", "http://foo.example.com,http://192.168.1.1")
 
 		host, port, extraOrigins, err := getRunServerParams(serveCmd)
 		if err != nil {

+ 3 - 5
server/routes.go

@@ -391,10 +391,10 @@ func CopyModelHandler(c *gin.Context) {
 	}
 }
 
-func Serve(ln net.Listener, extraOrigins []string) error {
+func Serve(ln net.Listener, origins []string) error {
 	config := cors.DefaultConfig()
 	config.AllowWildcard = true
-	allowedOrigins := []string{
+	config.AllowOrigins = append(origins, []string{
 		"http://localhost",
 		"http://localhost:*",
 		"https://localhost",
@@ -407,9 +407,7 @@ func Serve(ln net.Listener, extraOrigins []string) error {
 		"http://0.0.0.0:*",
 		"https://0.0.0.0",
 		"https://0.0.0.0:*",
-	}
-	allowedOrigins = append(allowedOrigins, extraOrigins...)
-	config.AllowOrigins = allowedOrigins
+	}...)
 
 	r := gin.Default()
 	r.Use(cors.New(config))