Explorar o código

add modelpaths

Patrick Devine hai 1 ano
pai
achega
db961934dc
Modificáronse 3 ficheiros con 147 adicións e 80 borrados
  1. 7 1
      cmd/cmd.go
  2. 34 79
      server/images.go
  3. 106 0
      server/modelpath.go

+ 7 - 1
cmd/cmd.go

@@ -48,7 +48,13 @@ func create(cmd *cobra.Command, args []string) error {
 }
 
 func RunRun(cmd *cobra.Command, args []string) error {
-	_, err := os.Stat(args[0])
+	mp := server.ParseModelPath(args[0])
+	fp, err := mp.GetManifestPath(false)
+	if err != nil {
+		return err
+	}
+
+	_, err = os.Stat(fp)
 	switch {
 	case errors.Is(err, os.ErrNotExist):
 		if err := pull(args[0]); err != nil {

+ 34 - 79
server/images.go

@@ -22,8 +22,6 @@ import (
 	"github.com/jmorganca/ollama/parser"
 )
 
-var DefaultRegistry string = "https://registry.ollama.ai"
-
 type Model struct {
 	Name      string `json:"name"`
 	ModelPath string
@@ -61,27 +59,13 @@ type RootFS struct {
 	DiffIDs []string `json:"diff_ids"`
 }
 
-func modelsDir(part ...string) (string, error) {
-	home, err := os.UserHomeDir()
-	if err != nil {
-		return "", err
-	}
-
-	path := filepath.Join(home, ".ollama", "models", filepath.Join(part...))
-	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
-		return "", err
-	}
-
-	return path, nil
-}
-
-func GetManifest(name string) (*ManifestV2, error) {
-	fp, err := modelsDir("manifests", name)
+func GetManifest(mp ModelPath) (*ManifestV2, error) {
+	fp, err := mp.GetManifestPath(false)
 	if err != nil {
 		return nil, err
 	}
 	if _, err = os.Stat(fp); err != nil && !errors.Is(err, os.ErrNotExist) {
-		return nil, fmt.Errorf("couldn't find model '%s'", name)
+		return nil, fmt.Errorf("couldn't find model '%s'", mp.GetShortTagname())
 	}
 
 	var manifest *ManifestV2
@@ -101,17 +85,19 @@ func GetManifest(name string) (*ManifestV2, error) {
 }
 
 func GetModel(name string) (*Model, error) {
-	manifest, err := GetManifest(name)
+	mp := ParseModelPath(name)
+
+	manifest, err := GetManifest(mp)
 	if err != nil {
 		return nil, err
 	}
 
 	model := &Model{
-		Name: name,
+		Name: mp.GetFullTagname(),
 	}
 
 	for _, layer := range manifest.Layers {
-		filename, err := modelsDir("blobs", layer.Digest)
+		filename, err := GetBlobsPath(layer.Digest)
 		if err != nil {
 			return nil, err
 		}
@@ -174,7 +160,7 @@ func CreateModel(name string, mf io.Reader, fn func(status string)) error {
 		switch c.Name {
 		case "model":
 			fn("looking for model")
-			mf, err := GetManifest(c.Arg)
+			mf, err := GetManifest(ParseModelPath(c.Arg))
 			if err != nil {
 				// if we couldn't read the manifest, try getting the bin file
 				fp, err := getAbsPath(c.Arg)
@@ -293,7 +279,7 @@ func removeLayerFromLayers(layers []*LayerWithBuffer, mediaType string) []*Layer
 func SaveLayers(layers []*LayerWithBuffer, fn func(status string), force bool) error {
 	// Write each of the layers to disk
 	for _, layer := range layers {
-		fp, err := modelsDir("blobs", layer.Digest)
+		fp, err := GetBlobsPath(layer.Digest)
 		if err != nil {
 			return err
 		}
@@ -321,6 +307,8 @@ func SaveLayers(layers []*LayerWithBuffer, fn func(status string), force bool) e
 }
 
 func CreateManifest(name string, cfg *LayerWithBuffer, layers []*Layer) error {
+	mp := ParseModelPath(name)
+
 	manifest := ManifestV2{
 		SchemaVersion: 2,
 		MediaType:     "application/vnd.docker.distribution.manifest.v2+json",
@@ -337,7 +325,7 @@ func CreateManifest(name string, cfg *LayerWithBuffer, layers []*Layer) error {
 		return err
 	}
 
-	fp, err := modelsDir("manifests", name)
+	fp, err := mp.GetManifestPath(true)
 	if err != nil {
 		return err
 	}
@@ -345,7 +333,7 @@ func CreateManifest(name string, cfg *LayerWithBuffer, layers []*Layer) error {
 }
 
 func GetLayerWithBufferFromLayer(layer *Layer) (*LayerWithBuffer, error) {
-	fp, err := modelsDir("blobs", layer.Digest)
+	fp, err := GetBlobsPath(layer.Digest)
 	if err != nil {
 		return nil, err
 	}
@@ -456,28 +444,15 @@ func CreateLayer(f io.Reader) (*LayerWithBuffer, error) {
 }
 
 func PushModel(name, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error {
+	mp := ParseModelPath(name)
+
 	fn("retrieving manifest", "", 0, 0, 0)
-	manifest, err := GetManifest(name)
+	manifest, err := GetManifest(mp)
 	if err != nil {
 		fn("couldn't retrieve manifest", "", 0, 0, 0)
 		return err
 	}
 
-	var repoName string
-	var tag string
-
-	comps := strings.Split(name, ":")
-	switch {
-	case len(comps) < 1 || len(comps) > 2:
-		return fmt.Errorf("repository name was invalid")
-	case len(comps) == 1:
-		repoName = comps[0]
-		tag = "latest"
-	case len(comps) == 2:
-		repoName = comps[0]
-		tag = comps[1]
-	}
-
 	var layers []*Layer
 	var total int
 	var completed int
@@ -489,7 +464,7 @@ func PushModel(name, username, password string, fn func(status, digest string, T
 	total += manifest.Config.Size
 
 	for _, layer := range layers {
-		exists, err := checkBlobExistence(DefaultRegistry, repoName, layer.Digest, username, password)
+		exists, err := checkBlobExistence(mp, layer.Digest, username, password)
 		if err != nil {
 			return err
 		}
@@ -502,7 +477,7 @@ func PushModel(name, username, password string, fn func(status, digest string, T
 
 		fn("starting upload", layer.Digest, total, completed, float64(completed)/float64(total))
 
-		location, err := startUpload(DefaultRegistry, repoName, username, password)
+		location, err := startUpload(mp, username, password)
 		if err != nil {
 			log.Printf("couldn't start upload: %v", err)
 			return err
@@ -518,7 +493,7 @@ func PushModel(name, username, password string, fn func(status, digest string, T
 	}
 
 	fn("pushing manifest", "", total, completed, float64(completed/total))
-	url := fmt.Sprintf("%s/v2/%s/manifests/%s", DefaultRegistry, repoName, tag)
+	url := fmt.Sprintf("%s://%s/v2/%s/manifests/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
 	headers := map[string]string{
 		"Content-Type": "application/vnd.docker.distribution.manifest.v2+json",
 	}
@@ -546,30 +521,15 @@ func PushModel(name, username, password string, fn func(status, digest string, T
 }
 
 func PullModel(name, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error {
-	var repoName string
-	var tag string
-
-	comps := strings.Split(name, ":")
-	switch {
-	case len(comps) < 1 || len(comps) > 2:
-		return fmt.Errorf("repository name was invalid")
-	case len(comps) == 1:
-		repoName = comps[0]
-		tag = "latest"
-	case len(comps) == 2:
-		repoName = comps[0]
-		tag = comps[1]
-	}
+	mp := ParseModelPath(name)
 
 	fn("pulling manifest", "", 0, 0, 0)
 
-	manifest, err := pullModelManifest(DefaultRegistry, repoName, tag, username, password)
+	manifest, err := pullModelManifest(mp, username, password)
 	if err != nil {
 		return fmt.Errorf("pull model manifest: %q", err)
 	}
 
-	log.Printf("manifest = %#v", manifest)
-
 	var layers []*Layer
 	var total int
 	var completed int
@@ -582,7 +542,7 @@ func PullModel(name, username, password string, fn func(status, digest string, T
 
 	for _, layer := range layers {
 		fn("starting download", layer.Digest, total, completed, float64(completed)/float64(total))
-		if err := downloadBlob(DefaultRegistry, repoName, layer.Digest, username, password, fn); err != nil {
+		if err := downloadBlob(mp, layer.Digest, username, password, fn); err != nil {
 			fn(fmt.Sprintf("error downloading: %v", err), layer.Digest, 0, 0, 0)
 			return err
 		}
@@ -597,16 +557,11 @@ func PullModel(name, username, password string, fn func(status, digest string, T
 		return err
 	}
 
-	fp, err := modelsDir("manifests", name)
+	fp, err := mp.GetManifestPath(true)
 	if err != nil {
 		return err
 	}
 
-	err = os.MkdirAll(path.Dir(fp), 0o700)
-	if err != nil {
-		return fmt.Errorf("make manifests directory: %w", err)
-	}
-
 	err = os.WriteFile(fp, manifestJSON, 0644)
 	if err != nil {
 		log.Printf("couldn't write to %s", fp)
@@ -618,8 +573,8 @@ func PullModel(name, username, password string, fn func(status, digest string, T
 	return nil
 }
 
-func pullModelManifest(registryURL, repoName, tag, username, password string) (*ManifestV2, error) {
-	url := fmt.Sprintf("%s/v2/%s/manifests/%s", registryURL, repoName, tag)
+func pullModelManifest(mp ModelPath, username, password string) (*ManifestV2, error) {
+	url := fmt.Sprintf("%s://%s/v2/%s/manifests/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
 	headers := map[string]string{
 		"Accept": "application/vnd.docker.distribution.manifest.v2+json",
 	}
@@ -682,8 +637,8 @@ func GetSHA256Digest(data *bytes.Buffer) (string, int) {
 	return "sha256:" + hex.EncodeToString(hash[:]), len(layerBytes)
 }
 
-func startUpload(registryURL string, repositoryName string, username string, password string) (string, error) {
-	url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", registryURL, repositoryName)
+func startUpload(mp ModelPath, username string, password string) (string, error) {
+	url := fmt.Sprintf("%s://%s/v2/%s/blobs/uploads/", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository())
 
 	resp, err := makeRequest("POST", url, nil, nil, username, password)
 	if err != nil {
@@ -708,8 +663,8 @@ func startUpload(registryURL string, repositoryName string, username string, pas
 }
 
 // Function to check if a blob already exists in the Docker registry
-func checkBlobExistence(registryURL string, repositoryName string, digest string, username string, password string) (bool, error) {
-	url := fmt.Sprintf("%s/v2/%s/blobs/%s", registryURL, repositoryName, digest)
+func checkBlobExistence(mp ModelPath, digest string, username string, password string) (bool, error) {
+	url := fmt.Sprintf("%s://%s/v2/%s/blobs/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), digest)
 
 	resp, err := makeRequest("HEAD", url, nil, nil, username, password)
 	if err != nil {
@@ -735,7 +690,7 @@ func uploadBlob(location string, layer *Layer, username string, password string)
 	// TODO allow canceling uploads via DELETE
 	// TODO allow cross repo blob mount
 
-	fp, err := modelsDir("blobs", layer.Digest)
+	fp, err := GetBlobsPath(layer.Digest)
 	if err != nil {
 		return err
 	}
@@ -761,8 +716,8 @@ func uploadBlob(location string, layer *Layer, username string, password string)
 	return nil
 }
 
-func downloadBlob(registryURL, repoName, digest string, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error {
-	fp, err := modelsDir("blobs", digest)
+func downloadBlob(mp ModelPath, digest string, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error {
+	fp, err := GetBlobsPath(digest)
 	if err != nil {
 		return err
 	}
@@ -786,7 +741,7 @@ func downloadBlob(registryURL, repoName, digest string, username, password strin
 		size = fi.Size()
 	}
 
-	url := fmt.Sprintf("%s/v2/%s/blobs/%s", registryURL, repoName, digest)
+	url := fmt.Sprintf("%s://%s/v2/%s/blobs/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), digest)
 	headers := map[string]string{
 		"Range": fmt.Sprintf("bytes=%d-", size),
 	}

+ 106 - 0
server/modelpath.go

@@ -0,0 +1,106 @@
+package server
+
+import (
+	"fmt"
+	"os"
+	"path/filepath"
+	"strings"
+)
+
+type ModelPath struct {
+	ProtocolScheme string
+	Registry       string
+	Namespace      string
+	Repository     string
+	Tag            string
+}
+
+const (
+	DefaultRegistry       = "registry.ollama.ai"
+	DefaultNamespace      = "library"
+	DefaultTag            = "latest"
+	DefaultProtocolScheme = "https"
+)
+
+func ParseModelPath(name string) ModelPath {
+	slashParts := strings.Split(name, "/")
+	var registry, namespace, repository, tag string
+
+	switch len(slashParts) {
+	case 3:
+		registry = slashParts[0]
+		namespace = slashParts[1]
+		repository = strings.Split(slashParts[2], ":")[0]
+	case 2:
+		registry = DefaultRegistry
+		namespace = slashParts[0]
+		repository = strings.Split(slashParts[1], ":")[0]
+	case 1:
+		registry = DefaultRegistry
+		namespace = DefaultNamespace
+		repository = strings.Split(slashParts[0], ":")[0]
+	default:
+		fmt.Println("Invalid image format.")
+		return ModelPath{}
+	}
+
+	colonParts := strings.Split(name, ":")
+	if len(colonParts) == 2 {
+		tag = colonParts[1]
+	} else {
+		tag = DefaultTag
+	}
+
+	return ModelPath{
+		ProtocolScheme: DefaultProtocolScheme,
+		Registry:       registry,
+		Namespace:      namespace,
+		Repository:     repository,
+		Tag:            tag,
+	}
+}
+
+func (mp ModelPath) GetNamespaceRepository() string {
+	return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
+}
+
+func (mp ModelPath) GetFullTagname() string {
+	return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
+}
+
+func (mp ModelPath) GetShortTagname() string {
+	if mp.Registry == DefaultRegistry && mp.Namespace == DefaultNamespace {
+		return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
+	}
+	return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
+}
+
+func (mp ModelPath) GetManifestPath(createDir bool) (string, error) {
+	home, err := os.UserHomeDir()
+	if err != nil {
+		return "", err
+	}
+
+	path := filepath.Join(home, ".ollama", "models", "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
+	if createDir {
+		if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
+			return "", err
+		}
+	}
+
+	return path, nil
+}
+
+func GetBlobsPath(digest string) (string, error) {
+	home, err := os.UserHomeDir()
+	if err != nil {
+		return "", err
+	}
+
+	path := filepath.Join(home, ".ollama", "models", "blobs")
+	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
+		return "", err
+	}
+
+	return filepath.Join(path, digest), nil
+}