瀏覽代碼

handle concurrent requests for the same blobs

Michael Yang 1 年之前
父節點
當前提交
5b84404c64
共有 1 個文件被更改,包括 183 次插入99 次删除
  1. 183 99
      server/download.go

+ 183 - 99
server/download.go

@@ -12,124 +12,107 @@ import (
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
 	"strconv"
 	"strconv"
+	"sync"
+	"sync/atomic"
+	"time"
 
 
-	"github.com/jmorganca/ollama/api"
 	"golang.org/x/sync/errgroup"
 	"golang.org/x/sync/errgroup"
+
+	"github.com/jmorganca/ollama/api"
 )
 )
 
 
-type BlobDownloadPart struct {
-	Offset    int64
-	Size      int64
-	Completed int64
-}
+var blobDownloadManager sync.Map
 
 
-type downloadOpts struct {
-	mp      ModelPath
-	digest  string
-	regOpts *RegistryOptions
-	fn      func(api.ProgressResponse)
-}
+type blobDownload struct {
+	Name   string
+	Digest string
 
 
-const maxRetries = 3
+	Total     int64
+	Completed atomic.Int64
 
 
-// downloadBlob downloads a blob from the registry and stores it in the blobs directory
-func downloadBlob(ctx context.Context, opts downloadOpts) error {
-	fp, err := GetBlobsPath(opts.digest)
-	if err != nil {
-		return err
-	}
+	*os.File
+	Parts []*blobDownloadPart
 
 
-	fi, err := os.Stat(fp)
-	switch {
-	case errors.Is(err, os.ErrNotExist):
-	case err != nil:
-		return err
-	default:
-		opts.fn(api.ProgressResponse{
-			Status:    fmt.Sprintf("downloading %s", opts.digest),
-			Digest:    opts.digest,
-			Total:     fi.Size(),
-			Completed: fi.Size(),
-		})
+	done chan struct{}
+	context.CancelFunc
+	refCount atomic.Int32
+}
 
 
-		return nil
-	}
+type blobDownloadPart struct {
+	Offset    int64
+	Size      int64
+	Completed int64
+}
 
 
-	f, err := os.OpenFile(fp+"-partial", os.O_CREATE|os.O_RDWR, 0644)
-	if err != nil {
-		return err
-	}
-	defer f.Close()
+func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
+	b.done = make(chan struct{}, 1)
 
 
-	partFilePaths, err := filepath.Glob(fp + "-partial-*")
+	partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	var total, completed int64
-	var parts []BlobDownloadPart
 	for _, partFilePath := range partFilePaths {
 	for _, partFilePath := range partFilePaths {
-		var part BlobDownloadPart
-		partFile, err := os.Open(partFilePath)
+		part, err := b.readPart(partFilePath)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
-		defer partFile.Close()
-
-		if err := json.NewDecoder(partFile).Decode(&part); err != nil {
-			return err
-		}
 
 
-		total += part.Size
-		completed += part.Completed
-
-		parts = append(parts, part)
+		b.Total += part.Size
+		b.Completed.Add(part.Completed)
+		b.Parts = append(b.Parts, part)
 	}
 	}
 
 
-	requestURL := opts.mp.BaseURL()
-	requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
-
-	if len(parts) == 0 {
-		resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, opts.regOpts)
+	if len(b.Parts) == 0 {
+		resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, opts)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
 		defer resp.Body.Close()
 		defer resp.Body.Close()
 
 
-		total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
-
-		// reserve the file
-		f.Truncate(total)
+		b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
 
 
 		var offset int64
 		var offset int64
 		var size int64 = 64 * 1024 * 1024
 		var size int64 = 64 * 1024 * 1024
 
 
-		for offset < total {
-			if offset+size > total {
-				size = total - offset
+		for offset < b.Total {
+			if offset+size > b.Total {
+				size = b.Total - offset
+			}
+
+			partName := b.Name + "-partial-" + strconv.Itoa(len(b.Parts))
+			part := blobDownloadPart{Offset: offset, Size: size}
+			if err := b.writePart(partName, &part); err != nil {
+				return err
 			}
 			}
 
 
-			parts = append(parts, BlobDownloadPart{
-				Offset: offset,
-				Size:   size,
-			})
+			b.Parts = append(b.Parts, &part)
 
 
 			offset += size
 			offset += size
 		}
 		}
 	}
 	}
 
 
