Przeglądaj źródła

refactor part reset

Michael Yang 1 rok temu
rodzic
commit
84725ec7e3
1 zmienionych plików z 28 dodań i 27 usunięć
  1. 28 27
      server/upload.go

+ 28 - 27
server/upload.go

@@ -40,14 +40,6 @@ type blobUpload struct {
 	references atomic.Int32
 	references atomic.Int32
 }
 }
 
 
-type blobUploadPart struct {
-	// N is the part number
-	N      int
-	Offset int64
-	Size   int64
-	hash.Hash
-}
-
 const (
 const (
 	numUploadParts          = 64
 	numUploadParts          = 64
 	minUploadPartSize int64 = 95 * 1000 * 1000
 	minUploadPartSize int64 = 95 * 1000 * 1000
@@ -100,7 +92,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *Reg
 		}
 		}
 
 
 		// set part.N to the current number of parts
 		// set part.N to the current number of parts
-		b.Parts = append(b.Parts, blobUploadPart{N: len(b.Parts), Offset: offset, Size: size, Hash: md5.New()})
+		b.Parts = append(b.Parts, blobUploadPart{blobUpload: b, N: len(b.Parts), Offset: offset, Size: size})
 		offset += size
 		offset += size
 	}
 	}
 
 
@@ -144,8 +136,8 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
 		case requestURL := <-b.nextURL:
 		case requestURL := <-b.nextURL:
 			g.Go(func() error {
 			g.Go(func() error {
 				for try := 0; try < maxRetries; try++ {
 				for try := 0; try < maxRetries; try++ {
-					r := io.NewSectionReader(f, part.Offset, part.Size)
-					err := b.uploadChunk(inner, http.MethodPatch, requestURL, r, part, opts)
+					part.ReadSeeker = io.NewSectionReader(f, part.Offset, part.Size)
+					err := b.uploadChunk(inner, http.MethodPatch, requestURL, part, opts)
 					switch {
 					switch {
 					case errors.Is(err, context.Canceled):
 					case errors.Is(err, context.Canceled):
 						return err
 						return err
@@ -197,7 +189,9 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
 	b.done = true
 	b.done = true
 }
 }
 
 
-func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL *url.URL, rs io.ReadSeeker, part *blobUploadPart, opts *RegistryOptions) error {
+func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error {
+	part.Reset()
+
 	headers := make(http.Header)
 	headers := make(http.Header)
 	headers.Set("Content-Type", "application/octet-stream")
 	headers.Set("Content-Type", "application/octet-stream")
 	headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
 	headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
@@ -207,8 +201,7 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
 		headers.Set("Content-Range", fmt.Sprintf("%d-%d", part.Offset, part.Offset+part.Size-1))
 		headers.Set("Content-Range", fmt.Sprintf("%d-%d", part.Offset, part.Offset+part.Size-1))
 	}
 	}
 
 
-	buw := blobUploadWriter{blobUpload: b}
-	resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(rs, io.MultiWriter(&buw, part.Hash)), opts)
+	resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(part.ReadSeeker, io.MultiWriter(part, part.Hash)), opts)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -234,11 +227,7 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
 		}
 		}
 
 
 		for try := 0; try < maxRetries; try++ {
 		for try := 0; try < maxRetries; try++ {
-			rs.Seek(0, io.SeekStart)
-			b.Completed.Add(-buw.written)
-			buw.written = 0
-			part.Hash = md5.New()
-			err := b.uploadChunk(ctx, http.MethodPut, redirectURL, rs, part, nil)
+			err := b.uploadChunk(ctx, http.MethodPut, redirectURL, part, nil)
 			switch {
 			switch {
 			case errors.Is(err, context.Canceled):
 			case errors.Is(err, context.Canceled):
 				return err
 				return err
@@ -270,9 +259,6 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
 			return err
 			return err
 		}
 		}
 
 
-		rs.Seek(0, io.SeekStart)
-		b.Completed.Add(-buw.written)
-		buw.written = 0
 		return fmt.Errorf("http status %d %s: %s", resp.StatusCode, resp.Status, body)
 		return fmt.Errorf("http status %d %s: %s", resp.StatusCode, resp.Status, body)
 	}
 	}
 
 
@@ -318,18 +304,33 @@ func (b *blobUpload) Wait(ctx context.Context, fn func(api.ProgressResponse)) er
 	}
 	}
 }
 }
 
 
-type blobUploadWriter struct {
+type blobUploadPart struct {
+	// N is the part number
+	N      int
+	Offset int64
+	Size   int64
+	hash.Hash
+
 	written int64
 	written int64
+
+	io.ReadSeeker
 	*blobUpload
 	*blobUpload
 }
 }
 
 
-func (b *blobUploadWriter) Write(p []byte) (n int, err error) {
-	n = len(p)
-	b.written += int64(n)
-	b.Completed.Add(int64(n))
+func (p *blobUploadPart) Write(b []byte) (n int, err error) {
+	n = len(b)
+	p.written += int64(n)
+	p.Completed.Add(int64(n))
 	return n, nil
 	return n, nil
 }
 }
 
 
+func (p *blobUploadPart) Reset() {
+	p.Seek(0, io.SeekStart)
+	p.Completed.Add(-int64(p.written))
+	p.written = 0
+	p.Hash = md5.New()
+}
+
 func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *RegistryOptions, fn func(api.ProgressResponse)) error {
 func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *RegistryOptions, fn func(api.ProgressResponse)) error {
 	requestURL := mp.BaseURL()
 	requestURL := mp.BaseURL()
 	requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
 	requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)