Pārlūkot izejas kodu

convert: only extract large files

Michael Yang 10 mēneši atpakaļ
vecāks
revīzija
eafc607abb

+ 5 - 6
convert/convert.go

@@ -5,9 +5,8 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"io/fs"
 	"log/slog"
-	"os"
-	"path/filepath"
 
 	"github.com/ollama/ollama/llm"
 )
@@ -67,8 +66,8 @@ type Converter interface {
 // and files it finds in the input path.
 // Supported input model formats include safetensors.
 // Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
-func Convert(path string, ws io.WriteSeeker) error {
-	bts, err := os.ReadFile(filepath.Join(path, "config.json"))
+func Convert(fsys fs.FS, ws io.WriteSeeker) error {
+	bts, err := fs.ReadFile(fsys, "config.json")
 	if err != nil {
 		return err
 	}
@@ -98,7 +97,7 @@ func Convert(path string, ws io.WriteSeeker) error {
 		return err
 	}
 
-	t, err := parseTokenizer(path, conv.specialTokenTypes())
+	t, err := parseTokenizer(fsys, conv.specialTokenTypes())
 	if err != nil {
 		return err
 	}
@@ -114,7 +113,7 @@ func Convert(path string, ws io.WriteSeeker) error {
 		slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
 	}
 
-	ts, err := parseTensors(path)
+	ts, err := parseTensors(fsys)
 	if err != nil {
 		return err
 	}

+ 4 - 3
convert/convert_test.go

@@ -6,6 +6,7 @@ import (
 	"flag"
 	"fmt"
 	"io"
+	"io/fs"
 	"log/slog"
 	"math"
 	"os"
@@ -17,7 +18,7 @@ import (
 	"golang.org/x/exp/maps"
 )
 
-func convertFull(t *testing.T, d string) (*os.File, llm.KV, llm.Tensors) {
+func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, llm.Tensors) {
 	t.Helper()
 
 	f, err := os.CreateTemp(t.TempDir(), "f16")
@@ -26,7 +27,7 @@ func convertFull(t *testing.T, d string) (*os.File, llm.KV, llm.Tensors) {
 	}
 	defer f.Close()
 
-	if err := Convert(d, f); err != nil {
+	if err := Convert(fsys, f); err != nil {
 		t.Fatal(err)
 	}
 
@@ -76,7 +77,7 @@ func TestConvertFull(t *testing.T) {
 				t.Skipf("%s not found", p)
 			}
 
-			f, kv, tensors := convertFull(t, p)
+			f, kv, tensors := convertFull(t, os.DirFS(p))
 			actual := make(map[string]string)
 			for k, v := range kv {
 				if s, ok := v.(json.Marshaler); !ok {

+ 58 - 0
convert/fs.go

@@ -0,0 +1,58 @@
+package convert
+
+import (
+	"archive/zip"
+	"errors"
+	"io"
+	"io/fs"
+	"os"
+	"path/filepath"
+)
+
+type ZipReader struct {
+	r     *zip.Reader
+	p     string
+
+	// limit is the maximum size of a file that can be read directly
+	// from the zip archive. Files larger than this size will be extracted
+	limit int64
+}
+
+func NewZipReader(r *zip.Reader, p string, limit int64) fs.FS {
+	return &ZipReader{r, p, limit}
+}
+
+func (z *ZipReader) Open(name string) (fs.File, error) {
+	r, err := z.r.Open(name)
+	if err != nil {
+		return nil, err
+	}
+	defer r.Close()
+
+	if fi, err := r.Stat(); err != nil {
+		return nil, err
+	} else if fi.Size() < z.limit {
+		return r, nil
+	}
+
+	if !filepath.IsLocal(name) {
+		return nil, zip.ErrInsecurePath
+	}
+
+	n := filepath.Join(z.p, name)
+	if _, err := os.Stat(n); errors.Is(err, os.ErrNotExist) {
+		w, err := os.Create(n)
+		if err != nil {
+			return nil, err
+		}
+		defer w.Close()
+
+		if _, err := io.Copy(w, r); err != nil {
+			return nil, err
+		}
+	} else if err != nil {
+		return nil, err
+	}
+
+	return os.Open(n)
+}

+ 5 - 5
convert/reader.go

@@ -3,7 +3,7 @@ package convert
 import (
 	"errors"
 	"io"
-	"path/filepath"
+	"io/fs"
 	"strings"
 )
 
@@ -55,8 +55,8 @@ func (t *tensorBase) SetRepacker(fn repacker) {
 
 type repacker func(string, []float32, []uint64) ([]float32, error)
 
-func parseTensors(d string) ([]Tensor, error) {
-	patterns := map[string]func(...string) ([]Tensor, error){
+func parseTensors(fsys fs.FS) ([]Tensor, error) {
+	patterns := map[string]func(fs.FS, ...string) ([]Tensor, error){
 		"model-*-of-*.safetensors": parseSafetensors,
 		"model.safetensors":        parseSafetensors,
 		"pytorch_model-*-of-*.bin": parseTorch,
@@ -65,13 +65,13 @@ func parseTensors(d string) ([]Tensor, error) {
 	}
 
 	for pattern, parseFn := range patterns {
-		matches, err := filepath.Glob(filepath.Join(d, pattern))
+		matches, err := fs.Glob(fsys, pattern)
 		if err != nil {
 			return nil, err
 		}
 
 		if len(matches) > 0 {
-			return parseFn(matches...)
+			return parseFn(fsys, matches...)
 		}
 	}
 

+ 14 - 6
convert/reader_safetensors.go

@@ -6,7 +6,7 @@ import (
 	"encoding/json"
 	"fmt"
 	"io"
-	"os"
+	"io/fs"
 	"slices"
 
 	"github.com/d4l3k/go-bfloat16"
@@ -20,10 +20,10 @@ type safetensorMetadata struct {
 	Offsets []int64  `json:"data_offsets"`
 }
 
-func parseSafetensors(ps ...string) ([]Tensor, error) {
+func parseSafetensors(fsys fs.FS, ps ...string) ([]Tensor, error) {
 	var ts []Tensor
 	for _, p := range ps {
-		f, err := os.Open(p)
+		f, err := fsys.Open(p)
 		if err != nil {
 			return nil, err
 		}
@@ -50,6 +50,7 @@ func parseSafetensors(ps ...string) ([]Tensor, error) {
 		for _, key := range keys {
 			if value := headers[key]; value.Type != "" {
 				ts = append(ts, safetensor{
+					fs:     fsys,
 					path:   p,
 					dtype:  value.Type,
 					offset: safetensorsPad(n, value.Offsets[0]),
@@ -72,6 +73,7 @@ func safetensorsPad(n, offset int64) int64 {
 }
 
 type safetensor struct {
+	fs     fs.FS
 	path   string
 	dtype  string
 	offset int64
@@ -80,14 +82,20 @@ type safetensor struct {
 }
 
 func (st safetensor) WriteTo(w io.Writer) (int64, error) {
-	f, err := os.Open(st.path)
+	f, err := st.fs.Open(st.path)
 	if err != nil {
 		return 0, err
 	}
 	defer f.Close()
 
-	if _, err = f.Seek(st.offset, io.SeekStart); err != nil {
-		return 0, err
+	if seeker, ok := f.(io.Seeker); ok {
+		if _, err := seeker.Seek(st.offset, io.SeekStart); err != nil {
+			return 0, err
+		}
+	} else {
+		if _, err := io.CopyN(io.Discard, f, st.offset); err != nil {
+			return 0, err
+		}
 	}
 
 	var f32s []float32

+ 2 - 1
convert/reader_torch.go

@@ -2,12 +2,13 @@ package convert
 
 import (
 	"io"
+	"io/fs"
 
 	"github.com/nlpodyssey/gopickle/pytorch"
 	"github.com/nlpodyssey/gopickle/types"
 )
 
-func parseTorch(ps ...string) ([]Tensor, error) {
+func parseTorch(fsys fs.FS, ps ...string) ([]Tensor, error) {
 	var ts []Tensor
 	for _, p := range ps {
 		pt, err := pytorch.Load(p)

+ 11 - 11
convert/tokenizer.go

@@ -7,9 +7,9 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
+	"io/fs"
 	"log/slog"
 	"os"
-	"path/filepath"
 	"slices"
 )
 
@@ -32,8 +32,8 @@ type Tokenizer struct {
 	Template string
 }
 
-func parseTokenizer(d string, specialTokenTypes []string) (*Tokenizer, error) {
-	v, err := parseVocabulary(d)
+func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error) {
+	v, err := parseVocabulary(fsys)
 	if err != nil {
 		return nil, err
 	}
@@ -44,7 +44,7 @@ func parseTokenizer(d string, specialTokenTypes []string) (*Tokenizer, error) {
 	}
 
 	addedTokens := make(map[string]token)
-	if f, err := os.Open(filepath.Join(d, "tokenizer.json")); errors.Is(err, os.ErrNotExist) {
+	if f, err := fsys.Open("tokenizer.json"); errors.Is(err, os.ErrNotExist) {
 	} else if err != nil {
 		return nil, err
 	} else {
@@ -87,7 +87,7 @@ func parseTokenizer(d string, specialTokenTypes []string) (*Tokenizer, error) {
 		}
 	}
 
-	if f, err := os.Open(filepath.Join(d, "tokenizer_config.json")); errors.Is(err, os.ErrNotExist) {
+	if f, err := fsys.Open("tokenizer_config.json"); errors.Is(err, os.ErrNotExist) {
 	} else if err != nil {
 		return nil, err
 	} else {
@@ -172,8 +172,8 @@ type Vocabulary struct {
 	Types  []int32
 }
 
-func parseVocabularyFromTokenizer(p string) (*Vocabulary, error) {
-	f, err := os.Open(filepath.Join(p, "tokenizer.json"))
+func parseVocabularyFromTokenizer(fsys fs.FS) (*Vocabulary, error) {
+	f, err := fsys.Open("tokenizer.json")
 	if err != nil {
 		return nil, err
 	}
@@ -219,20 +219,20 @@ func parseVocabularyFromTokenizer(p string) (*Vocabulary, error) {
 	return &v, nil
 }
 
-func parseVocabulary(d string) (*Vocabulary, error) {
-	patterns := map[string]func(string) (*Vocabulary, error){
+func parseVocabulary(fsys fs.FS) (*Vocabulary, error) {
+	patterns := map[string]func(fs.FS) (*Vocabulary, error){
 		"tokenizer.model": parseSentencePiece,
 		"tokenizer.json":  parseVocabularyFromTokenizer,
 	}
 
 	for pattern, parseFn := range patterns {
-		if _, err := os.Stat(filepath.Join(d, pattern)); errors.Is(err, os.ErrNotExist) {
+		if _, err := fs.Stat(fsys, pattern); errors.Is(err, os.ErrNotExist) {
 			continue
 		} else if err != nil {
 			return nil, err
 		}
 
-		return parseFn(d)
+		return parseFn(fsys)
 	}
 
 	return nil, errors.New("unknown tensor format")

+ 4 - 4
convert/tokenizer_spm.go

@@ -5,8 +5,8 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
+	"io/fs"
 	"os"
-	"path/filepath"
 	"slices"
 
 	"google.golang.org/protobuf/proto"
@@ -14,8 +14,8 @@ import (
 	"github.com/ollama/ollama/convert/sentencepiece"
 )
 
-func parseSentencePiece(d string) (*Vocabulary, error) {
-	bts, err := os.ReadFile(filepath.Join(d, "tokenizer.model"))
+func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
+	bts, err := fs.ReadFile(fsys, "tokenizer.model")
 	if err != nil {
 		return nil, err
 	}
@@ -41,7 +41,7 @@ func parseSentencePiece(d string) (*Vocabulary, error) {
 		}
 	}
 
-	f, err := os.Open(filepath.Join(d, "added_tokens.json"))
+	f, err := fsys.Open("added_tokens.json")
 	if errors.Is(err, os.ErrNotExist) {
 		return &v, nil
 	} else if err != nil {

+ 14 - 59
server/model.go

@@ -81,88 +81,43 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
 	return layers, nil
 }
 
-func extractFromZipFile(p string, file *os.File, fn func(api.ProgressResponse)) error {
-	stat, err := file.Stat()
-	if err != nil {
-		return err
-	}
-
-	r, err := zip.NewReader(file, stat.Size())
+func parseFromZipFile(_ context.Context, f *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
+	fi, err := f.Stat()
 	if err != nil {
-		return err
-	}
-
-	fn(api.ProgressResponse{Status: "unpacking model metadata"})
-	for _, f := range r.File {
-		if !filepath.IsLocal(f.Name) {
-			return fmt.Errorf("%w: %s", zip.ErrInsecurePath, f.Name)
-		}
-
-		n := filepath.Join(p, f.Name)
-		if err := os.MkdirAll(filepath.Dir(n), 0o750); err != nil {
-			return err
-		}
-
-		// TODO(mxyng): this should not write out all files to disk
-		outfile, err := os.Create(n)
-		if err != nil {
-			return err
-		}
-		defer outfile.Close()
-
-		infile, err := f.Open()
-		if err != nil {
-			return err
-		}
-		defer infile.Close()
-
-		if _, err = io.Copy(outfile, infile); err != nil {
-			return err
-		}
-
-		if err := outfile.Close(); err != nil {
-			return err
-		}
-
-		if err := infile.Close(); err != nil {
-			return err
-		}
+		return nil, err
 	}
 
-	return nil
-}
-
-func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
-	tempDir, err := os.MkdirTemp(filepath.Dir(file.Name()), "")
+	r, err := zip.NewReader(f, fi.Size())
 	if err != nil {
 		return nil, err
 	}
-	defer os.RemoveAll(tempDir)
 
-	if err := extractFromZipFile(tempDir, file, fn); err != nil {
+	p, err := os.MkdirTemp(filepath.Dir(f.Name()), "")
+	if err != nil {
 		return nil, err
 	}
+	defer os.RemoveAll(p)
 
 	fn(api.ProgressResponse{Status: "converting model"})
-
 	// TODO(mxyng): this should write directly into a layer
 	// e.g. NewLayer(arch.Reader(), "application/vnd.ollama.image.model")
-	temp, err := os.CreateTemp(tempDir, "fp16")
+	t, err := os.CreateTemp(p, "fp16")
 	if err != nil {
 		return nil, err
 	}
-	defer temp.Close()
-	defer os.Remove(temp.Name())
+	defer t.Close()
+	defer os.Remove(t.Name())
 
-	if err := convert.Convert(tempDir, temp); err != nil {
+	fn(api.ProgressResponse{Status: "converting model"})
+	if err := convert.Convert(convert.NewZipReader(r, p, 32<<20), t); err != nil {
 		return nil, err
 	}
 
-	if _, err := temp.Seek(0, io.SeekStart); err != nil {
+	if _, err := t.Seek(0, io.SeekStart); err != nil {
 		return nil, err
 	}
 
-	layer, err := NewLayer(temp, "application/vnd.ollama.image.model")
+	layer, err := NewLayer(t, "application/vnd.ollama.image.model")
 	if err != nil {
 		return nil, err
 	}

+ 0 - 102
server/model_test.go

@@ -1,16 +1,11 @@
 package server
 
 import (
-	"archive/zip"
 	"bytes"
 	"encoding/json"
-	"errors"
 	"fmt"
-	"io"
 	"os"
 	"path/filepath"
-	"slices"
-	"strings"
 	"testing"
 
 	"github.com/google/go-cmp/cmp"
@@ -18,103 +13,6 @@ import (
 	"github.com/ollama/ollama/template"
 )
 
-func createZipFile(t *testing.T, name string) *os.File {
-	t.Helper()
-
-	f, err := os.CreateTemp(t.TempDir(), "")
-	if err != nil {
-		t.Fatal(err)
-	}
-
-	zf := zip.NewWriter(f)
-	defer zf.Close()
-
-	zh, err := zf.CreateHeader(&zip.FileHeader{Name: name})
-	if err != nil {
-		t.Fatal(err)
-	}
-
-	if _, err := io.Copy(zh, bytes.NewReader([]byte(""))); err != nil {
-		t.Fatal(err)
-	}
-
-	return f
-}
-
-func TestExtractFromZipFile(t *testing.T) {
-	cases := []struct {
-		name   string
-		expect []string
-		err    error
-	}{
-		{
-			name:   "good",
-			expect: []string{"good"},
-		},
-		{
-			name:   strings.Join([]string{"path", "..", "to", "good"}, string(os.PathSeparator)),
-			expect: []string{filepath.Join("to", "good")},
-		},
-		{
-			name:   strings.Join([]string{"path", "..", "to", "..", "good"}, string(os.PathSeparator)),
-			expect: []string{"good"},
-		},
-		{
-			name:   strings.Join([]string{"path", "to", "..", "..", "good"}, string(os.PathSeparator)),
-			expect: []string{"good"},
-		},
-		{
-			name: strings.Join([]string{"..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "bad"}, string(os.PathSeparator)),
-			err:  zip.ErrInsecurePath,
-		},
-		{
-			name: strings.Join([]string{"path", "..", "..", "to", "bad"}, string(os.PathSeparator)),
-			err:  zip.ErrInsecurePath,
-		},
-	}
-
-	for _, tt := range cases {
-		t.Run(tt.name, func(t *testing.T) {
-			f := createZipFile(t, tt.name)
-			defer f.Close()
-
-			tempDir := t.TempDir()
-			if err := extractFromZipFile(tempDir, f, func(api.ProgressResponse) {}); !errors.Is(err, tt.err) {
-				t.Fatal(err)
-			}
-
-			var matches []string
-			if err := filepath.Walk(tempDir, func(p string, fi os.FileInfo, err error) error {
-				if err != nil {
-					return err
-				}
-
-				if !fi.IsDir() {
-					matches = append(matches, p)
-				}
-
-				return nil
-			}); err != nil {
-				t.Fatal(err)
-			}
-
-			var actual []string
-			for _, match := range matches {
-				rel, err := filepath.Rel(tempDir, match)
-				if err != nil {
-					t.Error(err)
-				}
-
-				actual = append(actual, rel)
-			}
-
-			if !slices.Equal(actual, tt.expect) {
-				t.Fatalf("expected %d files, got %d", len(tt.expect), len(matches))
-			}
-		})
-	}
-}
-
 func readFile(t *testing.T, base, name string) *bytes.Buffer {
 	t.Helper()