Selaa lähdekoodia

Merge pull request #750 from jmorganca/mxyng/concurrent-uploads

concurrent uploads
Michael Yang 1 vuosi sitten
vanhempi
commit
2c6189f4fe
4 muutettua tiedostoa jossa 287 lisäystä ja 189 poistoa
  1. 3 2
      server/download.go
  2. 2 60
      server/images.go
  3. 3 1
      server/routes.go
  4. 279 126
      server/upload.go

+ 3 - 2
server/download.go

@@ -134,7 +134,6 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis
 
 func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
 	defer blobDownloadManager.Delete(b.Digest)
-
 	ctx, b.CancelFunc = context.WithCancel(ctx)
 
 	file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0644)
@@ -170,7 +169,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
 				}
 			}
 
-			return errors.New("max retries exceeded")
+			return errMaxRetriesExceeded
 		})
 	}
 
@@ -308,6 +307,8 @@ type downloadOpts struct {
 
 const maxRetries = 3
 
+var errMaxRetriesExceeded = errors.New("max retries exceeded")
+
 // downloadBlob downloads a blob from the registry and stores it in the blobs directory
 func downloadBlob(ctx context.Context, opts downloadOpts) error {
 	fp, err := GetBlobsPath(opts.digest)

+ 2 - 60
server/images.go

@@ -981,46 +981,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
 	layers = append(layers, &manifest.Config)
 
 	for _, layer := range layers {
-		exists, err := checkBlobExistence(ctx, mp, layer.Digest, regOpts)
-		if err != nil {
-			return err
-		}
-
-		if exists {
-			fn(api.ProgressResponse{
-				Status:    "using existing layer",
-				Digest:    layer.Digest,
-				Total:     layer.Size,
-				Completed: layer.Size,
-			})
-			log.Printf("Layer %s already exists", layer.Digest)
-			continue
-		}
-
-		fn(api.ProgressResponse{
-			Status: "starting upload",
-			Digest: layer.Digest,
-			Total:  layer.Size,
-		})
-
-		location, chunkSize, err := startUpload(ctx, mp, layer, regOpts)
-		if err != nil {
-			log.Printf("couldn't start upload: %v", err)
-			return err
-		}
-
-		if strings.HasPrefix(filepath.Base(location.Path), "sha256:") {
-			layer.Digest = filepath.Base(location.Path)
-			fn(api.ProgressResponse{
-				Status:    "using existing layer",
-				Digest:    layer.Digest,
-				Total:     layer.Size,
-				Completed: layer.Size,
-			})
-			continue
-		}
-
-		if err := uploadBlob(ctx, location, layer, chunkSize, regOpts, fn); err != nil {
+		if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
 			log.Printf("error uploading blob: %v", err)
 			return err
 		}
@@ -1218,24 +1179,7 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
 	return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
 }
 
-// Function to check if a blob already exists in the Docker registry
-func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
-	requestURL := mp.BaseURL()
-	requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", digest)
-
-	resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, regOpts)
-	if err != nil {
-		log.Printf("couldn't check for blob: %v", err)
-		return false, err
-	}
-	defer resp.Body.Close()
-
-	// Check for success: If the blob exists, the Docker registry will respond with a 200 OK
-	return resp.StatusCode < http.StatusBadRequest, nil
-}
-
 func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
-	var status string
 	for try := 0; try < maxRetries; try++ {
 		resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
 		if err != nil {
@@ -1243,8 +1187,6 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
 			return nil, err
 		}
 
-		status = resp.Status
-
 		switch {
 		case resp.StatusCode == http.StatusUnauthorized:
 			auth := resp.Header.Get("www-authenticate")
@@ -1270,7 +1212,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
 		}
 	}
 
