Michael Yang vor 1 Jahr
Ursprung
Commit
2cc634689b
4 geänderte Dateien mit 80 neuen und 55 gelöschten Zeilen
  1. 25 18
      server/auth.go
  2. 3 2
      server/download.go
  3. 39 30
      server/images.go
  4. 13 5
      server/modelpath.go

+ 25 - 18
server/auth.go

@@ -12,8 +12,10 @@ import (
 	"io"
 	"log"
 	"net/http"
+	"net/url"
 	"os"
 	"path"
+	"strconv"
 	"strings"
 	"time"
 
@@ -43,21 +45,34 @@ func generateNonce(length int) (string, error) {
 	return base64.RawURLEncoding.EncodeToString(nonce), nil
 }
 
-func (r AuthRedirect) URL() (string, error) {
-	nonce, err := generateNonce(16)
+func (r AuthRedirect) URL() (*url.URL, error) {
+	redirectURL, err := url.Parse(r.Realm)
 	if err != nil {
-		return "", err
+		return nil, err
 	}
-	scopes := []string{}
+
+	values := redirectURL.Query()
+
+	values.Add("service", r.Service)
+
 	for _, s := range strings.Split(r.Scope, " ") {
-		scopes = append(scopes, fmt.Sprintf("scope=%s", s))
+		values.Add("scope", s)
 	}
-	scopeStr := strings.Join(scopes, "&")
-	return fmt.Sprintf("%s?service=%s&%s&ts=%d&nonce=%s", r.Realm, r.Service, scopeStr, time.Now().Unix(), nonce), nil
+
+	values.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
+
+	nonce, err := generateNonce(16)
+	if err != nil {
+		return nil, err
+	}
+	values.Add("nonce", nonce)
+
+	redirectURL.RawQuery = values.Encode()
+	return redirectURL, nil
 }
 
 func getAuthToken(ctx context.Context, redirData AuthRedirect, regOpts *RegistryOptions) (string, error) {
-	url, err := redirData.URL()
+	redirectURL, err := redirData.URL()
 	if err != nil {
 		return "", err
 	}
@@ -77,18 +92,10 @@ func getAuthToken(ctx context.Context, redirData AuthRedirect, regOpts *Registry
 
 	s := SignatureData{
 		Method: "GET",
-		Path:   url,
+		Path:   redirectURL.String(),
 		Data:   nil,
 	}
 
-	if !strings.HasPrefix(s.Path, "http") {
-		if regOpts.Insecure {
-			s.Path = "http://" + url
-		} else {
-			s.Path = "https://" + url
-		}
-	}
-
 	sig, err := s.Sign(rawKey)
 	if err != nil {
 		return "", err
@@ -96,7 +103,7 @@ func getAuthToken(ctx context.Context, redirData AuthRedirect, regOpts *Registry
 
 	headers := make(http.Header)
 	headers.Set("Authorization", sig)
-	resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts)
+	resp, err := makeRequest(ctx, "GET", redirectURL, headers, nil, regOpts)
 	if err != nil {
 		log.Printf("couldn't get token: %q", err)
 	}

+ 3 - 2
server/download.go

@@ -155,12 +155,13 @@ func doDownload(ctx context.Context, opts downloadOpts, f *FileDownload) error {
 		}
 	}
 
-	url := fmt.Sprintf("%s/v2/%s/blobs/%s", opts.mp.Registry, opts.mp.GetNamespaceRepository(), f.Digest)
+	requestURL := opts.mp.BaseURL()
+	requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", f.Digest)
 
 	headers := make(http.Header)
 	headers.Set("Range", fmt.Sprintf("bytes=%d-", size))
 
-	resp, err := makeRequest(ctx, "GET", url, headers, nil, opts.regOpts)
+	resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts.regOpts)
 	if err != nil {
 		log.Printf("couldn't download blob: %v", err)
 		return fmt.Errorf("%w: %w", errDownload, err)

+ 39 - 30
server/images.go

@@ -12,6 +12,7 @@ import (
 	"io"
 	"log"
 	"net/http"
+	"net/url"
 	"os"
 	"path"
 	"path/filepath"
@@ -961,8 +962,8 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
 			return err
 		}
 
-		if strings.HasPrefix(path.Base(location), "sha256:") {
-			layer.Digest = path.Base(location)
+		if strings.HasPrefix(path.Base(location.Path), "sha256:") {
+			layer.Digest = path.Base(location.Path)
 			fn(api.ProgressResponse{
 				Status:    "using existing layer",
 				Digest:    layer.Digest,
@@ -979,7 +980,8 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
 	}
 
 	fn(api.ProgressResponse{Status: "pushing manifest"})
-	url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
+	requestURL := mp.BaseURL()
+	requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
 
 	manifestJSON, err := json.Marshal(manifest)
 	if err != nil {
@@ -988,7 +990,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
 
 	headers := make(http.Header)
 	headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
-	resp, err := makeRequestWithRetry(ctx, "PUT", url, headers, bytes.NewReader(manifestJSON), regOpts)
+	resp, err := makeRequestWithRetry(ctx, "PUT", requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
 	if err != nil {
 		return err
 	}
@@ -1072,11 +1074,11 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
 }
 
 func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
-	url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
+	requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
 
 	headers := make(http.Header)
 	headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
-	resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts)
+	resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, regOpts)
 	if err != nil {
 		log.Printf("couldn't get manifest: %v", err)
 		return nil, err
@@ -1137,33 +1139,38 @@ func GetSHA256Digest(r io.Reader) (string, int) {
 
 type requestContextKey string
 
-func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (string, error) {
-	url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository())
+func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (*url.URL, error) {
+	requestURL := mp.BaseURL()
+	requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
 	if layer.From != "" {
-		url = fmt.Sprintf("%s/v2/%s/blobs/uploads/?mount=%s&from=%s", mp.Registry, mp.GetNamespaceRepository(), layer.Digest, layer.From)
+		values := requestURL.Query()
+		values.Add("mount", layer.Digest)
+		values.Add("from", layer.From)
+		requestURL.RawQuery = values.Encode()
 	}
 
-	resp, err := makeRequestWithRetry(ctx, "POST", url, nil, nil, regOpts)
+	resp, err := makeRequestWithRetry(ctx, "POST", requestURL, nil, nil, regOpts)
 	if err != nil {
 		log.Printf("couldn't start upload: %v", err)
-		return "", err
+		return nil, err
 	}
 	defer resp.Body.Close()
 
 	// Extract UUID location from header
 	location := resp.Header.Get("Location")
 	if location == "" {
-		return "", fmt.Errorf("location header is missing in response")
+		return nil, fmt.Errorf("location header is missing in response")
 	}
 
-	return location, nil
+	return url.Parse(location)
 }
 
 // Function to check if a blob already exists in the Docker registry
 func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
-	url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), digest)
+	requestURL := mp.BaseURL()
+	requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", digest)
 
-	resp, err := makeRequest(ctx, "HEAD", url, nil, nil, regOpts)
+	resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, regOpts)
 	if err != nil {
 		log.Printf("couldn't check for blob: %v", err)
 		return false, err
@@ -1174,7 +1181,7 @@ func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpt
 	return resp.StatusCode == http.StatusOK, nil
 }
 
-func uploadBlobChunked(ctx context.Context, mp ModelPath, url string, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
+func uploadBlobChunked(ctx context.Context, mp ModelPath, requestURL *url.URL, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
 	// TODO allow resumability
 	// TODO allow canceling uploads via DELETE
 
@@ -1204,7 +1211,7 @@ func uploadBlobChunked(ctx context.Context, mp ModelPath, url string, layer *Lay
 		headers.Set("Content-Type", "application/octet-stream")
 		headers.Set("Content-Length", strconv.Itoa(int(chunk)))
 		headers.Set("Content-Range", fmt.Sprintf("%d-%d", completed, completed+sectionReader.Size()-1))
-		resp, err := makeRequestWithRetry(ctx, "PATCH", url, headers, sectionReader, regOpts)
+		resp, err := makeRequestWithRetry(ctx, "PATCH", requestURL, headers, sectionReader, regOpts)
 		if err != nil && !errors.Is(err, io.EOF) {
 			fn(api.ProgressResponse{
 				Status:    fmt.Sprintf("error uploading chunk: %v", err),
@@ -1225,20 +1232,26 @@ func uploadBlobChunked(ctx context.Context, mp ModelPath, url string, layer *Lay
 			Completed: int(completed),
 		})
 
-		url = resp.Header.Get("Location")
+		requestURL, err = url.Parse(resp.Header.Get("Location"))
+		if err != nil {
+			return err
+		}
+
 		if completed >= int64(layer.Size) {
 			break
 		}
 	}
 
-	url = fmt.Sprintf("%s&digest=%s", url, layer.Digest)
+	values := requestURL.Query()
+	values.Add("digest", layer.Digest)
+	requestURL.RawQuery = values.Encode()
 
 	headers := make(http.Header)
 	headers.Set("Content-Type", "application/octet-stream")
 	headers.Set("Content-Length", "0")
 
 	// finish the upload
-	resp, err := makeRequest(ctx, "PUT", url, headers, nil, regOpts)
+	resp, err := makeRequest(ctx, "PUT", requestURL, headers, nil, regOpts)
 	if err != nil {
 		log.Printf("couldn't finish upload: %v", err)
 		return err
@@ -1252,10 +1265,10 @@ func uploadBlobChunked(ctx context.Context, mp ModelPath, url string, layer *Lay
 	return nil
 }
 
-func makeRequestWithRetry(ctx context.Context, method, url string, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
+func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
 	var status string
 	for try := 0; try < MaxRetries; try++ {
-		resp, err := makeRequest(ctx, method, url, headers, body, regOpts)
+		resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
 		if err != nil {
 			log.Printf("couldn't start upload: %v", err)
 			return nil, err
@@ -1291,16 +1304,12 @@ func makeRequestWithRetry(ctx context.Context, method, url string, headers http.
 	return nil, fmt.Errorf("max retry exceeded: %v", status)
 }
 
-func makeRequest(ctx context.Context, method, url string, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
-	if !strings.HasPrefix(url, "http") {
-		if regOpts.Insecure {
-			url = "http://" + url
-		} else {
-			url = "https://" + url
-		}
+func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
+	if requestURL.Scheme != "http" && regOpts.Insecure {
+		requestURL.Scheme = "http"
 	}
 
-	req, err := http.NewRequestWithContext(ctx, method, url, body)
+	req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body)
 	if err != nil {
 		return nil, err
 	}

+ 13 - 5
server/modelpath.go

@@ -3,6 +3,7 @@ package server
 import (
 	"errors"
 	"fmt"
+	"net/url"
 	"os"
 	"path/filepath"
 	"runtime"
@@ -39,13 +40,13 @@ func ParseModelPath(name string) ModelPath {
 		Tag:            DefaultTag,
 	}
 
-	parts := strings.Split(name, "://")
-	if len(parts) > 1 {
-		mp.ProtocolScheme = parts[0]
-		name = parts[1]
+	before, after, found := strings.Cut(name, "://")
+	if found {
+		mp.ProtocolScheme = before
+		name = after
 	}
 
-	parts = strings.Split(name, "/")
+	parts := strings.Split(name, "/")
 	switch len(parts) {
 	case 3:
 		mp.Registry = parts[0]
@@ -100,6 +101,13 @@ func (mp ModelPath) GetManifestPath(createDir bool) (string, error) {
 	return path, nil
 }
 
+func (mp ModelPath) BaseURL() *url.URL {
+	return &url.URL{
+		Scheme: mp.ProtocolScheme,
+		Host:   mp.Registry,
+	}
+}
+
 func GetManifestPath() (string, error) {
 	home, err := os.UserHomeDir()
 	if err != nil {