瀏覽代碼

revert checksum calculation to calculate-as-you-go

Michael Yang 1 年之前
父節點
當前提交
2799784ac8
共有 1 個文件被更改,包括 15 次插入16 次删除
  1. 15 16
      server/upload.go

+ 15 - 16
server/upload.go

@@ -5,6 +5,7 @@ import (
 	"crypto/md5"
 	"crypto/md5"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
+	"hash"
 	"io"
 	"io"
 	"log"
 	"log"
 	"math"
 	"math"
@@ -102,7 +103,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *Reg
 		}
 		}
 
 
 		// set part.N to the current number of parts
 		// set part.N to the current number of parts
-		b.Parts = append(b.Parts, blobUploadPart{blobUpload: b, N: len(b.Parts), Offset: offset, Size: size})
+		b.Parts = append(b.Parts, blobUploadPart{blobUpload: b, N: len(b.Parts), Offset: offset, Size: size, Hash: md5.New()})
 		offset += size
 		offset += size
 	}
 	}
 
 
@@ -147,14 +148,13 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
 			g.Go(func() error {
 			g.Go(func() error {
 				var err error
 				var err error
 				for try := 0; try < maxRetries; try++ {
 				for try := 0; try < maxRetries; try++ {
-					err = b.uploadChunk(inner, http.MethodPatch, requestURL, part, opts)
+					err = b.uploadPart(inner, http.MethodPatch, requestURL, part, opts)
 					switch {
 					switch {
 					case errors.Is(err, context.Canceled):
 					case errors.Is(err, context.Canceled):
 						return err
 						return err
 					case errors.Is(err, errMaxRetriesExceeded):
 					case errors.Is(err, errMaxRetriesExceeded):
 						return err
 						return err
 					case err != nil:
 					case err != nil:
-						part.Reset()
 						sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
 						sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
 						log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)
 						log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)
 						time.Sleep(sleep)
 						time.Sleep(sleep)
@@ -176,17 +176,10 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
 
 
 	requestURL := <-b.nextURL
 	requestURL := <-b.nextURL
 
 
-	var sb strings.Builder
-
 	// calculate md5 checksum and add it to the commit request
 	// calculate md5 checksum and add it to the commit request
+	var sb strings.Builder
 	for _, part := range b.Parts {
 	for _, part := range b.Parts {
-		hash := md5.New()
-		if _, err := io.Copy(hash, io.NewSectionReader(b.file, part.Offset, part.Size)); err != nil {
-			b.err = err
-			return
-		}
-
-		sb.Write(hash.Sum(nil))
+		sb.Write(part.Sum(nil))
 	}
 	}
 
 
 	md5sum := md5.Sum([]byte(sb.String()))
 	md5sum := md5.Sum([]byte(sb.String()))
@@ -221,7 +214,10 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
 	}
 	}
 }
 }
 
 
-func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error {
+func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error {
+	// reset the part here to ensure alignment
+	part.Reset()
+
 	headers := make(http.Header)
 	headers := make(http.Header)
 	headers.Set("Content-Type", "application/octet-stream")
 	headers.Set("Content-Type", "application/octet-stream")
 	headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
 	headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
@@ -232,7 +228,7 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
 	}
 	}
 
 
 	sr := io.NewSectionReader(b.file, part.Offset, part.Size)
 	sr := io.NewSectionReader(b.file, part.Offset, part.Size)
-	resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, part), opts)
+	resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(part, part.Hash)), opts)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -259,14 +255,13 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
 
 
 		// retry uploading to the redirect URL
 		// retry uploading to the redirect URL
 		for try := 0; try < maxRetries; try++ {
 		for try := 0; try < maxRetries; try++ {
-			err = b.uploadChunk(ctx, http.MethodPut, redirectURL, part, nil)
+			err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, nil)
 			switch {
 			switch {
 			case errors.Is(err, context.Canceled):
 			case errors.Is(err, context.Canceled):
 				return err
 				return err
 			case errors.Is(err, errMaxRetriesExceeded):
 			case errors.Is(err, errMaxRetriesExceeded):
 				return err
 				return err
 			case err != nil:
 			case err != nil:
-				part.Reset()
 				sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
 				sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
 				log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)
 				log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)
 				time.Sleep(sleep)
 				time.Sleep(sleep)
@@ -345,7 +340,10 @@ type blobUploadPart struct {
 	Offset  int64
 	Offset  int64
 	Size    int64
 	Size    int64
 	written int64
 	written int64
+
 	*blobUpload
 	*blobUpload
+
+	hash.Hash
 }
 }
 
 
 func (p *blobUploadPart) Write(b []byte) (n int, err error) {
 func (p *blobUploadPart) Write(b []byte) (n int, err error) {
@@ -356,6 +354,7 @@ func (p *blobUploadPart) Write(b []byte) (n int, err error) {
 }
 }
 
 
 func (p *blobUploadPart) Reset() {
 func (p *blobUploadPart) Reset() {
+	p.Hash.Reset()
 	p.Completed.Add(-int64(p.written))
 	p.Completed.Add(-int64(p.written))
 	p.written = 0
 	p.written = 0
 }
 }