Michael Yang 10 months ago
parent
commit
d1a5227cad
3 changed files with 119 additions and 30 deletions
  1. 25 27
      envconfig/config.go
  2. 93 2
      envconfig/config_test.go
  3. 1 1
      server/routes.go

+ 25 - 27
envconfig/config.go

@@ -75,9 +75,31 @@ func Host() *url.URL {
 	}
 }
 
+// Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable.
+func Origins() (origins []string) {
+	if s := clean("OLLAMA_ORIGINS"); s != "" {
+		origins = strings.Split(s, ",")
+	}
+
+	for _, origin := range []string{"localhost", "127.0.0.1", "0.0.0.0"} {
+		origins = append(origins,
+			fmt.Sprintf("http://%s", origin),
+			fmt.Sprintf("https://%s", origin),
+			fmt.Sprintf("http://%s", net.JoinHostPort(origin, "*")),
+			fmt.Sprintf("https://%s", net.JoinHostPort(origin, "*")),
+		)
+	}
+
+	origins = append(origins,
+		"app://*",
+		"file://*",
+		"tauri://*",
+	)
+
+	return origins
+}
+
 var (
-	// Set via OLLAMA_ORIGINS in the environment
-	AllowOrigins []string
 	// Experimental flash attention
 	FlashAttention bool
 	// Set via OLLAMA_KEEP_ALIVE in the environment
@@ -136,7 +158,7 @@ func AsMap() map[string]EnvVar {
 		"OLLAMA_NOHISTORY":         {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
 		"OLLAMA_NOPRUNE":           {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
 		"OLLAMA_NUM_PARALLEL":      {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests"},
-		"OLLAMA_ORIGINS":           {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"},
+		"OLLAMA_ORIGINS":           {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"},
 		"OLLAMA_RUNNERS_DIR":       {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"},
 		"OLLAMA_SCHED_SPREAD":      {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"},
 		"OLLAMA_TMPDIR":            {"OLLAMA_TMPDIR", TmpDir, "Location for temporary files"},
@@ -160,12 +182,6 @@ func Values() map[string]string {
 	return vals
 }
 
-var defaultAllowOrigins = []string{
-	"localhost",
-	"127.0.0.1",
-	"0.0.0.0",
-}
-
 // Clean quotes and spaces from the value
 func clean(key string) string {
 	return strings.Trim(os.Getenv(key), "\"' ")
@@ -255,24 +271,6 @@ func LoadConfig() {
 		NoPrune = true
 	}
 
-	if origins := clean("OLLAMA_ORIGINS"); origins != "" {
-		AllowOrigins = strings.Split(origins, ",")
-	}
-	for _, allowOrigin := range defaultAllowOrigins {
-		AllowOrigins = append(AllowOrigins,
-			fmt.Sprintf("http://%s", allowOrigin),
-			fmt.Sprintf("https://%s", allowOrigin),
-			fmt.Sprintf("http://%s", net.JoinHostPort(allowOrigin, "*")),
-			fmt.Sprintf("https://%s", net.JoinHostPort(allowOrigin, "*")),
-		)
-	}
-
-	AllowOrigins = append(AllowOrigins,
-		"app://*",
-		"file://*",
-		"tauri://*",
-	)
-
 	maxRunners := clean("OLLAMA_MAX_LOADED_MODELS")
 	if maxRunners != "" {
 		m, err := strconv.Atoi(maxRunners)

+ 93 - 2
envconfig/config_test.go

@@ -5,10 +5,11 @@ import (
 	"testing"
 	"time"
 
+	"github.com/google/go-cmp/cmp"
 	"github.com/stretchr/testify/require"
 )
 
-func TestConfig(t *testing.T) {
+func TestSmoke(t *testing.T) {
 	t.Setenv("OLLAMA_DEBUG", "")
 	require.False(t, Debug())
 
@@ -38,7 +39,7 @@ func TestConfig(t *testing.T) {
 	require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
 }
 
-func TestClientFromEnvironment(t *testing.T) {
+func TestHost(t *testing.T) {
 	cases := map[string]struct {
 		value  string
 		expect string
@@ -71,3 +72,93 @@ func TestClientFromEnvironment(t *testing.T) {
 		})
 	}
 }
+
+func TestOrigins(t *testing.T) {
+	cases := []struct {
+		value  string
+		expect []string
+	}{
+		{"", []string{
+			"http://localhost",
+			"https://localhost",
+			"http://localhost:*",
+			"https://localhost:*",
+			"http://127.0.0.1",
+			"https://127.0.0.1",
+			"http://127.0.0.1:*",
+			"https://127.0.0.1:*",
+			"http://0.0.0.0",
+			"https://0.0.0.0",
+			"http://0.0.0.0:*",
+			"https://0.0.0.0:*",
+			"app://*",
+			"file://*",
+			"tauri://*",
+		}},
+		{"http://10.0.0.1", []string{
+			"http://10.0.0.1",
+			"http://localhost",
+			"https://localhost",
+			"http://localhost:*",
+			"https://localhost:*",
+			"http://127.0.0.1",
+			"https://127.0.0.1",
+			"http://127.0.0.1:*",
+			"https://127.0.0.1:*",
+			"http://0.0.0.0",
+			"https://0.0.0.0",
+			"http://0.0.0.0:*",
+			"https://0.0.0.0:*",
+			"app://*",
+			"file://*",
+			"tauri://*",
+		}},
+		{"http://172.16.0.1,https://192.168.0.1", []string{
+			"http://172.16.0.1",
+			"https://192.168.0.1",
+			"http://localhost",
+			"https://localhost",
+			"http://localhost:*",
+			"https://localhost:*",
+			"http://127.0.0.1",
+			"https://127.0.0.1",
+			"http://127.0.0.1:*",
+			"https://127.0.0.1:*",
+			"http://0.0.0.0",
+			"https://0.0.0.0",
+			"http://0.0.0.0:*",
+			"https://0.0.0.0:*",
+			"app://*",
+			"file://*",
+			"tauri://*",
+		}},
+		{"http://totally.safe,http://definitely.legit", []string{
+			"http://totally.safe",
+			"http://definitely.legit",
+			"http://localhost",
+			"https://localhost",
+			"http://localhost:*",
+			"https://localhost:*",
+			"http://127.0.0.1",
+			"https://127.0.0.1",
+			"http://127.0.0.1:*",
+			"https://127.0.0.1:*",
+			"http://0.0.0.0",
+			"https://0.0.0.0",
+			"http://0.0.0.0:*",
+			"https://0.0.0.0:*",
+			"app://*",
+			"file://*",
+			"tauri://*",
+		}},
+	}
+	for _, tt := range cases {
+		t.Run(tt.value, func(t *testing.T) {
+			t.Setenv("OLLAMA_ORIGINS", tt.value)
+
+			if diff := cmp.Diff(Origins(), tt.expect); diff != "" {
+				t.Errorf("%s: mismatch (-want +got):\n%s", tt.value, diff)
+			}
+		})
+	}
+}

+ 1 - 1
server/routes.go

@@ -1048,7 +1048,7 @@ func (s *Server) GenerateRoutes() http.Handler {
 	for _, prop := range openAIProperties {
 		config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop)
 	}
-	config.AllowOrigins = envconfig.AllowOrigins
+	config.AllowOrigins = envconfig.Origins()
 
 	r := gin.Default()
 	r.Use(