-	pw := &ProgressWriter{
-		status:    fmt.Sprintf("downloading %s", opts.digest),
-		digest:    opts.digest,
-		total:     total,
-		completed: completed,
-		fn:        opts.fn,
+	log.Printf("downloading %s in %d part(s)", b.Digest[7:19], len(b.Parts))
+	return nil
+}
+
+func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) (err error) {
+	defer blobDownloadManager.Delete(b.Digest)
+
+	ctx, b.CancelFunc = context.WithCancel(ctx)
+
+	b.File, err = os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0644)
+	if err != nil {
+		return err
 	}
 	}
+	defer b.Close()
+
+	b.Truncate(b.Total)
 
 
 	g, ctx := errgroup.WithContext(ctx)
 	g, ctx := errgroup.WithContext(ctx)
 	g.SetLimit(64)
 	g.SetLimit(64)
-	for i := range parts {
-		part := parts[i]
+	for i := range b.Parts {
+		part := b.Parts[i]
 		if part.Completed == part.Size {
 		if part.Completed == part.Size {
 			continue
 			continue
 		}
 		}
@@ -137,12 +120,16 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error {
 		i := i
 		i := i
 		g.Go(func() error {
 		g.Go(func() error {
 			for try := 0; try < maxRetries; try++ {
 			for try := 0; try < maxRetries; try++ {
-				if err := downloadBlobChunk(ctx, f, requestURL, parts, i, pw, opts); err != nil {
-					log.Printf("%s part %d attempt %d failed: %v, retrying", opts.digest[7:19], i, try, err)
+				err := b.downloadChunk(ctx, requestURL, i, opts)
+				switch {
+				case errors.Is(err, context.Canceled):
+					return err
+				case err != nil:
+					log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], i, try, err)
 					continue
 					continue
+				default:
+					return nil
 				}
 				}
-
-				return nil
 			}
 			}
 
 
 			return errors.New("max retries exceeded")
 			return errors.New("max retries exceeded")
@@ -153,52 +140,67 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error {
 		return err
 		return err
 	}
 	}
 
 
-	if err := f.Close(); err != nil {
+	if err := b.Close(); err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	for i := range parts {
-		if err := os.Remove(f.Name() + "-" + strconv.Itoa(i)); err != nil {
+	for i := range b.Parts {
+		if err := os.Remove(b.File.Name() + "-" + strconv.Itoa(i)); err != nil {
 			return err
 			return err
 		}
 		}
 	}
 	}
 
 
-	return os.Rename(f.Name(), fp)
-}
-
-func downloadBlobChunk(ctx context.Context, f *os.File, requestURL *url.URL, parts []BlobDownloadPart, i int, pw *ProgressWriter, opts downloadOpts) error {
-	part := &parts[i]
-
-	partName := f.Name() + "-" + strconv.Itoa(i)
-	if err := flushPart(partName, part); err != nil {
+	if err := os.Rename(b.File.Name(), b.Name); err != nil {
 		return err
 		return err
 	}
 	}
 
 
+	close(b.done)
+	return nil
+}
+
+func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, i int, opts *RegistryOptions) error {
+	part := b.Parts[i]
+
+	partName := b.File.Name() + "-" + strconv.Itoa(i)
 	offset := part.Offset + part.Completed
 	offset := part.Offset + part.Completed
-	w := io.NewOffsetWriter(f, offset)
+	w := io.NewOffsetWriter(b.File, offset)
 
 
 	headers := make(http.Header)
 	headers := make(http.Header)
 	headers.Set("Range", fmt.Sprintf("bytes=%d-%d", offset, part.Offset+part.Size-1))
 	headers.Set("Range", fmt.Sprintf("bytes=%d-%d", offset, part.Offset+part.Size-1))
-	resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts.regOpts)
+	resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 	defer resp.Body.Close()
 	defer resp.Body.Close()
 
 
-	n, err := io.Copy(w, io.TeeReader(resp.Body, pw))
+	n, err := io.Copy(w, io.TeeReader(resp.Body, b))
 	if err != nil && !errors.Is(err, io.EOF) {
 	if err != nil && !errors.Is(err, io.EOF) {
-		// rollback progress bar
-		pw.completed -= n
+		// rollback progress
+		b.Completed.Add(-n)
 		return err
 		return err
 	}
 	}
 
 
 	part.Completed += n
 	part.Completed += n
+	return b.writePart(partName, part)
+}
+
+func (b *blobDownload) readPart(partName string) (*blobDownloadPart, error) {
+	var part blobDownloadPart
+	partFile, err := os.Open(partName)
+	if err != nil {
+		return nil, err
+	}
+	defer partFile.Close()
+
+	if err := json.NewDecoder(partFile).Decode(&part); err != nil {
+		return nil, err
+	}
 
 
-	return flushPart(partName, part)
+	return &part, nil
 }
 }
 
 
-func flushPart(name string, part *BlobDownloadPart) error {
-	partFile, err := os.OpenFile(name, os.O_CREATE|os.O_RDWR, 0644)
+func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error {
+	partFile, err := os.OpenFile(partName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0644)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -206,3 +208,85 @@ func flushPart(name string, part *BlobDownloadPart) error {
 
 
 	return json.NewEncoder(partFile).Encode(part)
 	return json.NewEncoder(partFile).Encode(part)
 }
 }
+
+func (b *blobDownload) Write(p []byte) (n int, err error) {
+	n = len(p)
+	b.Completed.Add(int64(n))
+	return n, nil
+}
+
+func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) error {
+	b.refCount.Add(1)
+
+	ticker := time.NewTicker(60 * time.Millisecond)
+	for {
+		select {
+		case <-ticker.C:
+		case <-ctx.Done():
+			if b.refCount.Add(-1) == 0 {
+				b.CancelFunc()
+			}
+
+			return ctx.Err()
+		}
+
+		fn(api.ProgressResponse{
+			Status:    fmt.Sprintf("downloading %s", b.Digest),
+			Digest:    b.Digest,
+			Total:     b.Total,
+			Completed: b.Completed.Load(),
+		})
+
+		if b.Completed.Load() >= b.Total {
+			<-b.done
+			return nil
+		}
+	}
+}
+
+type downloadOpts struct {
+	mp      ModelPath
+	digest  string
+	regOpts *RegistryOptions
+	fn      func(api.ProgressResponse)
+}
+
+const maxRetries = 3
+
+// downloadBlob downloads a blob from the registry and stores it in the blobs directory
+func downloadBlob(ctx context.Context, opts downloadOpts) error {
+	fp, err := GetBlobsPath(opts.digest)
+	if err != nil {
+		return err
+	}
+
+	fi, err := os.Stat(fp)
+	switch {
+	case errors.Is(err, os.ErrNotExist):
+	case err != nil:
+		return err
+	default:
+		opts.fn(api.ProgressResponse{
+			Status:    fmt.Sprintf("downloading %s", opts.digest),
+			Digest:    opts.digest,
+			Total:     fi.Size(),
+			Completed: fi.Size(),
+		})
+
+		return nil
+	}
+
+	value, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
+	blobDownload := value.(*blobDownload)
+	if !ok {
+		requestURL := opts.mp.BaseURL()
+		requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
+		if err := blobDownload.Prepare(ctx, requestURL, opts.regOpts); err != nil {
+			return err
+		}
+
+		go blobDownload.Run(context.Background(), requestURL, opts.regOpts)
+	}
+
+	return blobDownload.Wait(ctx, opts.fn)
+}