Quellcode durchsuchen

Merge pull request #375 from jmorganca/mxyng/fix-push

fix push manifest
Michael Yang vor 1 Jahr
Ursprung
Commit
cbf725a9ba
1 geänderte Dateien mit 44 neuen und 36 gelöschten Zeilen
  1. 44 36
      server/images.go

+ 44 - 36
server/images.go

@@ -105,9 +105,9 @@ type LayerReader struct {
 
 type ConfigV2 struct {
 	ModelFamily llm.ModelFamily `json:"model_family"`
-	ModelType   string      `json:"model_type"`
-	FileType    string      `json:"file_type"`
-	RootFS      RootFS      `json:"rootfs"`
+	ModelType   string          `json:"model_type"`
+	FileType    string          `json:"file_type"`
+	RootFS      RootFS          `json:"rootfs"`
 
 	// required by spec
 	Architecture string `json:"architecture"`
@@ -963,18 +963,12 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
 		return err
 	}
 
-	resp, err := makeRequest(ctx, "PUT", url, headers, bytes.NewReader(manifestJSON), regOpts)
+	resp, err := makeRequestWithRetry(ctx, "PUT", url, headers, bytes.NewReader(manifestJSON), regOpts)
 	if err != nil {
 		return err
 	}
 	defer resp.Body.Close()
 
-	// Check for success: For a successful upload, the Docker registry will respond with a 201 Created
-	if resp.StatusCode != http.StatusCreated {
-		body, _ := io.ReadAll(resp.Body)
-		return fmt.Errorf("on push registry responded with code %d: %v", resp.StatusCode, string(body))
-	}
-
 	fn(api.ProgressResponse{Status: "success"})
 
 	return nil
@@ -1116,43 +1110,18 @@ func GetSHA256Digest(r io.Reader) (string, int) {
 type requestContextKey string
 
 func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (string, error) {
-	retry, _ := ctx.Value(requestContextKey("retry")).(int)
-
 	url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository())
 	if layer.From != "" {
 		url = fmt.Sprintf("%s/v2/%s/blobs/uploads/?mount=%s&from=%s", mp.Registry, mp.GetNamespaceRepository(), layer.Digest, layer.From)
 	}
 
-	resp, err := makeRequest(ctx, "POST", url, nil, nil, regOpts)
+	resp, err := makeRequestWithRetry(ctx, "POST", url, nil, nil, regOpts)
 	if err != nil {
 		log.Printf("couldn't start upload: %v", err)
 		return "", err
 	}
 	defer resp.Body.Close()
 
-	switch resp.StatusCode {
-	case http.StatusAccepted, http.StatusCreated:
-		// noop
-	case http.StatusUnauthorized:
-		if retry > MaxRetries {
-			return "", fmt.Errorf("max retries exceeded: %s", resp.Status)
-		}
-
-		auth := resp.Header.Get("www-authenticate")
-		authRedir := ParseAuthRedirectString(auth)
-		token, err := getAuthToken(ctx, authRedir, regOpts)
-		if err != nil {
-			return "", err
-		}
-
-		regOpts.Token = token
-		ctx = context.WithValue(ctx, requestContextKey("retry"), retry+1)
-		return startUpload(ctx, mp, layer, regOpts)
-	default:
-		body, _ := io.ReadAll(resp.Body)
-		return "", fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
-	}
-
 	// Extract UUID location from header
 	location := resp.Header.Get("Location")
 	if location == "" {
@@ -1277,6 +1246,45 @@ func uploadBlobChunked(ctx context.Context, mp ModelPath, url string, layer *Lay
 	return nil
 }
 
+func makeRequestWithRetry(ctx context.Context, method, url string, headers map[string]string, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
+	var status string
+	for try := 0; try < MaxRetries; try++ {
+		resp, err := makeRequest(ctx, method, url, headers, body, regOpts)
+		if err != nil {
+			log.Printf("couldn't start upload: %v", err)
+			return nil, err
+		}
+
+		status = resp.Status
+
+		switch resp.StatusCode {
+		case http.StatusAccepted, http.StatusCreated:
+			return resp, nil
+		case http.StatusUnauthorized:
+			auth := resp.Header.Get("www-authenticate")
+			authRedir := ParseAuthRedirectString(auth)
+			token, err := getAuthToken(ctx, authRedir, regOpts)
+			if err != nil {
+				return nil, err
+			}
+
+			regOpts.Token = token
+			if body != nil {
+				if _, err := body.Seek(0, io.SeekStart); err != nil {
+					return nil, err
+				}
+			}
+
+			continue
+		default:
+			body, _ := io.ReadAll(resp.Body)
+			return nil, fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
+		}
+	}
+
+	return nil, fmt.Errorf("max retry exceeded: %v", status)
+}
+
 func makeRequest(ctx context.Context, method, url string, headers map[string]string, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
 	if !strings.HasPrefix(url, "http") {
 		if regOpts.Insecure {