Explorar o código

Merge pull request #975 from jmorganca/mxyng/downloads

update downloads to use retry wrapper
Michael Yang hai 1 ano
pai
achega
1fd511e661
Modificáronse 5 ficheiros con 24 adicións e 39 borrados
  1. 1 1
      api/client.go
  2. 2 2
      server/auth.go
  3. 2 7
      server/download.go
  4. 11 17
      server/images.go
  5. 8 12
      server/upload.go

+ 1 - 1
api/client.go

@@ -72,7 +72,7 @@ func ClientFromEnvironment() (*Client, error) {
 		},
 	}
 
-	mockRequest, err := http.NewRequest("HEAD", client.base.String(), nil)
+	mockRequest, err := http.NewRequest(http.MethodHead, client.base.String(), nil)
 	if err != nil {
 		return nil, err
 	}

+ 2 - 2
server/auth.go

@@ -91,7 +91,7 @@ func getAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) {
 	}
 
 	s := SignatureData{
-		Method: "GET",
+		Method: http.MethodGet,
 		Path:   redirectURL.String(),
 		Data:   nil,
 	}
@@ -103,7 +103,7 @@ func getAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) {
 
 	headers := make(http.Header)
 	headers.Set("Authorization", sig)
-	resp, err := makeRequest(ctx, "GET", redirectURL, headers, nil, nil)
+	resp, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil)
 	if err != nil {
 		log.Printf("couldn't get token: %q", err)
 		return "", err

+ 2 - 7
server/download.go

@@ -89,17 +89,12 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R
 	}
 
 	if len(b.Parts) == 0 {
-		resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, opts)
+		resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
 		if err != nil {
 			return err
 		}
 		defer resp.Body.Close()
 
-		if resp.StatusCode >= http.StatusBadRequest {
-			body, _ := io.ReadAll(resp.Body)
-			return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body))
-		}
-
 		b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
 
 		var size = b.Total / numDownloadParts
@@ -199,7 +194,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
 func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
 	headers := make(http.Header)
 	headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
-	resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts)
+	resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
 	if err != nil {
 		return err
 	}

+ 11 - 17
server/images.go

@@ -1002,7 +1002,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
 
 	headers := make(http.Header)
 	headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
-	resp, err := makeRequestWithRetry(ctx, "PUT", requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
+	resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
 	if err != nil {
 		return err
 	}
@@ -1124,22 +1124,12 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptio
 
 	headers := make(http.Header)
 	headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
-	resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, regOpts)
+	resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts)
 	if err != nil {
-		log.Printf("couldn't get manifest: %v", err)
 		return nil, err
 	}
 	defer resp.Body.Close()
 
-	if resp.StatusCode >= http.StatusBadRequest {
-		if resp.StatusCode == http.StatusNotFound {
-			return nil, fmt.Errorf("model not found")
-		}
-
-		body, _ := io.ReadAll(resp.Body)
-		return nil, fmt.Errorf("on pull registry responded with code %d: %s", resp.StatusCode, body)
-	}
-
 	var m *ManifestV2
 	if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
 		return nil, err
@@ -1202,15 +1192,19 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
 
 			regOpts.Token = token
 			if body != nil {
-				if _, err := body.Seek(0, io.SeekStart); err != nil {
-					return nil, err
-				}
+				body.Seek(0, io.SeekStart)
 			}
 
 			continue
+		case resp.StatusCode == http.StatusNotFound:
+			return nil, os.ErrNotExist
 		case resp.StatusCode >= http.StatusBadRequest:
-			body, _ := io.ReadAll(resp.Body)
-			return nil, fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
+			body, err := io.ReadAll(resp.Body)
+			if err != nil {
+				return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
+			}
+
+			return nil, fmt.Errorf("%d: %s", resp.StatusCode, body)
 		default:
 			return resp, nil
 		}

+ 8 - 12
server/upload.go

@@ -67,7 +67,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *Reg
 		requestURL.RawQuery = values.Encode()
 	}
 
-	resp, err := makeRequestWithRetry(ctx, "POST", requestURL, nil, nil, opts)
+	resp, err := makeRequestWithRetry(ctx, http.MethodPost, requestURL, nil, nil, opts)
 	if err != nil {
 		return err
 	}
@@ -187,7 +187,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
 	headers.Set("Content-Type", "application/octet-stream")
 	headers.Set("Content-Length", "0")
 
-	resp, err := makeRequestWithRetry(ctx, "PUT", requestURL, headers, nil, opts)
+	resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts)
 	if err != nil {
 		b.err = err
 		return
@@ -334,15 +334,13 @@ func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *RegistryO
 	requestURL := mp.BaseURL()
 	requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
 
-	resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, opts)
-	if err != nil {
+	resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
+	switch {
+	case errors.Is(err, os.ErrNotExist):
+	case err != nil:
 		return err
-	}
-	defer resp.Body.Close()
-
-	switch resp.StatusCode {
-	case http.StatusNotFound:
-	case http.StatusOK:
+	default:
+		defer resp.Body.Close()
 		fn(api.ProgressResponse{
 			Status:    fmt.Sprintf("uploading %s", layer.Digest),
 			Digest:    layer.Digest,
@@ -351,8 +349,6 @@ func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *RegistryO
 		})
 
 		return nil
-	default:
-		return fmt.Errorf("unexpected status code %d", resp.StatusCode)
 	}
 
 	data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer})