浏览代码

zip: prevent extracting files into parent dirs (#5314)

Michael Yang 10 月之前
父节点
当前提交
123a722a6f
共有 3 个文件被更改,包括 133 次插入22 次删除
  1. 3 3
      cmd/cmd.go
  2. 38 19
      server/model.go
  3. 92 0
      server/model_test.go

+ 3 - 3
cmd/cmd.go

@@ -162,9 +162,6 @@ func tempZipFiles(path string) (string, error) {
 	}
 	defer tempfile.Close()
 
-	zipfile := zip.NewWriter(tempfile)
-	defer zipfile.Close()
-
 	detectContentType := func(path string) (string, error) {
 		f, err := os.Open(path)
 		if err != nil {
@@ -233,6 +230,9 @@ func tempZipFiles(path string) (string, error) {
 		files = append(files, tks...)
 	}
 
+	zipfile := zip.NewWriter(tempfile)
+	defer zipfile.Close()
+
 	for _, file := range files {
 		f, err := os.Open(file)
 		if err != nil {

+ 38 - 19
server/model.go

@@ -11,6 +11,7 @@ import (
 	"net/http"
 	"os"
 	"path/filepath"
+	"strings"
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/convert"
@@ -77,62 +78,80 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
 	return layers, nil
 }
 
-func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
+func extractFromZipFile(p string, file *os.File, fn func(api.ProgressResponse)) error {
 	stat, err := file.Stat()
 	if err != nil {
-		return nil, err
+		return err
 	}
 
 	r, err := zip.NewReader(file, stat.Size())
 	if err != nil {
-		return nil, err
-	}
-
-	tempdir, err := os.MkdirTemp(filepath.Dir(file.Name()), "")
-	if err != nil {
-		return nil, err
+		return err
 	}
-	defer os.RemoveAll(tempdir)
 
 	fn(api.ProgressResponse{Status: "unpacking model metadata"})
 	for _, f := range r.File {
+		n := filepath.Join(p, f.Name)
+		if !strings.HasPrefix(n, p) {
+			slog.Warn("skipped extracting file outside of context", "name", f.Name)
+			continue
+		}
+
+		if err := os.MkdirAll(filepath.Dir(n), 0o750); err != nil {
+			return err
+		}
+
 		// TODO(mxyng): this should not write out all files to disk
-		outfile, err := os.Create(filepath.Join(tempdir, f.Name))
+		outfile, err := os.Create(n)
 		if err != nil {
-			return nil, err
+			return err
 		}
 		defer outfile.Close()
 
 		infile, err := f.Open()
 		if err != nil {
-			return nil, err
+			return err
 		}
 		defer infile.Close()
 
 		if _, err = io.Copy(outfile, infile); err != nil {
-			return nil, err
+			return err
 		}
 
 		if err := outfile.Close(); err != nil {
-			return nil, err
+			return err
 		}
 
 		if err := infile.Close(); err != nil {
-			return nil, err
+			return err
 		}
 	}
 
-	mf, err := convert.GetModelFormat(tempdir)
+	return nil
+}
+
+func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
+	tempDir, err := os.MkdirTemp(filepath.Dir(file.Name()), "")
+	if err != nil {
+		return nil, err
+	}
+	defer os.RemoveAll(tempDir)
+
+	if err := extractFromZipFile(tempDir, file, fn); err != nil {
+		return nil, err
+	}
+
+	mf, err := convert.GetModelFormat(tempDir)
 	if err != nil {
 		return nil, err
 	}
 
-	params, err := mf.GetParams(tempdir)
+	params, err := mf.GetParams(tempDir)
 	if err != nil {
 		return nil, err
 	}
 
-	mArch, err := mf.GetModelArch("", tempdir, params)
+	mArch, err := mf.GetModelArch("", tempDir, params)
 	if err != nil {
 		return nil, err
 	}
@@ -150,7 +169,7 @@ func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(a
 
 	// TODO(mxyng): this should write directly into a layer
 	// e.g. NewLayer(arch.Reader(), "application/vnd.ollama.image.model")
-	temp, err := os.CreateTemp(tempdir, "fp16")
+	temp, err := os.CreateTemp(tempDir, "fp16")
 	if err != nil {
 		return nil, err
 	}

+ 92 - 0
server/model_test.go

@@ -0,0 +1,92 @@
+package server
+
+import (
+	"archive/zip"
+	"bytes"
+	"io"
+	"os"
+	"path/filepath"
+	"slices"
+	"testing"
+
+	"github.com/ollama/ollama/api"
+)
+
+func createZipFile(t *testing.T, name string) *os.File {
+	t.Helper()
+
+	f, err := os.CreateTemp(t.TempDir(), "")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	zf := zip.NewWriter(f)
+	defer zf.Close()
+
+	zh, err := zf.CreateHeader(&zip.FileHeader{Name: name})
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if _, err := io.Copy(zh, bytes.NewReader([]byte(""))); err != nil {
+		t.Fatal(err)
+	}
+
+	return f
+}
+
+func TestExtractFromZipFile(t *testing.T) {
+	cases := []struct {
+		name   string
+		expect []string
+	}{
+		{
+			name:   "good",
+			expect: []string{"good"},
+		},
+		{
+			name: filepath.Join("..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "bad"),
+		},
+	}
+
+	for _, tt := range cases {
+		t.Run(tt.name, func(t *testing.T) {
+			f := createZipFile(t, tt.name)
+			defer f.Close()
+
+			tempDir := t.TempDir()
+			if err := extractFromZipFile(tempDir, f, func(api.ProgressResponse) {}); err != nil {
+				t.Fatal(err)
+			}
+
+			var matches []string
+			if err := filepath.Walk(tempDir, func(p string, fi os.FileInfo, err error) error {
+				if err != nil {
+					return err
+				}
+
+				if !fi.IsDir() {
+					matches = append(matches, p)
+				}
+
+				return nil
+			}); err != nil {
+				t.Fatal(err)
+			}
+
+			var actual []string
+			for _, match := range matches {
+				rel, err := filepath.Rel(tempDir, match)
+				if err != nil {
+					t.Error(err)
+				}
+
+				actual = append(actual, rel)
+			}
+
+			if !slices.Equal(actual, tt.expect) {
+				t.Fatalf("expected %d files, got %d", len(tt.expect), len(matches))
+			}
+		})
+	}
+}