浏览代码

add maximum retries when pushing (#334)

Patrick Devine 1 年之前
父节点
当前提交
d9cf18e28d
共有 4 个文件被更改,包括 35 次插入21 次删除
  1. 3 2
      server/auth.go
  2. 1 1
      server/download.go
  3. 29 17
      server/images.go
  4. 2 1
      server/routes.go

+ 3 - 2
server/auth.go

@@ -2,6 +2,7 @@ package server
 
 import (
 	"bytes"
+	"context"
 	"crypto/rand"
 	"crypto/sha256"
 	"encoding/base64"
@@ -50,7 +51,7 @@ func (r AuthRedirect) URL() (string, error) {
 	return fmt.Sprintf("%s?service=%s&scope=%s&ts=%d&nonce=%s", r.Realm, r.Service, r.Scope, time.Now().Unix(), nonce), nil
 }
 
-func getAuthToken(redirData AuthRedirect, regOpts *RegistryOptions) (string, error) {
+func getAuthToken(ctx context.Context, redirData AuthRedirect, regOpts *RegistryOptions) (string, error) {
 	url, err := redirData.URL()
 	if err != nil {
 		return "", err
@@ -92,7 +93,7 @@ func getAuthToken(redirData AuthRedirect, regOpts *RegistryOptions) (string, err
 		"Authorization": sig,
 	}
 
-	resp, err := makeRequest("GET", url, headers, nil, regOpts)
+	resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts)
 	if err != nil {
 		log.Printf("couldn't get token: %q", err)
 	}

+ 1 - 1
server/download.go

@@ -137,7 +137,7 @@ func doDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *
 		"Range": fmt.Sprintf("bytes=%d-", size),
 	}
 
-	resp, err := makeRequest("GET", url, headers, nil, regOpts)
+	resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts)
 	if err != nil {
 		log.Printf("couldn't download blob: %v", err)
 		return err

+ 29 - 17
server/images.go

@@ -24,6 +24,8 @@ import (
 	"github.com/jmorganca/ollama/vector"
 )
 
+const MaxRetries = 3
+
 type RegistryOptions struct {
 	Insecure bool
 	Username string
@@ -856,7 +858,7 @@ func DeleteModel(name string) error {
 	return nil
 }
 
-func PushModel(name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
+func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
 	mp := ParseModelPath(name)
 
 	fn(api.ProgressResponse{Status: "retrieving manifest"})
@@ -872,7 +874,7 @@ func PushModel(name string, regOpts *RegistryOptions, fn func(api.ProgressRespon
 	layers = append(layers, &manifest.Config)
 
 	for _, layer := range layers {
-		exists, err := checkBlobExistence(mp, layer.Digest, regOpts)
+		exists, err := checkBlobExistence(ctx, mp, layer.Digest, regOpts)
 		if err != nil {
 			return err
 		}
@@ -894,13 +896,13 @@ func PushModel(name string, regOpts *RegistryOptions, fn func(api.ProgressRespon
 			Total:  layer.Size,
 		})
 
-		location, err := startUpload(mp, regOpts)
+		location, err := startUpload(ctx, mp, regOpts)
 		if err != nil {
 			log.Printf("couldn't start upload: %v", err)
 			return err
 		}
 
-		err = uploadBlobChunked(mp, location, layer, regOpts, fn)
+		err = uploadBlobChunked(ctx, mp, location, layer, regOpts, fn)
 		if err != nil {
 			log.Printf("error uploading blob: %v", err)
 			return err
@@ -918,7 +920,7 @@ func PushModel(name string, regOpts *RegistryOptions, fn func(api.ProgressRespon
 		return err
 	}
 
-	resp, err := makeRequest("PUT", url, headers, bytes.NewReader(manifestJSON), regOpts)
+	resp, err := makeRequest(ctx, "PUT", url, headers, bytes.NewReader(manifestJSON), regOpts)
 	if err != nil {
 		return err
 	}
@@ -940,7 +942,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
 
 	fn(api.ProgressResponse{Status: "pulling manifest"})
 
-	manifest, err := pullModelManifest(mp, regOpts)
+	manifest, err := pullModelManifest(ctx, mp, regOpts)
 	if err != nil {
 		return fmt.Errorf("pull model manifest: %s", err)
 	}
@@ -996,13 +998,13 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
 	return nil
 }
 
-func pullModelManifest(mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
+func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
 	url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
 	headers := map[string]string{
 		"Accept": "application/vnd.docker.distribution.manifest.v2+json",
 	}
 
-	resp, err := makeRequest("GET", url, headers, nil, regOpts)
+	resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts)
 	if err != nil {
 		log.Printf("couldn't get manifest: %v", err)
 		return nil, err
@@ -1061,10 +1063,10 @@ func GetSHA256Digest(r io.Reader) (string, int) {
 	return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n)
 }
 
-func startUpload(mp ModelPath, regOpts *RegistryOptions) (string, error) {
+func startUpload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (string, error) {
 	url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository())
 
-	resp, err := makeRequest("POST", url, nil, nil, regOpts)
+	resp, err := makeRequest(ctx, "POST", url, nil, nil, regOpts)
 	if err != nil {
 		log.Printf("couldn't start upload: %v", err)
 		return "", err
@@ -1087,10 +1089,10 @@ func startUpload(mp ModelPath, regOpts *RegistryOptions) (string, error) {
 }
 
 // Function to check if a blob already exists in the Docker registry
-func checkBlobExistence(mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
+func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
 	url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), digest)
 
-	resp, err := makeRequest("HEAD", url, nil, nil, regOpts)
+	resp, err := makeRequest(ctx, "HEAD", url, nil, nil, regOpts)
 	if err != nil {
 		log.Printf("couldn't check for blob: %v", err)
 		return false, err
@@ -1101,7 +1103,7 @@ func checkBlobExistence(mp ModelPath, digest string, regOpts *RegistryOptions) (
 	return resp.StatusCode == http.StatusOK, nil
 }
 
-func uploadBlobChunked(mp ModelPath, url string, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
+func uploadBlobChunked(ctx context.Context, mp ModelPath, url string, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
 	// TODO allow resumability
 	// TODO allow canceling uploads via DELETE
 	// TODO allow cross repo blob mount
@@ -1158,7 +1160,7 @@ func uploadBlobChunked(mp ModelPath, url string, layer *Layer, regOpts *Registry
 	headers["Content-Length"] = strconv.Itoa(int(layer.Size))
 
 	// finish the upload
-	resp, err := makeRequest("PUT", url, headers, r, regOpts)
+	resp, err := makeRequest(ctx, "PUT", url, headers, r, regOpts)
 	if err != nil {
 		log.Printf("couldn't finish upload: %v", err)
 		return err
@@ -1172,7 +1174,16 @@ func uploadBlobChunked(mp ModelPath, url string, layer *Layer, regOpts *Registry
 	return nil
 }
 
-func makeRequest(method, url string, headers map[string]string, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
+func makeRequest(ctx context.Context, method, url string, headers map[string]string, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
+	retryCtx := ctx.Value("retries")
+	var retries int
+	var ok bool
+	if retries, ok = retryCtx.(int); ok {
+		if retries > MaxRetries {
+			return nil, fmt.Errorf("Maximum retries hit; are you sure you have access to this resource?")
+		}
+	}
+
 	if !strings.HasPrefix(url, "http") {
 		if regOpts.Insecure {
 			url = "http://" + url
@@ -1225,13 +1236,14 @@ func makeRequest(method, url string, headers map[string]string, body io.Reader,
 	if resp.StatusCode == http.StatusUnauthorized {
 		auth := resp.Header.Get("Www-Authenticate")
 		authRedir := ParseAuthRedirectString(string(auth))
-		token, err := getAuthToken(authRedir, regOpts)
+		token, err := getAuthToken(ctx, authRedir, regOpts)
 		if err != nil {
 			return nil, err
 		}
 		regOpts.Token = token
 		bodyCopy = bytes.NewReader(buf.Bytes())
-		return makeRequest(method, url, headers, bodyCopy, regOpts)
+		ctx = context.WithValue(ctx, "retries", retries+1)
+		return makeRequest(ctx, method, url, headers, bodyCopy, regOpts)
 	}
 
 	return resp, nil

+ 2 - 1
server/routes.go

@@ -277,7 +277,8 @@ func PushModelHandler(c *gin.Context) {
 			Password: req.Password,
 		}
 
-		if err := PushModel(req.Name, regOpts, fn); err != nil {
+		ctx := context.Background()
+		if err := PushModel(ctx, req.Name, regOpts, fn); err != nil {
 			ch <- gin.H{"error": err.Error()}
 		}
 	}()