瀏覽代碼

Merge pull request #3833 from ollama/mxyng/fix-from

fix: from blob
Michael Yang 1 年之前
父節點
當前提交
2010cbc5fa
共有 1 個文件被更改,包括 96 次插入87 次删除
  1. 96 87
      cmd/cmd.go

+ 96 - 87
cmd/cmd.go

@@ -17,6 +17,7 @@ import (
 	"os"
 	"os/signal"
 	"path/filepath"
+	"regexp"
 	"runtime"
 	"strings"
 	"syscall"
@@ -53,8 +54,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 	p := progress.NewProgress(os.Stderr)
 	defer p.Stop()
 
-	bars := make(map[string]*progress.Bar)
-
 	modelfile, err := os.ReadFile(filename)
 	if err != nil {
 		return err
@@ -95,95 +94,16 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 				return err
 			}
 
-			// TODO make this work w/ adapters
 			if fi.IsDir() {
-				tf, err := os.CreateTemp("", "ollama-tf")
-				if err != nil {
-					return err
-				}
-				defer os.RemoveAll(tf.Name())
-
-				zf := zip.NewWriter(tf)
-
-				files := []string{}
-
-				tfiles, err := filepath.Glob(filepath.Join(path, "pytorch_model-*.bin"))
+				// this is likely a safetensors or pytorch directory
+				// TODO make this work w/ adapters
+				tempfile, err := tempZipFiles(path)
 				if err != nil {
 					return err
-				} else if len(tfiles) == 0 {
-					tfiles, err = filepath.Glob(filepath.Join(path, "model-*.safetensors"))
-					if err != nil {
-						return err
-					}
 				}
+				defer os.RemoveAll(tempfile)
 
-				files = append(files, tfiles...)
-
-				if len(files) == 0 {
-					return fmt.Errorf("no models were found in '%s'", path)
-				}
-
-				// add the safetensor/torch config file + tokenizer
-				files = append(files, filepath.Join(path, "config.json"))
-				files = append(files, filepath.Join(path, "params.json"))
-				files = append(files, filepath.Join(path, "added_tokens.json"))
-				files = append(files, filepath.Join(path, "tokenizer.model"))
-
-				for _, fn := range files {
-					f, err := os.Open(fn)
-
-					// just skip whatever files aren't there
-					if os.IsNotExist(err) {
-						if strings.HasSuffix(fn, "tokenizer.model") {
-							// try the parent dir before giving up
-							parentDir := filepath.Dir(path)
-							newFn := filepath.Join(parentDir, "tokenizer.model")
-							f, err = os.Open(newFn)
-							if os.IsNotExist(err) {
-								continue
-							} else if err != nil {
-								return err
-							}
-						} else {
-							continue
-						}
-					} else if err != nil {
-						return err
-					}
-
-					fi, err := f.Stat()
-					if err != nil {
-						return err
-					}
-
-					h, err := zip.FileInfoHeader(fi)
-					if err != nil {
-						return err
-					}
-
-					h.Name = filepath.Base(fn)
-					h.Method = zip.Store
-
-					w, err := zf.CreateHeader(h)
-					if err != nil {
-						return err
-					}
-
-					_, err = io.Copy(w, f)
-					if err != nil {
-						return err
-					}
-
-				}
-
-				if err := zf.Close(); err != nil {
-					return err
-				}
-
-				if err := tf.Close(); err != nil {
-					return err
-				}
-				path = tf.Name()
+				path = tempfile
 			}
 
 			digest, err := createBlob(cmd, client, path)
@@ -191,10 +111,17 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 				return err
 			}
 
-			modelfile = bytes.ReplaceAll(modelfile, []byte(c.Args), []byte("@"+digest))
+			name := c.Name
+			if c.Name == "model" {
+				name = "from"
+			}
+
+			re := regexp.MustCompile(fmt.Sprintf(`(?im)^(%s)\s+%s\s*$`, name, c.Args))
+			modelfile = re.ReplaceAll(modelfile, []byte("$1 @"+digest))
 		}
 	}
 
+	bars := make(map[string]*progress.Bar)
 	fn := func(resp api.ProgressResponse) error {
 		if resp.Digest != "" {
 			spinner.Stop()
@@ -228,6 +155,88 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 	return nil
 }
 
+func tempZipFiles(path string) (string, error) {
+	tempfile, err := os.CreateTemp("", "ollama-tf")
+	if err != nil {
+		return "", err
+	}
+	defer tempfile.Close()
+
+	zipfile := zip.NewWriter(tempfile)
+	defer zipfile.Close()
+
+	tfiles, err := filepath.Glob(filepath.Join(path, "pytorch_model-*.bin"))
+	if err != nil {
+		return "", err
+	} else if len(tfiles) == 0 {
+		tfiles, err = filepath.Glob(filepath.Join(path, "model-*.safetensors"))
+		if err != nil {
+			return "", err
+		}
+	}
+
+	files := []string{}
+	files = append(files, tfiles...)
+
+	if len(files) == 0 {
+		return "", fmt.Errorf("no models were found in '%s'", path)
+	}
+
+	// add the safetensor/torch config file + tokenizer
+	files = append(files, filepath.Join(path, "config.json"))
+	files = append(files, filepath.Join(path, "params.json"))
+	files = append(files, filepath.Join(path, "added_tokens.json"))
+	files = append(files, filepath.Join(path, "tokenizer.model"))
+
+	for _, fn := range files {
+		f, err := os.Open(fn)
+
+		// just skip whatever files aren't there
+		if os.IsNotExist(err) {
+			if strings.HasSuffix(fn, "tokenizer.model") {
+				// try the parent dir before giving up
+				parentDir := filepath.Dir(path)
+				newFn := filepath.Join(parentDir, "tokenizer.model")
+				f, err = os.Open(newFn)
+				if os.IsNotExist(err) {
+					continue
+				} else if err != nil {
+					return "", err
+				}
+			} else {
+				continue
+			}
+		} else if err != nil {
+			return "", err
+		}
+
+		fi, err := f.Stat()
+		if err != nil {
+			return "", err
+		}
+
+		h, err := zip.FileInfoHeader(fi)
+		if err != nil {
+			return "", err
+		}
+
+		h.Name = filepath.Base(fn)
+		h.Method = zip.Store
+
+		w, err := zipfile.CreateHeader(h)
+		if err != nil {
+			return "", err
+		}
+
+		_, err = io.Copy(w, f)
+		if err != nil {
+			return "", err
+		}
+	}
+
+	return tempfile.Name(), nil
+}
+
 func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
 	bin, err := os.Open(path)
 	if err != nil {