瀏覽代碼

check file type before zip

Michael Yang 1 年之前
父節點
當前提交
41e03ede95
共有 1 個文件被更改,包括 65 次插入39 次删除
  1. 65 39
      cmd/cmd.go

+ 65 - 39
cmd/cmd.go

@@ -165,71 +165,97 @@ func tempZipFiles(path string) (string, error) {
 	zipfile := zip.NewWriter(tempfile)
 	defer zipfile.Close()
 
-	tfiles, err := filepath.Glob(filepath.Join(path, "pytorch_model-*.bin"))
-	if err != nil {
-		return "", err
-	} else if len(tfiles) == 0 {
-		tfiles, err = filepath.Glob(filepath.Join(path, "model-*.safetensors"))
+	detectContentType := func(path string) (string, error) {
+		f, err := os.Open(path)
 		if err != nil {
 			return "", err
 		}
+		defer f.Close()
+
+		var b bytes.Buffer
+		b.Grow(512)
+
+		if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
+			return "", err
+		}
+
+		contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
+		return contentType, nil
 	}
 
-	files := []string{}
-	files = append(files, tfiles...)
+	glob := func(pattern, contentType string) ([]string, error) {
+		matches, err := filepath.Glob(pattern)
+		if err != nil {
+			return nil, err
+		}
+
+		for _, safetensor := range matches {
+			if ct, err := detectContentType(safetensor); err != nil {
+				return nil, err
+			} else if ct != contentType {
+				return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor)
+			}
+		}
 
-	if len(files) == 0 {
-		return "", fmt.Errorf("no models were found in '%s'", path)
+		return matches, nil
+	}
+
+	var files []string
+	if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
+		// safetensors files might be unresolved git lfs references; skip if they are
+		// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
+		files = append(files, st...)
+	} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
+		// pytorch files might also be unresolved git lfs references; skip if they are
+		// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
+		files = append(files, pt...)
+	} else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/octet-stream"); len(pt) > 0 {
+		// pytorch files might also be unresolved git lfs references; skip if they are
+		// covers consolidated.x.pth, consolidated.pth
+		files = append(files, pt...)
+	} else {
+		return "", errors.New("no safetensors or torch files found")
 	}
 
-	// add the safetensor/torch config file + tokenizer
-	files = append(files, filepath.Join(path, "config.json"))
-	files = append(files, filepath.Join(path, "params.json"))
-	files = append(files, filepath.Join(path, "added_tokens.json"))
-	files = append(files, filepath.Join(path, "tokenizer.model"))
+	// add configuration files, json files are detected as text/plain
+	js, err := glob(filepath.Join(path, "*.json"), "text/plain")
+	if err != nil {
+		return "", err
+	}
+	files = append(files, js...)
 
-	for _, fn := range files {
-		f, err := os.Open(fn)
+	if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
+		// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
+		// tokenizer.model might be a unresolved git lfs reference; error if it is
+		files = append(files, tks...)
+	} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
+		// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
+		files = append(files, tks...)
+	}
 
-		// just skip whatever files aren't there
-		if os.IsNotExist(err) {
-			if strings.HasSuffix(fn, "tokenizer.model") {
-				// try the parent dir before giving up
-				parentDir := filepath.Dir(path)
-				newFn := filepath.Join(parentDir, "tokenizer.model")
-				f, err = os.Open(newFn)
-				if os.IsNotExist(err) {
-					continue
-				} else if err != nil {
-					return "", err
-				}
-			} else {
-				continue
-			}
-		} else if err != nil {
+	for _, file := range files {
+		f, err := os.Open(file)
+		if err != nil {
 			return "", err
 		}
+		defer f.Close()
 
 		fi, err := f.Stat()
 		if err != nil {
 			return "", err
 		}
 
-		h, err := zip.FileInfoHeader(fi)
+		zfi, err := zip.FileInfoHeader(fi)
 		if err != nil {
 			return "", err
 		}
 
-		h.Name = filepath.Base(fn)
-		h.Method = zip.Store
-
-		w, err := zipfile.CreateHeader(h)
+		zf, err := zipfile.CreateHeader(zfi)
 		if err != nil {
 			return "", err
 		}
 
-		_, err = io.Copy(w, f)
-		if err != nil {
+		if _, err := io.Copy(zf, f); err != nil {
 			return "", err
 		}
 	}