|
@@ -165,71 +165,97 @@ func tempZipFiles(path string) (string, error) {
|
|
|
zipfile := zip.NewWriter(tempfile)
|
|
|
defer zipfile.Close()
|
|
|
|
|
|
- tfiles, err := filepath.Glob(filepath.Join(path, "pytorch_model-*.bin"))
|
|
|
- if err != nil {
|
|
|
- return "", err
|
|
|
- } else if len(tfiles) == 0 {
|
|
|
- tfiles, err = filepath.Glob(filepath.Join(path, "model-*.safetensors"))
|
|
|
+ detectContentType := func(path string) (string, error) {
|
|
|
+ f, err := os.Open(path)
|
|
|
if err != nil {
|
|
|
return "", err
|
|
|
}
|
|
|
+ defer f.Close()
|
|
|
+
|
|
|
+ var b bytes.Buffer
|
|
|
+ b.Grow(512)
|
|
|
+
|
|
|
+ if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
|
|
|
+ return "", err
|
|
|
+ }
|
|
|
+
|
|
|
+ contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
|
|
|
+ return contentType, nil
|
|
|
}
|
|
|
|
|
|
- files := []string{}
|
|
|
- files = append(files, tfiles...)
|
|
|
+ glob := func(pattern, contentType string) ([]string, error) {
|
|
|
+ matches, err := filepath.Glob(pattern)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, safetensor := range matches {
|
|
|
+ if ct, err := detectContentType(safetensor); err != nil {
|
|
|
+ return nil, err
|
|
|
+ } else if ct != contentType {
|
|
|
+ return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor)
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- if len(files) == 0 {
|
|
|
- return "", fmt.Errorf("no models were found in '%s'", path)
|
|
|
+ return matches, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ var files []string
|
|
|
+ if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
|
|
|
+ // safetensors files might be unresolved git lfs references; skip if they are
|
|
|
+ // covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
|
|
|
+ files = append(files, st...)
|
|
|
+ } else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
|
|
|
+ // pytorch files might also be unresolved git lfs references; skip if they are
|
|
|
+ // covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
|
|
|
+ files = append(files, pt...)
|
|
|
+ } else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/octet-stream"); len(pt) > 0 {
|
|
|
+ // pytorch files might also be unresolved git lfs references; skip if they are
|
|
|
+ // covers consolidated.x.pth, consolidated.pth
|
|
|
+ files = append(files, pt...)
|
|
|
+ } else {
|
|
|
+ return "", errors.New("no safetensors or torch files found")
|
|
|
}
|
|
|
|
|
|
- // add the safetensor/torch config file + tokenizer
|
|
|
- files = append(files, filepath.Join(path, "config.json"))
|
|
|
- files = append(files, filepath.Join(path, "params.json"))
|
|
|
- files = append(files, filepath.Join(path, "added_tokens.json"))
|
|
|
- files = append(files, filepath.Join(path, "tokenizer.model"))
|
|
|
+ // add configuration files, json files are detected as text/plain
|
|
|
+ js, err := glob(filepath.Join(path, "*.json"), "text/plain")
|
|
|
+ if err != nil {
|
|
|
+ return "", err
|
|
|
+ }
|
|
|
+ files = append(files, js...)
|
|
|
|
|
|
- for _, fn := range files {
|
|
|
- f, err := os.Open(fn)
|
|
|
+ if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
|
|
|
+ // add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
|
|
|
+ // tokenizer.model might be a unresolved git lfs reference; error if it is
|
|
|
+ files = append(files, tks...)
|
|
|
+ } else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
|
|
|
+ // some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
|
|
|
+ files = append(files, tks...)
|
|
|
+ }
|
|
|
|
|
|
- // just skip whatever files aren't there
|
|
|
- if os.IsNotExist(err) {
|
|
|
- if strings.HasSuffix(fn, "tokenizer.model") {
|
|
|
- // try the parent dir before giving up
|
|
|
- parentDir := filepath.Dir(path)
|
|
|
- newFn := filepath.Join(parentDir, "tokenizer.model")
|
|
|
- f, err = os.Open(newFn)
|
|
|
- if os.IsNotExist(err) {
|
|
|
- continue
|
|
|
- } else if err != nil {
|
|
|
- return "", err
|
|
|
- }
|
|
|
- } else {
|
|
|
- continue
|
|
|
- }
|
|
|
- } else if err != nil {
|
|
|
+ for _, file := range files {
|
|
|
+ f, err := os.Open(file)
|
|
|
+ if err != nil {
|
|
|
return "", err
|
|
|
}
|
|
|
+ defer f.Close()
|
|
|
|
|
|
fi, err := f.Stat()
|
|
|
if err != nil {
|
|
|
return "", err
|
|
|
}
|
|
|
|
|
|
- h, err := zip.FileInfoHeader(fi)
|
|
|
+ zfi, err := zip.FileInfoHeader(fi)
|
|
|
if err != nil {
|
|
|
return "", err
|
|
|
}
|
|
|
|
|
|
- h.Name = filepath.Base(fn)
|
|
|
- h.Method = zip.Store
|
|
|
-
|
|
|
- w, err := zipfile.CreateHeader(h)
|
|
|
+ zf, err := zipfile.CreateHeader(zfi)
|
|
|
if err != nil {
|
|
|
return "", err
|
|
|
}
|
|
|
|
|
|
- _, err = io.Copy(w, f)
|
|
|
- if err != nil {
|
|
|
+ if _, err := io.Copy(zf, f); err != nil {
|
|
|
return "", err
|
|
|
}
|
|
|
}
|