Michael Yang 10 місяців тому
батько
коміт
e401a23d62
1 змінених файлів з 70 додано та 57 видалено
  1. 70 57
      cmd/cmd.go

+ 70 - 57
cmd/cmd.go

@@ -3,6 +3,7 @@ package cmd
 import (
 import (
 	"archive/zip"
 	"archive/zip"
 	"bytes"
 	"bytes"
+	"cmp"
 	"context"
 	"context"
 	"crypto/ed25519"
 	"crypto/ed25519"
 	"crypto/rand"
 	"crypto/rand"
@@ -11,6 +12,7 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
+	"io/fs"
 	"log"
 	"log"
 	"math"
 	"math"
 	"net"
 	"net"
@@ -70,30 +72,24 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 		return err
 		return err
 	}
 	}
 
 
-	home, err := os.UserHomeDir()
+	status := "transferring model data"
+	spinner := progress.NewSpinner(status)
+	p.Add(status, spinner)
+
+	createCtx, err := cmd.Flags().GetString("context")
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	status := "transferring model data"
-	spinner := progress.NewSpinner(status)
-	p.Add(status, spinner)
+	createCtx = cmp.Or(createCtx, filepath.Dir(filename))
+	fsys := os.DirFS(createCtx)
 
 
 	for i := range modelfile.Commands {
 	for i := range modelfile.Commands {
 		switch modelfile.Commands[i].Name {
 		switch modelfile.Commands[i].Name {
 		case "model", "adapter":
 		case "model", "adapter":
-			path := modelfile.Commands[i].Args
-			if path == "~" {
-				path = home
-			} else if strings.HasPrefix(path, "~/") {
-				path = filepath.Join(home, path[2:])
-			}
-
-			if !filepath.IsAbs(path) {
-				path = filepath.Join(filepath.Dir(filename), path)
-			}
+			p := filepath.Clean(modelfile.Commands[i].Args)
 
 
-			fi, err := os.Stat(path)
+			fi, err := fs.Stat(fsys, p)
 			if errors.Is(err, os.ErrNotExist) && modelfile.Commands[i].Name == "model" {
 			if errors.Is(err, os.ErrNotExist) && modelfile.Commands[i].Name == "model" {
 				continue
 				continue
 			} else if err != nil {
 			} else if err != nil {
@@ -103,16 +99,29 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 			if fi.IsDir() {
 			if fi.IsDir() {
 				// this is likely a safetensors or pytorch directory
 				// this is likely a safetensors or pytorch directory
 				// TODO make this work w/ adapters
 				// TODO make this work w/ adapters
-				tempfile, err := tempZipFiles(path)
+				sub, err := fs.Sub(fsys, p)
 				if err != nil {
 				if err != nil {
 					return err
 					return err
 				}
 				}
-				defer os.RemoveAll(tempfile)
 
 
-				path = tempfile
+				temp, err := os.CreateTemp(createCtx, "*.zip")
+				if err != nil {
+					return err
+				}
+				defer temp.Close()
+				defer os.RemoveAll(temp.Name())
+
+				if err := zipFiles(sub, temp); err != nil {
+					return err
+				}
+
+				p, err = filepath.Rel(createCtx, temp.Name())
+				if err != nil {
+					return err
+				}
 			}
 			}
 
 
-			digest, err := createBlob(cmd, client, path)
+			digest, err := createBlob(cmd, client, fsys, p)
 			if err != nil {
 			if err != nil {
 				return err
 				return err
 			}
 			}
@@ -155,42 +164,34 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 	return nil
 	return nil
 }
 }
 
 
-func tempZipFiles(path string) (string, error) {
-	tempfile, err := os.CreateTemp("", "ollama-tf")
-	if err != nil {
-		return "", err
-	}
-	defer tempfile.Close()
-
-	detectContentType := func(path string) (string, error) {
-		f, err := os.Open(path)
+func zipFiles(fsys fs.FS, w io.Writer) error {
+	detectContentType := func(name string) (string, error) {
+		f, err := fsys.Open(name)
 		if err != nil {
 		if err != nil {
 			return "", err
 			return "", err
 		}
 		}
 		defer f.Close()
 		defer f.Close()
 
 
-		var b bytes.Buffer
-		b.Grow(512)
-
-		if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
+		bts, err := io.ReadAll(io.LimitReader(f, 512))
+		if err != nil {
 			return "", err
 			return "", err
 		}
 		}
 
 
-		contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
+		contentType, _, _ := strings.Cut(http.DetectContentType(bts), ";")
 		return contentType, nil
 		return contentType, nil
 	}
 	}
 
 
 	glob := func(pattern, contentType string) ([]string, error) {
 	glob := func(pattern, contentType string) ([]string, error) {
-		matches, err := filepath.Glob(pattern)
+		matches, err := fs.Glob(fsys, pattern)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
 
 
-		for _, safetensor := range matches {
-			if ct, err := detectContentType(safetensor); err != nil {
+		for _, match := range matches {
+			if ct, err := detectContentType(match); err != nil {
 				return nil, err
 				return nil, err
 			} else if ct != contentType {
 			} else if ct != contentType {
-				return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor)
+				return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match)
 			}
 			}
 		}
 		}
 
 
@@ -198,73 +199,73 @@ func tempZipFiles(path string) (string, error) {
 	}
 	}
 
 
 	var files []string
 	var files []string
-	if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
+	if st, _ := glob("model*.safetensors", "application/octet-stream"); len(st) > 0 {
 		// safetensors files might be unresolved git lfs references; skip if they are
 		// safetensors files might be unresolved git lfs references; skip if they are
 		// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
 		// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
 		files = append(files, st...)
 		files = append(files, st...)
-	} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
+	} else if pt, _ := glob("pytorch_model*.bin", "application/zip"); len(pt) > 0 {
 		// pytorch files might also be unresolved git lfs references; skip if they are
 		// pytorch files might also be unresolved git lfs references; skip if they are
 		// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
 		// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
 		files = append(files, pt...)
 		files = append(files, pt...)
-	} else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/zip"); len(pt) > 0 {
+	} else if pt, _ := glob("consolidated*.pth", "application/zip"); len(pt) > 0 {
 		// pytorch files might also be unresolved git lfs references; skip if they are
 		// pytorch files might also be unresolved git lfs references; skip if they are
 		// covers consolidated.x.pth, consolidated.pth
 		// covers consolidated.x.pth, consolidated.pth
 		files = append(files, pt...)
 		files = append(files, pt...)
 	} else {
 	} else {
-		return "", errors.New("no safetensors or torch files found")
+		return errors.New("no safetensors or torch files found")
 	}
 	}
 
 
 	// add configuration files, json files are detected as text/plain
 	// add configuration files, json files are detected as text/plain
-	js, err := glob(filepath.Join(path, "*.json"), "text/plain")
+	js, err := glob("*.json", "text/plain")
 	if err != nil {
 	if err != nil {
-		return "", err
+		return err
 	}
 	}
 	files = append(files, js...)
 	files = append(files, js...)
 
 
-	if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
+	if tks, _ := glob("tokenizer.model", "application/octet-stream"); len(tks) > 0 {
 		// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
 		// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
 		// tokenizer.model might be a unresolved git lfs reference; error if it is
 		// tokenizer.model might be a unresolved git lfs reference; error if it is
 		files = append(files, tks...)
 		files = append(files, tks...)
-	} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
+	} else if tks, _ := glob("**/tokenizer.model", "text/plain"); len(tks) > 0 {
 		// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
 		// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
 		files = append(files, tks...)
 		files = append(files, tks...)
 	}
 	}
 
 
