|
@@ -10,6 +10,7 @@ import (
|
|
|
"log/slog"
|
|
|
"net"
|
|
|
"net/http"
|
|
|
+ "net/netip"
|
|
|
"os"
|
|
|
"os/signal"
|
|
|
"path/filepath"
|
|
@@ -35,7 +36,7 @@ import (
|
|
|
var mode string = gin.DebugMode
|
|
|
|
|
|
type Server struct {
|
|
|
- WorkDir string
|
|
|
+ addr net.Addr
|
|
|
}
|
|
|
|
|
|
func init() {
|
|
@@ -904,15 +905,64 @@ var defaultAllowOrigins = []string{
|
|
|
"0.0.0.0",
|
|
|
}
|
|
|
|
|
|
-func NewServer() (*Server, error) {
|
|
|
- workDir, err := os.MkdirTemp("", "ollama")
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
+func allowedHost(host string) bool {
|
|
|
+ if host == "" || host == "localhost" {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+
|
|
|
+ if hostname, err := os.Hostname(); err == nil && host == hostname {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+
|
|
|
+ var tlds = []string{
|
|
|
+ ".localhost",
|
|
|
+ ".local",
|
|
|
+ ".internal",
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, tld := range tlds {
|
|
|
+ if strings.HasSuffix(host, "."+tld) {
|
|
|
+ return true
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- return &Server{
|
|
|
- WorkDir: workDir,
|
|
|
- }, nil
|
|
|
+ return false
|
|
|
+}
|
|
|
+
|
|
|
+func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
|
|
|
+ return func(c *gin.Context) {
|
|
|
+ if addr == nil {
|
|
|
+ c.Next()
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if !netip.MustParseAddrPort(addr.String()).Addr().IsLoopback() {
|
|
|
+ c.Next()
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if addrPort, _ := netip.ParseAddrPort(c.Request.Host); addrPort.Addr().IsLoopback() {
|
|
|
+ c.Next()
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if addr, _ := netip.ParseAddr(c.Request.Host); addr.IsLoopback() {
|
|
|
+ c.Next()
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ host, _, err := net.SplitHostPort(c.Request.Host)
|
|
|
+ if err != nil {
|
|
|
+ host = c.Request.Host
|
|
|
+ }
|
|
|
+
|
|
|
+ if allowedHost(host) {
|
|
|
+ c.Next()
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ c.AbortWithStatus(http.StatusForbidden)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
func (s *Server) GenerateRoutes() http.Handler {
|
|
@@ -938,10 +988,7 @@ func (s *Server) GenerateRoutes() http.Handler {
|
|
|
r := gin.Default()
|
|
|
r.Use(
|
|
|
cors.New(config),
|
|
|
- func(c *gin.Context) {
|
|
|
- c.Set("workDir", s.WorkDir)
|
|
|
- c.Next()
|
|
|
- },
|
|
|
+ allowedHostsMiddleware(s.addr),
|
|
|
)
|
|
|
|
|
|
r.POST("/api/pull", PullModelHandler)
|
|
@@ -1010,10 +1057,7 @@ func Serve(ln net.Listener) error {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- s, err := NewServer()
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
+ s := &Server{addr: ln.Addr()}
|
|
|
r := s.GenerateRoutes()
|
|
|
|
|
|
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
|
|
@@ -1029,7 +1073,6 @@ func Serve(ln net.Listener) error {
|
|
|
if loaded.runner != nil {
|
|
|
loaded.runner.Close()
|
|
|
}
|
|
|
- os.RemoveAll(s.WorkDir)
|
|
|
os.Exit(0)
|
|
|
}()
|
|
|
|