-	return nil, fmt.Errorf("max retry exceeded: %v", status)
+	return nil, errMaxRetriesExceeded
 }
 
 func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {

+ 3 - 1
server/routes.go

@@ -365,7 +365,9 @@ func PushModelHandler(c *gin.Context) {
 			Insecure: req.Insecure,
 		}
 
-		ctx := context.Background()
+		ctx, cancel := context.WithCancel(c.Request.Context())
+		defer cancel()
+
 		if err := PushModel(ctx, req.Name, regOpts, fn); err != nil {
 			ch <- gin.H{"error": err.Error()}
 		}

+ 279 - 126
server/upload.go

@@ -2,218 +2,371 @@ package server
 
 import (
 	"context"
+	"crypto/md5"
 	"errors"
 	"fmt"
+	"hash"
 	"io"
 	"log"
 	"net/http"
 	"net/url"
 	"os"
-	"strconv"
+	"strings"
 	"sync"
+	"sync/atomic"
+	"time"
 
 	"github.com/jmorganca/ollama/api"
+	"github.com/jmorganca/ollama/format"
+	"golang.org/x/sync/errgroup"
 )
 
+var blobUploadManager sync.Map
+
+type blobUpload struct {
+	*Layer
+
+	Total     int64
+	Completed atomic.Int64
+
+	Parts []blobUploadPart
+
+	nextURL chan *url.URL
+
+	context.CancelFunc
+
+	done       bool
+	err        error
+	references atomic.Int32
+}
+
+type blobUploadPart struct {
+	// N is the part number
+	N      int
+	Offset int64
+	Size   int64
+	hash.Hash
+}
+
 const (
-	redirectChunkSize int64 = 1024 * 1024 * 1024
-	regularChunkSize  int64 = 95 * 1024 * 1024
+	numUploadParts          = 64
+	minUploadPartSize int64 = 95 * 1000 * 1000
+	maxUploadPartSize int64 = 1000 * 1000 * 1000
 )
 
-func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (*url.URL, int64, error) {
-	requestURL := mp.BaseURL()
-	requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
-	if layer.From != "" {
+func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
+	p, err := GetBlobsPath(b.Digest)
+	if err != nil {
+		return err
+	}
+
+	if b.From != "" {
 		values := requestURL.Query()
-		values.Add("mount", layer.Digest)
-		values.Add("from", layer.From)
+		values.Add("mount", b.Digest)
+		values.Add("from", b.From)
 		requestURL.RawQuery = values.Encode()
 	}
 
-	resp, err := makeRequestWithRetry(ctx, "POST", requestURL, nil, nil, regOpts)
+	resp, err := makeRequestWithRetry(ctx, "POST", requestURL, nil, nil, opts)
 	if err != nil {
-		log.Printf("couldn't start upload: %v", err)
-		return nil, 0, err
+		return err
 	}
 	defer resp.Body.Close()
 
 	location := resp.Header.Get("Docker-Upload-Location")
-	chunkSize := redirectChunkSize
 	if location == "" {
 		location = resp.Header.Get("Location")
-		chunkSize = regularChunkSize
 	}
 
-	locationURL, err := url.Parse(location)
+	fi, err := os.Stat(p)
 	if err != nil {
-		return nil, 0, err
+		return err
 	}
 
-	return locationURL, chunkSize, nil
-}
+	b.Total = fi.Size()
 
-func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSize int64, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
-	// TODO allow resumability
-	// TODO allow canceling uploads via DELETE
+	var size = b.Total / numUploadParts
+	switch {
+	case size < minUploadPartSize:
+		size = minUploadPartSize
+	case size > maxUploadPartSize:
+		size = maxUploadPartSize
+	}
 
-	fp, err := GetBlobsPath(layer.Digest)
-	if err != nil {
-		return err
+	var offset int64
+	for offset < fi.Size() {
+		if offset+size > fi.Size() {
+			size = fi.Size() - offset
+		}
+
+		// set part.N to the current number of parts
+		b.Parts = append(b.Parts, blobUploadPart{N: len(b.Parts), Offset: offset, Size: size, Hash: md5.New()})
+		offset += size
 	}
 
-	f, err := os.Open(fp)
+	log.Printf("uploading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(size))
+
+	requestURL, err = url.Parse(location)
 	if err != nil {
 		return err
 	}
-	defer f.Close()
 
-	pw := ProgressWriter{
-		status: fmt.Sprintf("uploading %s", layer.Digest),
-		digest: layer.Digest,
-		total:  layer.Size,
-		fn:     fn,
+	b.nextURL = make(chan *url.URL, 1)
+	b.nextURL <- requestURL
+	return nil
+}
+
+// Run uploads blob parts to the upstream. If the upstream supports redirection, parts will be uploaded
+// in parallel as defined by Prepare. Otherwise, parts will be uploaded serially. Run sets b.err on error.
+func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
+	defer blobUploadManager.Delete(b.Digest)
+	ctx, b.CancelFunc = context.WithCancel(ctx)
+
+	p, err := GetBlobsPath(b.Digest)
+	if err != nil {
+		b.err = err
+		return
 	}
 
-	for offset := int64(0); offset < layer.Size; {
-		chunk := layer.Size - offset
-		if chunk > chunkSize {
-			chunk = chunkSize
-		}
+	f, err := os.Open(p)
+	if err != nil {
+		b.err = err
+		return
+	}
+	defer f.Close()
 
-		resp, err := uploadBlobChunk(ctx, http.MethodPatch, requestURL, f, offset, chunk, regOpts, &pw)
-		if err != nil {
-			fn(api.ProgressResponse{
-				Status:    fmt.Sprintf("error uploading chunk: %v", err),
-				Digest:    layer.Digest,
-				Total:     layer.Size,
-				Completed: offset,
+	g, inner := errgroup.WithContext(ctx)
+	g.SetLimit(numUploadParts)
+	for i := range b.Parts {
+		part := &b.Parts[i]
+		select {
+		case <-inner.Done():
+		case requestURL := <-b.nextURL:
+			g.Go(func() error {
+				for try := 0; try < maxRetries; try++ {
+					r := io.NewSectionReader(f, part.Offset, part.Size)
+					err := b.uploadChunk(inner, http.MethodPatch, requestURL, r, part, opts)
+					switch {
+					case errors.Is(err, context.Canceled):
+						return err
+					case errors.Is(err, errMaxRetriesExceeded):
+						return err
+					case err != nil:
+						log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], part.N, try, err)
+						continue
+					}
+
+					return nil
+				}
+
+				return errMaxRetriesExceeded
 			})
-
-			return err
 		}
+	}
 
-		offset += chunk
-		location := resp.Header.Get("Docker-Upload-Location")
-		if location == "" {
-			location = resp.Header.Get("Location")
-		}
+	if err := g.Wait(); err != nil {
+		b.err = err
+		return
+	}
 
-		requestURL, err = url.Parse(location)
-		if err != nil {
-			return err
-		}
+	requestURL := <-b.nextURL
+
+	var sb strings.Builder
+	for _, part := range b.Parts {
+		sb.Write(part.Sum(nil))
 	}
 
+	md5sum := md5.Sum([]byte(sb.String()))
+
 	values := requestURL.Query()
-	values.Add("digest", layer.Digest)
+	values.Add("digest", b.Digest)
+	values.Add("etag", fmt.Sprintf("%x-%d", md5sum, len(b.Parts)))
 	requestURL.RawQuery = values.Encode()
 
 	headers := make(http.Header)
 	headers.Set("Content-Type", "application/octet-stream")
 	headers.Set("Content-Length", "0")
 
-	// finish the upload
-	resp, err := makeRequest(ctx, "PUT", requestURL, headers, nil, regOpts)
+	resp, err := makeRequestWithRetry(ctx, "PUT", requestURL, headers, nil, opts)
 	if err != nil {
-		log.Printf("couldn't finish upload: %v", err)
-		return err
+		b.err = err
+		return
 	}
 	defer resp.Body.Close()
 
-	if resp.StatusCode >= http.StatusBadRequest {
-		body, _ := io.ReadAll(resp.Body)
-		return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
-	}
-	return nil
+	b.done = true
 }
 
-func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r io.ReaderAt, offset, limit int64, opts *RegistryOptions, pw *ProgressWriter) (*http.Response, error) {
-	sectionReader := io.NewSectionReader(r, offset, limit)
-
+func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL *url.URL, rs io.ReadSeeker, part *blobUploadPart, opts *RegistryOptions) error {
 	headers := make(http.Header)
 	headers.Set("Content-Type", "application/octet-stream")
-	headers.Set("Content-Length", strconv.Itoa(int(limit)))
+	headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
 	headers.Set("X-Redirect-Uploads", "1")
 
 	if method == http.MethodPatch {
-		headers.Set("Content-Range", fmt.Sprintf("%d-%d", offset, offset+sectionReader.Size()-1))
+		headers.Set("Content-Range", fmt.Sprintf("%d-%d", part.Offset, part.Offset+part.Size-1))
 	}
 
-	for try := 0; try < maxRetries; try++ {
-		resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sectionReader, pw), opts)
-		if err != nil && !errors.Is(err, io.EOF) {
-			return nil, err
-		}
-		defer resp.Body.Close()
+	buw := blobUploadWriter{blobUpload: b}
+	resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(rs, io.MultiWriter(&buw, part.Hash)), opts)
+	if err != nil {
+		return err
+	}
+	defer resp.Body.Close()
 
-		switch {
-		case resp.StatusCode == http.StatusTemporaryRedirect:
-			location, err := resp.Location()
-			if err != nil {
-				return nil, err
-			}
+	location := resp.Header.Get("Docker-Upload-Location")
+	if location == "" {
+		location = resp.Header.Get("Location")
+	}
+
+	nextURL, err := url.Parse(location)
+	if err != nil {
+		return err
+	}
+
+	switch {
+	case resp.StatusCode == http.StatusTemporaryRedirect:
+		b.nextURL <- nextURL
+
+		redirectURL, err := resp.Location()
+		if err != nil {
+			return err
+		}
 
-			pw.completed = offset
-			if _, err := uploadBlobChunk(ctx, http.MethodPut, location, r, offset, limit, nil, pw); err != nil {
-				// retry
-				log.Printf("retrying redirected upload: %v", err)
+		for try := 0; try < maxRetries; try++ {
+			rs.Seek(0, io.SeekStart)
+			b.Completed.Add(-buw.written)
+			buw.written = 0
+			part.Hash = md5.New()
+			err := b.uploadChunk(ctx, http.MethodPut, redirectURL, rs, part, nil)
+			switch {
+			case errors.Is(err, context.Canceled):
+				return err
+			case errors.Is(err, errMaxRetriesExceeded):
+				return err
+			case err != nil:
+				log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], part.N, try, err)
 				continue
 			}
 
-			return resp, nil
-		case resp.StatusCode == http.StatusUnauthorized:
-			auth := resp.Header.Get("www-authenticate")
-			authRedir := ParseAuthRedirectString(auth)
-			token, err := getAuthToken(ctx, authRedir)
-			if err != nil {
-				return nil, err
-			}
+			return nil
+		}
 
-			opts.Token = token
+		return errMaxRetriesExceeded
 
-			pw.completed = offset
-			sectionReader = io.NewSectionReader(r, offset, limit)
-			continue
-		case resp.StatusCode >= http.StatusBadRequest:
-			body, _ := io.ReadAll(resp.Body)
-			return nil, fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
+	case resp.StatusCode == http.StatusUnauthorized:
+		auth := resp.Header.Get("www-authenticate")
+		authRedir := ParseAuthRedirectString(auth)
+		token, err := getAuthToken(ctx, authRedir)
+		if err != nil {
+			return err
 		}
 
-		return resp, nil
+		opts.Token = token
+		fallthrough
+	case resp.StatusCode >= http.StatusBadRequest:
+		body, err := io.ReadAll(resp.Body)
+		if err != nil {
+			return err
+		}
+
+		rs.Seek(0, io.SeekStart)
+		b.Completed.Add(-buw.written)
+		buw.written = 0
+		return fmt.Errorf("http status %d %s: %s", resp.StatusCode, resp.Status, body)
+	}
+
+	if method == http.MethodPatch {
+		b.nextURL <- nextURL
 	}
 
-	return nil, fmt.Errorf("max retries exceeded")
+	return nil
+}
+
+func (b *blobUpload) acquire() {
+	b.references.Add(1)
 }
 
-type ProgressWriter struct {
-	status    string
-	digest    string
-	bucket    int64
-	completed int64
-	total     int64
-	fn        func(api.ProgressResponse)
-	mu        sync.Mutex
+func (b *blobUpload) release() {
+	if b.references.Add(-1) == 0 {
+		b.CancelFunc()
+	}
 }
 
-func (pw *ProgressWriter) Write(b []byte) (int, error) {
-	pw.mu.Lock()
-	defer pw.mu.Unlock()
-
-	n := len(b)
-	pw.bucket += int64(n)
-
-	// throttle status updates to not spam the client
-	if pw.bucket >= 1024*1024 || pw.completed+pw.bucket >= pw.total {
-		pw.completed += pw.bucket
-		pw.fn(api.ProgressResponse{
-			Status:    pw.status,
-			Digest:    pw.digest,
-			Total:     pw.total,
-			Completed: pw.completed,
+func (b *blobUpload) Wait(ctx context.Context, fn func(api.ProgressResponse)) error {
+	b.acquire()
+	defer b.release()
+
+	ticker := time.NewTicker(60 * time.Millisecond)
+	for {
+		select {
+		case <-ticker.C:
+		case <-ctx.Done():
+			return ctx.Err()
+		}
+
+		fn(api.ProgressResponse{
+			Status:    fmt.Sprintf("uploading %s", b.Digest),
+			Digest:    b.Digest,
+			Total:     b.Total,
+			Completed: b.Completed.Load(),
 		})
 
-		pw.bucket = 0
+		if b.done || b.err != nil {
+			return b.err
+		}
 	}
+}
 
+type blobUploadWriter struct {
+	written int64
+	*blobUpload
+}
+
+func (b *blobUploadWriter) Write(p []byte) (n int, err error) {
+	n = len(p)
+	b.written += int64(n)
+	b.Completed.Add(int64(n))
 	return n, nil
 }
+
+func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *RegistryOptions, fn func(api.ProgressResponse)) error {
+	requestURL := mp.BaseURL()
+	requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
+
+	resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, opts)
+	if err != nil {
+		return err
+	}
+	defer resp.Body.Close()
+
+	switch resp.StatusCode {
+	case http.StatusNotFound:
+	case http.StatusOK:
+		fn(api.ProgressResponse{
+			Status:    fmt.Sprintf("uploading %s", layer.Digest),
+			Digest:    layer.Digest,
+			Total:     layer.Size,
+			Completed: layer.Size,
+		})
+
+		return nil
+	default:
+		return fmt.Errorf("unexpected status code %d", resp.StatusCode)
+	}
+
+	data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer})
+	upload := data.(*blobUpload)
+	if !ok {
+		requestURL := mp.BaseURL()
+		requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
+		if err := upload.Prepare(ctx, requestURL, opts); err != nil {
+			blobUploadManager.Delete(layer.Digest)
+			return err
+		}
+
+		go upload.Run(context.Background(), opts)
+	}
+
+	return upload.Wait(ctx, fn)
+}