Forráskód Böngészése

add allowed host middleware and remove `workDir` middleware (#3018)

Jeffrey Morgan 1 éve
szülő
commit
fc8c044584
2 módosított fájl, 61 hozzáadás és 26 törlés
  1. 60 17
      server/routes.go
  2. 1 9
      server/routes_test.go

+ 60 - 17
server/routes.go

@@ -10,6 +10,7 @@ import (
 	"log/slog"
 	"log/slog"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
+	"net/netip"
 	"os"
 	"os"
 	"os/signal"
 	"os/signal"
 	"path/filepath"
 	"path/filepath"
@@ -35,7 +36,7 @@ import (
 var mode string = gin.DebugMode
 var mode string = gin.DebugMode
 
 
 type Server struct {
 type Server struct {
-	WorkDir string
+	addr net.Addr
 }
 }
 
 
 func init() {
 func init() {
@@ -904,15 +905,64 @@ var defaultAllowOrigins = []string{
 	"0.0.0.0",
 	"0.0.0.0",
 }
 }
 
 
-func NewServer() (*Server, error) {
-	workDir, err := os.MkdirTemp("", "ollama")
-	if err != nil {
-		return nil, err
+func allowedHost(host string) bool {
+	if host == "" || host == "localhost" {
+		return true
+	}
+
+	if hostname, err := os.Hostname(); err == nil && host == hostname {
+		return true
+	}
+
+	var tlds = []string{
+		".localhost",
+		".local",
+		".internal",
+	}
+
+	for _, tld := range tlds {
+		if strings.HasSuffix(host, "."+tld) {
+			return true
+		}
 	}
 	}
 
 
-	return &Server{
-		WorkDir: workDir,
-	}, nil
+	return false
+}
+
+func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
+	return func(c *gin.Context) {
+		if addr == nil {
+			c.Next()
+			return
+		}
+
+		if !netip.MustParseAddrPort(addr.String()).Addr().IsLoopback() {
+			c.Next()
+			return
+		}
+
+		if addrPort, _ := netip.ParseAddrPort(c.Request.Host); addrPort.Addr().IsLoopback() {
+			c.Next()
+			return
+		}
+
+		if addr, _ := netip.ParseAddr(c.Request.Host); addr.IsLoopback() {
+			c.Next()
+			return
+		}
+
+		host, _, err := net.SplitHostPort(c.Request.Host)
+		if err != nil {
+			host = c.Request.Host
+		}
+
+		if allowedHost(host) {
+			c.Next()
+			return
+		}
+
+		c.AbortWithStatus(http.StatusForbidden)
+	}
 }
 }
 
 
 func (s *Server) GenerateRoutes() http.Handler {
 func (s *Server) GenerateRoutes() http.Handler {
@@ -938,10 +988,7 @@ func (s *Server) GenerateRoutes() http.Handler {
 	r := gin.Default()
 	r := gin.Default()
 	r.Use(
 	r.Use(
 		cors.New(config),
 		cors.New(config),
-		func(c *gin.Context) {
-			c.Set("workDir", s.WorkDir)
-			c.Next()
-		},
+		allowedHostsMiddleware(s.addr),
 	)
 	)
 
 
 	r.POST("/api/pull", PullModelHandler)
 	r.POST("/api/pull", PullModelHandler)
@@ -1010,10 +1057,7 @@ func Serve(ln net.Listener) error {
 		}
 		}
 	}
 	}
 
 
-	s, err := NewServer()
-	if err != nil {
-		return err
-	}
+	s := &Server{addr: ln.Addr()}
 	r := s.GenerateRoutes()
 	r := s.GenerateRoutes()
 
 
 	slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
 	slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
@@ -1029,7 +1073,6 @@ func Serve(ln net.Listener) error {
 		if loaded.runner != nil {
 		if loaded.runner != nil {
 			loaded.runner.Close()
 			loaded.runner.Close()
 		}
 		}
-		os.RemoveAll(s.WorkDir)
 		os.Exit(0)
 		os.Exit(0)
 	}()
 	}()
 
 

+ 1 - 9
server/routes_test.go

@@ -21,12 +21,6 @@ import (
 	"github.com/jmorganca/ollama/version"
 	"github.com/jmorganca/ollama/version"
 )
 )
 
 
-func setupServer(t *testing.T) (*Server, error) {
-	t.Helper()
-
-	return NewServer()
-}
-
 func Test_Routes(t *testing.T) {
 func Test_Routes(t *testing.T) {
 	type testCase struct {
 	type testCase struct {
 		Name     string
 		Name     string
@@ -207,9 +201,7 @@ func Test_Routes(t *testing.T) {
 		},
 		},
 	}
 	}
 
 
-	s, err := setupServer(t)
-	assert.Nil(t, err)
-
+	s := Server{}
 	router := s.GenerateRoutes()
 	router := s.GenerateRoutes()
 
 
 	httpSrv := httptest.NewServer(router)
 	httpSrv := httptest.NewServer(router)