Ver Fonte

allow pushing/pulling to insecure registries

Patrick Devine há 1 ano atrás
pai
commit
92e30a62ee
5 ficheiros alterados com 86 adições e 39 exclusões
  1. 2 0
      api/types.go
  2. 22 8
      cmd/cmd.go
  3. 42 26
      server/images.go
  4. 6 3
      server/modelpath.go
  5. 14 2
      server/routes.go

+ 2 - 0
api/types.go

@@ -43,6 +43,7 @@ type DeleteRequest struct {
 
 
 type PullRequest struct {
 type PullRequest struct {
 	Name     string `json:"name"`
 	Name     string `json:"name"`
+	Insecure bool   `json:"insecure,omitempty"`
 	Username string `json:"username"`
 	Username string `json:"username"`
 	Password string `json:"password"`
 	Password string `json:"password"`
 }
 }
@@ -56,6 +57,7 @@ type ProgressResponse struct {
 
 
 type PushRequest struct {
 type PushRequest struct {
 	Name     string `json:"name"`
 	Name     string `json:"name"`
+	Insecure bool   `json:"insecure,omitempty"`
 	Username string `json:"username"`
 	Username string `json:"username"`
 	Password string `json:"password"`
 	Password string `json:"password"`
 }
 }

+ 22 - 8
cmd/cmd.go

@@ -69,7 +69,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
 	_, err = os.Stat(fp)
 	_, err = os.Stat(fp)
 	switch {
 	switch {
 	case errors.Is(err, os.ErrNotExist):
 	case errors.Is(err, os.ErrNotExist):
-		if err := pull(args[0]); err != nil {
+		if err := pull(args[0], false); err != nil {
 			var apiStatusError api.StatusError
 			var apiStatusError api.StatusError
 			if !errors.As(err, &apiStatusError) {
 			if !errors.As(err, &apiStatusError) {
 				return err
 				return err
@@ -89,7 +89,12 @@ func RunHandler(cmd *cobra.Command, args []string) error {
 func PushHandler(cmd *cobra.Command, args []string) error {
 func PushHandler(cmd *cobra.Command, args []string) error {
 	client := api.NewClient()
 	client := api.NewClient()
 
 
-	request := api.PushRequest{Name: args[0]}
+	insecure, err := cmd.Flags().GetBool("insecure")
+	if err != nil {
+		return err
+	}
+
+	request := api.PushRequest{Name: args[0], Insecure: insecure}
 	fn := func(resp api.ProgressResponse) error {
 	fn := func(resp api.ProgressResponse) error {
 		fmt.Println(resp.Status)
 		fmt.Println(resp.Status)
 		return nil
 		return nil
@@ -147,16 +152,21 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
 }
 }
 
 
 func PullHandler(cmd *cobra.Command, args []string) error {
 func PullHandler(cmd *cobra.Command, args []string) error {
-	return pull(args[0])
+	insecure, err := cmd.Flags().GetBool("insecure")
+	if err != nil {
+		return err
+	}
+
+	return pull(args[0], insecure)
 }
 }
 
 
-func pull(model string) error {
+func pull(model string, insecure bool) error {
 	client := api.NewClient()
 	client := api.NewClient()
 
 
 	var currentDigest string
 	var currentDigest string
 	var bar *progressbar.ProgressBar
 	var bar *progressbar.ProgressBar
 
 
-	request := api.PullRequest{Name: model}
+	request := api.PullRequest{Name: model, Insecure: insecure}
 	fn := func(resp api.ProgressResponse) error {
 	fn := func(resp api.ProgressResponse) error {
 		if resp.Digest != currentDigest && resp.Digest != "" {
 		if resp.Digest != currentDigest && resp.Digest != "" {
 			currentDigest = resp.Digest
 			currentDigest = resp.Digest
@@ -430,6 +440,8 @@ func NewCLI() *cobra.Command {
 		RunE:  PullHandler,
 		RunE:  PullHandler,
 	}
 	}
 
 
+	pullCmd.Flags().Bool("insecure", false, "Use an insecure registry")
+
 	pushCmd := &cobra.Command{
 	pushCmd := &cobra.Command{
 		Use:   "push MODEL",
 		Use:   "push MODEL",
 		Short: "Push a model to a registry",
 		Short: "Push a model to a registry",
@@ -437,11 +449,13 @@ func NewCLI() *cobra.Command {
 		RunE:  PushHandler,
 		RunE:  PushHandler,
 	}
 	}
 
 
+	pushCmd.Flags().Bool("insecure", false, "Use an insecure registry")
+
 	listCmd := &cobra.Command{
 	listCmd := &cobra.Command{
-		Use:   "list",
+		Use:     "list",
 		Aliases: []string{"ls"},
 		Aliases: []string{"ls"},
-		Short: "List models",
-		RunE:  ListHandler,
+		Short:   "List models",
+		RunE:    ListHandler,
 	}
 	}
 
 
 	deleteCmd := &cobra.Command{
 	deleteCmd := &cobra.Command{

+ 42 - 26
server/images.go

@@ -22,6 +22,12 @@ import (
 	"github.com/jmorganca/ollama/parser"
 	"github.com/jmorganca/ollama/parser"
 )
 )
 
 
+type RegistryOptions struct {
+	Insecure bool
+	Username string
+	Password string
+}
+
 type Model struct {
 type Model struct {
 	Name      string `json:"name"`
 	Name      string `json:"name"`
 	ModelPath string
 	ModelPath string
@@ -564,7 +570,7 @@ func DeleteModel(name string, fn func(api.ProgressResponse)) error {
 	return nil
 	return nil
 }
 }
 
 
-func PushModel(name, username, password string, fn func(api.ProgressResponse)) error {
+func PushModel(name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
 	mp := ParseModelPath(name)
 	mp := ParseModelPath(name)
 
 
 	fn(api.ProgressResponse{Status: "retrieving manifest"})
 	fn(api.ProgressResponse{Status: "retrieving manifest"})
@@ -586,7 +592,7 @@ func PushModel(name, username, password string, fn func(api.ProgressResponse)) e
 	total += manifest.Config.Size
 	total += manifest.Config.Size
 
 
 	for _, layer := range layers {
 	for _, layer := range layers {
-		exists, err := checkBlobExistence(mp, layer.Digest, username, password)
+		exists, err := checkBlobExistence(mp, layer.Digest, regOpts)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
@@ -609,13 +615,13 @@ func PushModel(name, username, password string, fn func(api.ProgressResponse)) e
 			Completed: completed,
 			Completed: completed,
 		})
 		})
 
 
-		location, err := startUpload(mp, username, password)
+		location, err := startUpload(mp, regOpts)
 		if err != nil {
 		if err != nil {
 			log.Printf("couldn't start upload: %v", err)
 			log.Printf("couldn't start upload: %v", err)
 			return err
 			return err
 		}
 		}
 
 
-		err = uploadBlob(location, layer, username, password)
+		err = uploadBlob(location, layer, regOpts)
 		if err != nil {
 		if err != nil {
 			log.Printf("error uploading blob: %v", err)
 			log.Printf("error uploading blob: %v", err)
 			return err
 			return err
@@ -634,7 +640,7 @@ func PushModel(name, username, password string, fn func(api.ProgressResponse)) e
 		Total:     total,
 		Total:     total,
 		Completed: completed,
 		Completed: completed,
 	})
 	})
-	url := fmt.Sprintf("%s://%s/v2/%s/manifests/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
+	url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
 	headers := map[string]string{
 	headers := map[string]string{
 		"Content-Type": "application/vnd.docker.distribution.manifest.v2+json",
 		"Content-Type": "application/vnd.docker.distribution.manifest.v2+json",
 	}
 	}
@@ -644,7 +650,7 @@ func PushModel(name, username, password string, fn func(api.ProgressResponse)) e
 		return err
 		return err
 	}
 	}
 
 
-	resp, err := makeRequest("PUT", url, headers, bytes.NewReader(manifestJSON), username, password)
+	resp, err := makeRequest("PUT", url, headers, bytes.NewReader(manifestJSON), regOpts)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -665,12 +671,12 @@ func PushModel(name, username, password string, fn func(api.ProgressResponse)) e
 	return nil
 	return nil
 }
 }
 
 
-func PullModel(name, username, password string, fn func(api.ProgressResponse)) error {
+func PullModel(name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
 	mp := ParseModelPath(name)
 	mp := ParseModelPath(name)
 
 
 	fn(api.ProgressResponse{Status: "pulling manifest"})
 	fn(api.ProgressResponse{Status: "pulling manifest"})
 
 
-	manifest, err := pullModelManifest(mp, username, password)
+	manifest, err := pullModelManifest(mp, regOpts)
 	if err != nil {
 	if err != nil {
 		return fmt.Errorf("pull model manifest: %q", err)
 		return fmt.Errorf("pull model manifest: %q", err)
 	}
 	}
@@ -680,7 +686,7 @@ func PullModel(name, username, password string, fn func(api.ProgressResponse)) e
 	layers = append(layers, &manifest.Config)
 	layers = append(layers, &manifest.Config)
 
 
 	for _, layer := range layers {
 	for _, layer := range layers {
-		if err := downloadBlob(mp, layer.Digest, username, password, fn); err != nil {
+		if err := downloadBlob(mp, layer.Digest, regOpts, fn); err != nil {
 			return err
 			return err
 		}
 		}
 	}
 	}
@@ -715,13 +721,13 @@ func PullModel(name, username, password string, fn func(api.ProgressResponse)) e
 	return nil
 	return nil
 }
 }
 
 
-func pullModelManifest(mp ModelPath, username, password string) (*ManifestV2, error) {
-	url := fmt.Sprintf("%s://%s/v2/%s/manifests/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
+func pullModelManifest(mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
+	url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
 	headers := map[string]string{
 	headers := map[string]string{
 		"Accept": "application/vnd.docker.distribution.manifest.v2+json",
 		"Accept": "application/vnd.docker.distribution.manifest.v2+json",
 	}
 	}
 
 
-	resp, err := makeRequest("GET", url, headers, nil, username, password)
+	resp, err := makeRequest("GET", url, headers, nil, regOpts)
 	if err != nil {
 	if err != nil {
 		log.Printf("couldn't get manifest: %v", err)
 		log.Printf("couldn't get manifest: %v", err)
 		return nil, err
 		return nil, err
@@ -782,10 +788,10 @@ func GetSHA256Digest(r io.Reader) (string, int) {
 	return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n)
 	return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n)
 }
 }
 
 
-func startUpload(mp ModelPath, username string, password string) (string, error) {
-	url := fmt.Sprintf("%s://%s/v2/%s/blobs/uploads/", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository())
+func startUpload(mp ModelPath, regOpts *RegistryOptions) (string, error) {
+	url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository())
 
 
-	resp, err := makeRequest("POST", url, nil, nil, username, password)
+	resp, err := makeRequest("POST", url, nil, nil, regOpts)
 	if err != nil {
 	if err != nil {
 		log.Printf("couldn't start upload: %v", err)
 		log.Printf("couldn't start upload: %v", err)
 		return "", err
 		return "", err
@@ -808,10 +814,10 @@ func startUpload(mp ModelPath, username string, password string) (string, error)
 }
 }
 
 
 // Function to check if a blob already exists in the Docker registry
 // Function to check if a blob already exists in the Docker registry
-func checkBlobExistence(mp ModelPath, digest string, username string, password string) (bool, error) {
-	url := fmt.Sprintf("%s://%s/v2/%s/blobs/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), digest)
+func checkBlobExistence(mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
+	url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), digest)
 
 
-	resp, err := makeRequest("HEAD", url, nil, nil, username, password)
+	resp, err := makeRequest("HEAD", url, nil, nil, regOpts)
 	if err != nil {
 	if err != nil {
 		log.Printf("couldn't check for blob: %v", err)
 		log.Printf("couldn't check for blob: %v", err)
 		return false, err
 		return false, err
@@ -822,7 +828,7 @@ func checkBlobExistence(mp ModelPath, digest string, username string, password s
 	return resp.StatusCode == http.StatusOK, nil
 	return resp.StatusCode == http.StatusOK, nil
 }
 }
 
 
-func uploadBlob(location string, layer *Layer, username string, password string) error {
+func uploadBlob(location string, layer *Layer, regOpts *RegistryOptions) error {
 	// Create URL
 	// Create URL
 	url := fmt.Sprintf("%s&digest=%s", location, layer.Digest)
 	url := fmt.Sprintf("%s&digest=%s", location, layer.Digest)
 
 
@@ -845,7 +851,7 @@ func uploadBlob(location string, layer *Layer, username string, password string)
 		return err
 		return err
 	}
 	}
 
 
-	resp, err := makeRequest("PUT", url, headers, f, username, password)
+	resp, err := makeRequest("PUT", url, headers, f, regOpts)
 	if err != nil {
 	if err != nil {
 		log.Printf("couldn't upload blob: %v", err)
 		log.Printf("couldn't upload blob: %v", err)
 		return err
 		return err
@@ -861,7 +867,7 @@ func uploadBlob(location string, layer *Layer, username string, password string)
 	return nil
 	return nil
 }
 }
 
 
-func downloadBlob(mp ModelPath, digest string, username, password string, fn func(api.ProgressResponse)) error {
+func downloadBlob(mp ModelPath, digest string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
 	fp, err := GetBlobsPath(digest)
 	fp, err := GetBlobsPath(digest)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -890,12 +896,12 @@ func downloadBlob(mp ModelPath, digest string, username, password string, fn fun
 		size = fi.Size()
 		size = fi.Size()
 	}
 	}
 
 
-	url := fmt.Sprintf("%s://%s/v2/%s/blobs/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), digest)
+	url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), digest)
 	headers := map[string]string{
 	headers := map[string]string{
 		"Range": fmt.Sprintf("bytes=%d-", size),
 		"Range": fmt.Sprintf("bytes=%d-", size),
 	}
 	}
 
 
-	resp, err := makeRequest("GET", url, headers, nil, username, password)
+	resp, err := makeRequest("GET", url, headers, nil, regOpts)
 	if err != nil {
 	if err != nil {
 		log.Printf("couldn't download blob: %v", err)
 		log.Printf("couldn't download blob: %v", err)
 		return err
 		return err
@@ -959,7 +965,17 @@ func downloadBlob(mp ModelPath, digest string, username, password string, fn fun
 	return nil
 	return nil
 }
 }
 
 
-func makeRequest(method, url string, headers map[string]string, body io.Reader, username, password string) (*http.Response, error) {
+func makeRequest(method, url string, headers map[string]string, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
+	if !strings.HasPrefix(url, "http") {
+		if regOpts.Insecure {
+			url = "http://" + url
+		} else {
+			url = "https://" + url
+		}
+	}
+
+	log.Printf("url = %s", url)
+
 	req, err := http.NewRequest(method, url, body)
 	req, err := http.NewRequest(method, url, body)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
@@ -970,8 +986,8 @@ func makeRequest(method, url string, headers map[string]string, body io.Reader,
 	}
 	}
 
 
 	// TODO: better auth
 	// TODO: better auth
-	if username != "" && password != "" {
-		req.SetBasicAuth(username, password)
+	if regOpts.Username != "" && regOpts.Password != "" {
+		req.SetBasicAuth(regOpts.Username, regOpts.Password)
 	}
 	}
 
 
 	client := &http.Client{
 	client := &http.Client{

+ 6 - 3
server/modelpath.go

@@ -70,10 +70,13 @@ func (mp ModelPath) GetFullTagname() string {
 }
 }
 
 
 func (mp ModelPath) GetShortTagname() string {
 func (mp ModelPath) GetShortTagname() string {
-	if mp.Registry == DefaultRegistry && mp.Namespace == DefaultNamespace {
-		return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
+	if mp.Registry == DefaultRegistry {
+		if mp.Namespace == DefaultNamespace {
+			return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
+		}
+		return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
 	}
 	}
-	return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
+	return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
 }
 }
 
 
 func (mp ModelPath) GetManifestPath(createDir bool) (string, error) {
 func (mp ModelPath) GetManifestPath(createDir bool) (string, error) {

+ 14 - 2
server/routes.go

@@ -92,7 +92,13 @@ func PullModelHandler(c *gin.Context) {
 			ch <- r
 			ch <- r
 		}
 		}
 
 
-		if err := PullModel(req.Name, req.Username, req.Password, fn); err != nil {
+		regOpts := &RegistryOptions{
+			Insecure: req.Insecure,
+			Username: req.Username,
+			Password: req.Password,
+		}
+
+		if err := PullModel(req.Name, regOpts, fn); err != nil {
 			ch <- gin.H{"error": err.Error()}
 			ch <- gin.H{"error": err.Error()}
 		}
 		}
 	}()
 	}()
@@ -114,7 +120,13 @@ func PushModelHandler(c *gin.Context) {
 			ch <- r
 			ch <- r
 		}
 		}
 
 
-		if err := PushModel(req.Name, req.Username, req.Password, fn); err != nil {
+		regOpts := &RegistryOptions{
+			Insecure: req.Insecure,
+			Username: req.Username,
+			Password: req.Password,
+		}
+
+		if err := PushModel(req.Name, regOpts, fn); err != nil {
 			ch <- gin.H{"error": err.Error()}
 			ch <- gin.H{"error": err.Error()}
 		}
 		}
 	}()
 	}()