Browse Source

Merge pull request #1229 from jmorganca/mxyng/calculate-as-you-go

revert checksum calculation to calculate-as-you-go
Michael Yang 1 year ago
parent
commit
b56e92470a
1 changed files with 37 additions and 33 deletions
  1. 37 33
      server/upload.go

+ 37 - 33
server/upload.go

@@ -5,6 +5,7 @@ import (
 	"crypto/md5"
 	"errors"
 	"fmt"
+	"hash"
 	"io"
 	"log"
 	"math"
@@ -102,7 +103,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *Reg
 		}
 
 		// set part.N to the current number of parts
-		b.Parts = append(b.Parts, blobUploadPart{blobUpload: b, N: len(b.Parts), Offset: offset, Size: size})
+		b.Parts = append(b.Parts, blobUploadPart{N: len(b.Parts), Offset: offset, Size: size})
 		offset += size
 	}
 
@@ -147,14 +148,13 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
 			g.Go(func() error {
 				var err error
 				for try := 0; try < maxRetries; try++ {
-					err = b.uploadChunk(inner, http.MethodPatch, requestURL, part, opts)
+					err = b.uploadPart(inner, http.MethodPatch, requestURL, part, opts)
 					switch {
 					case errors.Is(err, context.Canceled):
 						return err
 					case errors.Is(err, errMaxRetriesExceeded):
 						return err
 					case err != nil:
-						part.Reset()
 						sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
 						log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)
 						time.Sleep(sleep)
@@ -176,17 +176,10 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
 
 	requestURL := <-b.nextURL
 
-	var sb strings.Builder
-
 	// calculate md5 checksum and add it to the commit request
+	var sb strings.Builder
 	for _, part := range b.Parts {
-		hash := md5.New()
-		if _, err := io.Copy(hash, io.NewSectionReader(b.file, part.Offset, part.Size)); err != nil {
-			b.err = err
-			return
-		}
-
-		sb.Write(hash.Sum(nil))
+		sb.Write(part.Sum(nil))
 	}
 
 	md5sum := md5.Sum([]byte(sb.String()))
@@ -201,27 +194,25 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
 	headers.Set("Content-Length", "0")
 
 	for try := 0; try < maxRetries; try++ {
-		resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts)
-		if err != nil {
-			b.err = err
-			if errors.Is(err, context.Canceled) {
-				return
-			}
-
+		var resp *http.Response
+		resp, err = makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts)
+		if errors.Is(err, context.Canceled) {
+			break
+		} else if err != nil {
 			sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
 			log.Printf("%s complete upload attempt %d failed: %v, retrying in %s", b.Digest[7:19], try, err, sleep)
 			time.Sleep(sleep)
 			continue
 		}
 		defer resp.Body.Close()
-
-		b.err = nil
-		b.done = true
-		return
+		break
 	}
+
+	b.err = err
+	b.done = true
 }
 
-func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error {
+func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error {
 	headers := make(http.Header)
 	headers.Set("Content-Type", "application/octet-stream")
 	headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
@@ -232,8 +223,13 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
 	}
 
 	sr := io.NewSectionReader(b.file, part.Offset, part.Size)
-	resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, part), opts)
+
+	md5sum := md5.New()
+	w := &progressWriter{blobUpload: b}
+
+	resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts)
 	if err != nil {
+		w.Rollback()
 		return err
 	}
 	defer resp.Body.Close()
@@ -245,11 +241,13 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
 
 	nextURL, err := url.Parse(location)
 	if err != nil {
+		w.Rollback()
 		return err
 	}
 
 	switch {
 	case resp.StatusCode == http.StatusTemporaryRedirect:
+		w.Rollback()
 		b.nextURL <- nextURL
 
 		redirectURL, err := resp.Location()
@@ -259,14 +257,13 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
 
 		// retry uploading to the redirect URL
 		for try := 0; try < maxRetries; try++ {
-			err = b.uploadChunk(ctx, http.MethodPut, redirectURL, part, nil)
+			err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, nil)
 			switch {
 			case errors.Is(err, context.Canceled):
 				return err
 			case errors.Is(err, errMaxRetriesExceeded):
 				return err
 			case err != nil:
-				part.Reset()
 				sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
 				log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)
 				time.Sleep(sleep)
@@ -279,6 +276,7 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
 		return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
 
 	case resp.StatusCode == http.StatusUnauthorized:
+		w.Rollback()
 		auth := resp.Header.Get("www-authenticate")
 		authRedir := ParseAuthRedirectString(auth)
 		token, err := getAuthToken(ctx, authRedir)
@@ -289,6 +287,7 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
 		opts.Token = token
 		fallthrough
 	case resp.StatusCode >= http.StatusBadRequest:
+		w.Rollback()
 		body, err := io.ReadAll(resp.Body)
 		if err != nil {
 			return err
@@ -301,6 +300,7 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
 		b.nextURL <- nextURL
 	}
 
+	part.Hash = md5sum
 	return nil
 }
 
@@ -341,22 +341,26 @@ func (b *blobUpload) Wait(ctx context.Context, fn func(api.ProgressResponse)) er
 
 type blobUploadPart struct {
 	// N is the part number
-	N       int
-	Offset  int64
-	Size    int64
+	N      int
+	Offset int64
+	Size   int64
+	hash.Hash
+}
+
+type progressWriter struct {
 	written int64
 	*blobUpload
 }
 
-func (p *blobUploadPart) Write(b []byte) (n int, err error) {
+func (p *progressWriter) Write(b []byte) (n int, err error) {
 	n = len(b)
 	p.written += int64(n)
 	p.Completed.Add(int64(n))
 	return n, nil
 }
 
-func (p *blobUploadPart) Reset() {
-	p.Completed.Add(-int64(p.written))
+func (p *progressWriter) Rollback() {
+	p.Completed.Add(-p.written)
 	p.written = 0
 }