Explorar o código

implement ProgressWriter

Michael Yang hai 1 ano
pai
achega
f0b398d17f
Modificáronse 1 ficheiros con 47 adicións e 38 borrados
  1. 47 38
      server/upload.go

+ 47 - 38
server/upload.go

@@ -57,6 +57,12 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r
 
 	// 95MB chunk size
 	chunkSize := 95 * 1024 * 1024
+	pw := ProgressWriter{
+		status: fmt.Sprintf("uploading %s", layer.Digest),
+		digest: layer.Digest,
+		total:  layer.Size,
+		fn:     fn,
+	}
 
 	for offset := int64(0); offset < int64(layer.Size); {
 		chunk := int64(layer.Size) - offset
@@ -65,48 +71,16 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r
 		}
 
 		sectionReader := io.NewSectionReader(f, int64(offset), chunk)
+
+		var errStatus error
 		for try := 0; try < MaxRetries; try++ {
-			ch := make(chan error, 1)
-
-			r, w := io.Pipe()
-			defer r.Close()
-			go func() {
-				defer w.Close()
-
-				for chunked := int64(0); chunked < chunk; {
-					select {
-					case err := <-ch:
-						log.Printf("chunk interrupted: %v", err)
-						return
-					default:
-						n, err := io.CopyN(w, sectionReader, 1024*1024)
-						if err != nil && !errors.Is(err, io.EOF) {
-							fn(api.ProgressResponse{
-								Status:    fmt.Sprintf("error reading chunk: %v", err),
-								Digest:    layer.Digest,
-								Total:     layer.Size,
-								Completed: int(offset),
-							})
-
-							return
-						}
-
-						chunked += n
-						fn(api.ProgressResponse{
-							Status:    fmt.Sprintf("uploading %s", layer.Digest),
-							Digest:    layer.Digest,
-							Total:     layer.Size,
-							Completed: int(offset) + int(chunked),
-						})
-					}
-				}
-			}()
+			errStatus = nil
 
 			headers := make(http.Header)
 			headers.Set("Content-Type", "application/octet-stream")
 			headers.Set("Content-Length", strconv.Itoa(int(chunk)))
 			headers.Set("Content-Range", fmt.Sprintf("%d-%d", offset, offset+sectionReader.Size()-1))
-			resp, err := makeRequest(ctx, "PATCH", requestURL, headers, r, regOpts)
+			resp, err := makeRequest(ctx, "PATCH", requestURL, headers, io.TeeReader(sectionReader, &pw), regOpts)
 			if err != nil && !errors.Is(err, io.EOF) {
 				fn(api.ProgressResponse{
 					Status:    fmt.Sprintf("error uploading chunk: %v", err),
@@ -121,7 +95,7 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r
 
 			switch {
 			case resp.StatusCode == http.StatusUnauthorized:
-				ch <- errors.New("unauthorized")
+				errStatus = errors.New("unauthorized")
 
 				auth := resp.Header.Get("www-authenticate")
 				authRedir := ParseAuthRedirectString(auth)
@@ -131,7 +105,9 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r
 				}
 
 				regOpts.Token = token
-				sectionReader = io.NewSectionReader(f, int64(offset), chunk)
+
+				pw.completed = int(offset)
+				sectionReader = io.NewSectionReader(f, offset, chunk)
 				continue
 			case resp.StatusCode >= http.StatusBadRequest:
 				body, _ := io.ReadAll(resp.Body)
@@ -146,6 +122,10 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r
 
 			break
 		}
+
+		if errStatus != nil {
+			return fmt.Errorf("max retries exceeded: %w", errStatus)
+		}
 	}
 
 	values := requestURL.Query()
@@ -170,3 +150,32 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r
 	}
 	return nil
 }
+
+type ProgressWriter struct {
+	status    string
+	digest    string
+	bucket    int
+	completed int
+	total     int
+	fn        func(api.ProgressResponse)
+}
+
+func (pw *ProgressWriter) Write(b []byte) (int, error) {
+	n := len(b)
+	pw.bucket += n
+	pw.completed += n
+
+	// throttle status updates to not spam the client
+	if pw.bucket >= 1024*1024 || pw.completed >= pw.total {
+		pw.fn(api.ProgressResponse{
+			Status:    pw.status,
+			Digest:    pw.digest,
+			Total:     pw.total,
+			Completed: pw.completed,
+		})
+
+		pw.bucket = 0
+	}
+
+	return n, nil
+}