瀏覽代碼

replace done channel with file check

Michael Yang 1 年之前
父節點
當前提交
10199c5987
共有 1 個文件被更改,包括 24 次插入29 次删除
  1. 24 29
      server/download.go

+ 24 - 29
server/download.go

@@ -31,10 +31,8 @@ type blobDownload struct {
 	Total     int64
 	Completed atomic.Int64
 
-	*os.File
 	Parts []*blobDownloadPart
 
-	done chan struct{}
 	context.CancelFunc
 	references atomic.Int32
 }
@@ -54,6 +52,14 @@ func (p *blobDownloadPart) Name() string {
 	}, "-")
 }
 
+func (p *blobDownloadPart) StartsAt() int64 {
+	return p.Offset + p.Completed
+}
+
+func (p *blobDownloadPart) StopsAt() int64 {
+	return p.Offset + p.Size
+}
+
 func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
 	partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
 	if err != nil {
@@ -110,18 +116,16 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis
 
 	ctx, b.CancelFunc = context.WithCancel(ctx)
 
-	b.File, err = os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0644)
+	file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0644)
 	if err != nil {
 		return err
 	}
-	defer b.Close()
+	defer file.Close()
 
-	b.Truncate(b.Total)
-
-	b.done = make(chan struct{}, 1)
-	defer close(b.done)
+	file.Truncate(b.Total)
 
 	g, ctx := errgroup.WithContext(ctx)
+	// TODO(mxyng): download concurrency should be configurable
 	g.SetLimit(64)
 	for i := range b.Parts {
 		part := b.Parts[i]
@@ -132,7 +136,8 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis
 		i := i
 		g.Go(func() error {
 			for try := 0; try < maxRetries; try++ {
-				err := b.downloadChunk(ctx, requestURL, i, opts)
+				w := io.NewOffsetWriter(file, part.StartsAt())
+				err := b.downloadChunk(ctx, requestURL, w, part, opts)
 				switch {
 				case errors.Is(err, context.Canceled):
 					return err
@@ -152,31 +157,23 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis
 		return err
 	}
 
-	if err := b.Close(); err != nil {
+	// explicitly close the file so we can rename it
+	if err := file.Close(); err != nil {
 		return err
 	}
 
 	for i := range b.Parts {
-		if err := os.Remove(b.File.Name() + "-" + strconv.Itoa(i)); err != nil {
+		if err := os.Remove(file.Name() + "-" + strconv.Itoa(i)); err != nil {
 			return err
 		}
 	}
 
-	if err := os.Rename(b.File.Name(), b.Name); err != nil {
-		return err
-	}
-
-	return nil
+	return os.Rename(file.Name(), b.Name)
 }
 
-func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, i int, opts *RegistryOptions) error {
-	part := b.Parts[i]
-
-	offset := part.Offset + part.Completed
-	w := io.NewOffsetWriter(b.File, offset)
-
+func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
 	headers := make(http.Header)
-	headers.Set("Range", fmt.Sprintf("bytes=%d-%d", offset, part.Offset+part.Size-1))
+	headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
 	resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts)
 	if err != nil {
 		return err
@@ -258,10 +255,6 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
 	ticker := time.NewTicker(60 * time.Millisecond)
 	for {
 		select {
-		case <-b.done:
-			if b.Completed.Load() != b.Total {
-				return io.ErrUnexpectedEOF
-			}
 		case <-ticker.C:
 		case <-ctx.Done():
 			return ctx.Err()
@@ -275,8 +268,10 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
 		})
 
 		if b.Completed.Load() >= b.Total {
-			<-b.done
-			return nil
+			// wait for the file to get renamed
+			if _, err := os.Stat(b.Name); err == nil {
+				return nil
+			}
 		}
 	}
 }