瀏覽代碼

upload: separate progress tracking

Michael Yang 1 年之前
父節點
當前提交
c4bdfffd96
共有 1 個文件被更改,包括 21 次插入15 次删除
  1. 21 15
      server/upload.go

+ 21 - 15
server/upload.go

@@ -103,7 +103,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *Reg
 		}
 		}
 
 
 		// set part.N to the current number of parts
 		// set part.N to the current number of parts
-		b.Parts = append(b.Parts, blobUploadPart{blobUpload: b, N: len(b.Parts), Offset: offset, Size: size})
+		b.Parts = append(b.Parts, blobUploadPart{N: len(b.Parts), Offset: offset, Size: size})
 		offset += size
 		offset += size
 	}
 	}
 
 
@@ -215,9 +215,6 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
 }
 }
 
 
 func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error {
 func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error {
-	// reset the part here to ensure alignment
-	part.Reset()
-
 	headers := make(http.Header)
 	headers := make(http.Header)
 	headers.Set("Content-Type", "application/octet-stream")
 	headers.Set("Content-Type", "application/octet-stream")
 	headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
 	headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
@@ -227,10 +224,14 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
 		headers.Set("Content-Range", fmt.Sprintf("%d-%d", part.Offset, part.Offset+part.Size-1))
 		headers.Set("Content-Range", fmt.Sprintf("%d-%d", part.Offset, part.Offset+part.Size-1))
 	}
 	}
 
 
-	md5sum := md5.New()
 	sr := io.NewSectionReader(b.file, part.Offset, part.Size)
 	sr := io.NewSectionReader(b.file, part.Offset, part.Size)
-	resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(part, md5sum)), opts)
+
+	md5sum := md5.New()
+	w := &progressWriter{blobUpload: b}
+
+	resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts)
 	if err != nil {
 	if err != nil {
+		w.Rollback()
 		return err
 		return err
 	}
 	}
 	defer resp.Body.Close()
 	defer resp.Body.Close()
@@ -242,11 +243,13 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
 
 
 	nextURL, err := url.Parse(location)
 	nextURL, err := url.Parse(location)
 	if err != nil {
 	if err != nil {
+		w.Rollback()
 		return err
 		return err
 	}
 	}
 
 
 	switch {
 	switch {
 	case resp.StatusCode == http.StatusTemporaryRedirect:
 	case resp.StatusCode == http.StatusTemporaryRedirect:
+		w.Rollback()
 		b.nextURL <- nextURL
 		b.nextURL <- nextURL
 
 
 		redirectURL, err := resp.Location()
 		redirectURL, err := resp.Location()
@@ -275,6 +278,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
 		return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
 		return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
 
 
 	case resp.StatusCode == http.StatusUnauthorized:
 	case resp.StatusCode == http.StatusUnauthorized:
+		w.Rollback()
 		auth := resp.Header.Get("www-authenticate")
 		auth := resp.Header.Get("www-authenticate")
 		authRedir := ParseAuthRedirectString(auth)
 		authRedir := ParseAuthRedirectString(auth)
 		token, err := getAuthToken(ctx, authRedir)
 		token, err := getAuthToken(ctx, authRedir)
@@ -285,6 +289,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
 		opts.Token = token
 		opts.Token = token
 		fallthrough
 		fallthrough
 	case resp.StatusCode >= http.StatusBadRequest:
 	case resp.StatusCode >= http.StatusBadRequest:
+		w.Rollback()
 		body, err := io.ReadAll(resp.Body)
 		body, err := io.ReadAll(resp.Body)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
@@ -338,25 +343,26 @@ func (b *blobUpload) Wait(ctx context.Context, fn func(api.ProgressResponse)) er
 
 
 type blobUploadPart struct {
 type blobUploadPart struct {
 	// N is the part number
 	// N is the part number
-	N       int
-	Offset  int64
-	Size    int64
-	written int64
+	N      int
+	Offset int64
+	Size   int64
+	hash.Hash
+}
 
 
+type progressWriter struct {
+	written int64
 	*blobUpload
 	*blobUpload
-
-	hash.Hash
 }
 }
 
 
-func (p *blobUploadPart) Write(b []byte) (n int, err error) {
+func (p *progressWriter) Write(b []byte) (n int, err error) {
 	n = len(b)
 	n = len(b)
 	p.written += int64(n)
 	p.written += int64(n)
 	p.Completed.Add(int64(n))
 	p.Completed.Add(int64(n))
 	return n, nil
 	return n, nil
 }
 }
 
 
-func (p *blobUploadPart) Reset() {
-	p.Completed.Add(-int64(p.written))
+func (p *progressWriter) Rollback() {
+	p.Completed.Add(-p.written)
 	p.written = 0
 	p.written = 0
 }
 }