Browse Source

Merge pull request #486 from jmorganca/mxyng/fix-push

fix: retry push on expired token
Michael Yang 1 year ago
parent
commit
738fe9c4aa
3 changed files with 32 additions and 23 deletions
  1. 1 1
      server/auth.go
  2. 6 4
      server/images.go
  3. 25 18
      server/upload.go

+ 1 - 1
server/auth.go

@@ -103,7 +103,7 @@ func getAuthToken(ctx context.Context, redirData AuthRedirect, regOpts *Registry
 
 	headers := make(http.Header)
 	headers.Set("Authorization", sig)
-	resp, err := makeRequest(ctx, "GET", redirectURL, headers, nil, regOpts)
+	resp, err := makeRequest(ctx, "GET", redirectURL, headers, nil, nil)
 	if err != nil {
 		log.Printf("couldn't get token: %q", err)
 	}

+ 6 - 4
server/images.go

@@ -1313,10 +1313,12 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header
 		req.Header = headers
 	}
 
-	if regOpts.Token != "" {
-		req.Header.Set("Authorization", "Bearer "+regOpts.Token)
-	} else if regOpts.Username != "" && regOpts.Password != "" {
-		req.SetBasicAuth(regOpts.Username, regOpts.Password)
+	if regOpts != nil {
+		if regOpts.Token != "" {
+			req.Header.Set("Authorization", "Bearer "+regOpts.Token)
+		} else if regOpts.Username != "" && regOpts.Password != "" {
+			req.SetBasicAuth(regOpts.Username, regOpts.Password)
+		}
 	}
 
 	req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))

+ 25 - 18
server/upload.go

@@ -66,31 +66,39 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r
 
 		sectionReader := io.NewSectionReader(f, int64(offset), chunk)
 		for try := 0; try < MaxRetries; try++ {
+			ch := make(chan error, 1)
+
 			r, w := io.Pipe()
 			defer r.Close()
 			go func() {
 				defer w.Close()
 
 				for chunked := int64(0); chunked < chunk; {
-					n, err := io.CopyN(w, sectionReader, 1024*1024)
-					if err != nil && !errors.Is(err, io.EOF) {
+					select {
+					case err := <-ch:
+						log.Printf("chunk interrupted: %v", err)
+						return
+					default:
+						n, err := io.CopyN(w, sectionReader, 1024*1024)
+						if err != nil && !errors.Is(err, io.EOF) {
+							fn(api.ProgressResponse{
+								Status:    fmt.Sprintf("error reading chunk: %v", err),
+								Digest:    layer.Digest,
+								Total:     layer.Size,
+								Completed: int(offset),
+							})
+
+							return
+						}
+
+						chunked += n
 						fn(api.ProgressResponse{
-							Status:    fmt.Sprintf("error reading chunk: %v", err),
+							Status:    fmt.Sprintf("uploading %s", layer.Digest),
 							Digest:    layer.Digest,
 							Total:     layer.Size,
-							Completed: int(offset),
+							Completed: int(offset) + int(chunked),
 						})
-
-						return
 					}
-
-					chunked += n
-					fn(api.ProgressResponse{
-						Status:    fmt.Sprintf("uploading %s", layer.Digest),
-						Digest:    layer.Digest,
-						Total:     layer.Size,
-						Completed: int(offset) + int(chunked),
-					})
 				}
 			}()
 
@@ -113,6 +121,8 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r
 
 			switch {
 			case resp.StatusCode == http.StatusUnauthorized:
+				ch <- errors.New("unauthorized")
+
 				auth := resp.Header.Get("www-authenticate")
 				authRedir := ParseAuthRedirectString(auth)
 				token, err := getAuthToken(ctx, authRedir, regOpts)
@@ -121,10 +131,7 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r
 				}
 
 				regOpts.Token = token
-				if _, err := sectionReader.Seek(0, io.SeekStart); err != nil {
-					return err
-				}
-
+				sectionReader = io.NewSectionReader(f, int64(offset), chunk)
 				continue
 			case resp.StatusCode >= http.StatusBadRequest:
 				body, _ := io.ReadAll(resp.Body)