-	zipfile := zip.NewWriter(tempfile)
+	zipfile := zip.NewWriter(w)
 	defer zipfile.Close()
 	defer zipfile.Close()
 
 
 	for _, file := range files {
 	for _, file := range files {
-		f, err := os.Open(file)
+		f, err := fsys.Open(file)
 		if err != nil {
 		if err != nil {
-			return "", err
+			return err
 		}
 		}
 		defer f.Close()
 		defer f.Close()
 
 
 		fi, err := f.Stat()
 		fi, err := f.Stat()
 		if err != nil {
 		if err != nil {
-			return "", err
+			return err
 		}
 		}
 
 
 		zfi, err := zip.FileInfoHeader(fi)
 		zfi, err := zip.FileInfoHeader(fi)
 		if err != nil {
 		if err != nil {
-			return "", err
+			return err
 		}
 		}
 
 
 		zf, err := zipfile.CreateHeader(zfi)
 		zf, err := zipfile.CreateHeader(zfi)
 		if err != nil {
 		if err != nil {
-			return "", err
+			return err
 		}
 		}
 
 
 		if _, err := io.Copy(zf, f); err != nil {
 		if _, err := io.Copy(zf, f); err != nil {
-			return "", err
+			return err
 		}
 		}
 	}
 	}
 
 
-	return tempfile.Name(), nil
+	return nil
 }
 }
 
 
-func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
-	bin, err := os.Open(path)
+func sha256sum(fsys fs.FS, name string) (string, error) {
+	bin, err := fsys.Open(name)
 	if err != nil {
 	if err != nil {
 		return "", err
 		return "", err
 	}
 	}
@@ -275,14 +276,25 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
 		return "", err
 		return "", err
 	}
 	}
 
 
-	if _, err := bin.Seek(0, io.SeekStart); err != nil {
+	return fmt.Sprintf("sha256:%x", hash.Sum(nil)), nil
+}
+
+func createBlob(cmd *cobra.Command, client *api.Client, fsys fs.FS, name string) (string, error) {
+	bin, err := fsys.Open(name)
+	if err != nil {
+		return "", err
+	}
+	defer bin.Close()
+
+	digest, err := sha256sum(fsys, name)
+	if err != nil {
 		return "", err
 		return "", err
 	}
 	}
 
 
-	digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
 	if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
 	if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
 		return "", err
 		return "", err
 	}
 	}
+
 	return digest, nil
 	return digest, nil
 }
 }
 
 
@@ -1226,6 +1238,7 @@ func NewCLI() *cobra.Command {
 
 
 	createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile")
 	createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile")
 	createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_0)")
 	createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_0)")
+	createCmd.Flags().StringP("context", "C", "", "Context for the model")
 
 
 	showCmd := &cobra.Command{
 	showCmd := &cobra.Command{
 		Use:     "show MODEL",
 		Use:     "show MODEL",