|
@@ -2,218 +2,371 @@ package server
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
+ "crypto/md5"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
+ "hash"
|
|
|
"io"
|
|
|
"log"
|
|
|
"net/http"
|
|
|
"net/url"
|
|
|
"os"
|
|
|
- "strconv"
|
|
|
+ "strings"
|
|
|
"sync"
|
|
|
+ "sync/atomic"
|
|
|
+ "time"
|
|
|
|
|
|
"github.com/jmorganca/ollama/api"
|
|
|
+ "github.com/jmorganca/ollama/format"
|
|
|
+ "golang.org/x/sync/errgroup"
|
|
|
)
|
|
|
|
|
|
+var blobUploadManager sync.Map
|
|
|
+
|
|
|
+type blobUpload struct {
|
|
|
+ *Layer
|
|
|
+
|
|
|
+ Total int64
|
|
|
+ Completed atomic.Int64
|
|
|
+
|
|
|
+ Parts []blobUploadPart
|
|
|
+
|
|
|
+ nextURL chan *url.URL
|
|
|
+
|
|
|
+ context.CancelFunc
|
|
|
+
|
|
|
+ done bool
|
|
|
+ err error
|
|
|
+ references atomic.Int32
|
|
|
+}
|
|
|
+
|
|
|
+type blobUploadPart struct {
|
|
|
+ // N is the part number
|
|
|
+ N int
|
|
|
+ Offset int64
|
|
|
+ Size int64
|
|
|
+ hash.Hash
|
|
|
+}
|
|
|
+
|
|
|
const (
|
|
|
- redirectChunkSize int64 = 1024 * 1024 * 1024
|
|
|
- regularChunkSize int64 = 95 * 1024 * 1024
|
|
|
+ numUploadParts = 64
|
|
|
+ minUploadPartSize int64 = 95 * 1000 * 1000
|
|
|
+ maxUploadPartSize int64 = 1000 * 1000 * 1000
|
|
|
)
|
|
|
|
|
|
-func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (*url.URL, int64, error) {
|
|
|
- requestURL := mp.BaseURL()
|
|
|
- requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
|
|
|
- if layer.From != "" {
|
|
|
+func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
|
|
|
+ p, err := GetBlobsPath(b.Digest)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ if b.From != "" {
|
|
|
values := requestURL.Query()
|
|
|
- values.Add("mount", layer.Digest)
|
|
|
- values.Add("from", layer.From)
|
|
|
+ values.Add("mount", b.Digest)
|
|
|
+ values.Add("from", b.From)
|
|
|
requestURL.RawQuery = values.Encode()
|
|
|
}
|
|
|
|
|
|
- resp, err := makeRequestWithRetry(ctx, "POST", requestURL, nil, nil, regOpts)
|
|
|
+ resp, err := makeRequestWithRetry(ctx, "POST", requestURL, nil, nil, opts)
|
|
|
if err != nil {
|
|
|
- log.Printf("couldn't start upload: %v", err)
|
|
|
- return nil, 0, err
|
|
|
+ return err
|
|
|
}
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
location := resp.Header.Get("Docker-Upload-Location")
|
|
|
- chunkSize := redirectChunkSize
|
|
|
if location == "" {
|
|
|
location = resp.Header.Get("Location")
|
|
|
- chunkSize = regularChunkSize
|
|
|
}
|
|
|
|
|
|
- locationURL, err := url.Parse(location)
|
|
|
+ fi, err := os.Stat(p)
|
|
|
if err != nil {
|
|
|
- return nil, 0, err
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
- return locationURL, chunkSize, nil
|
|
|
-}
|
|
|
+ b.Total = fi.Size()
|
|
|
|
|
|
-func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSize int64, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
|
|
- // TODO allow resumability
|
|
|
- // TODO allow canceling uploads via DELETE
|
|
|
+ var size = b.Total / numUploadParts
|
|
|
+ switch {
|
|
|
+ case size < minUploadPartSize:
|
|
|
+ size = minUploadPartSize
|
|
|
+ case size > maxUploadPartSize:
|
|
|
+ size = maxUploadPartSize
|
|
|
+ }
|
|
|
|
|
|
- fp, err := GetBlobsPath(layer.Digest)
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
+ var offset int64
|
|
|
+ for offset < fi.Size() {
|
|
|
+ if offset+size > fi.Size() {
|
|
|
+ size = fi.Size() - offset
|
|
|
+ }
|
|
|
+
|
|
|
+ // set part.N to the current number of parts
|
|
|
+ b.Parts = append(b.Parts, blobUploadPart{N: len(b.Parts), Offset: offset, Size: size, Hash: md5.New()})
|
|
|
+ offset += size
|
|
|
}
|
|
|
|
|
|
- f, err := os.Open(fp)
|
|
|
+ log.Printf("uploading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(size))
|
|
|
+
|
|
|
+ requestURL, err = url.Parse(location)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- defer f.Close()
|
|
|
|
|
|
- pw := ProgressWriter{
|
|
|
- status: fmt.Sprintf("uploading %s", layer.Digest),
|
|
|
- digest: layer.Digest,
|
|
|
- total: layer.Size,
|
|
|
- fn: fn,
|
|
|
+ b.nextURL = make(chan *url.URL, 1)
|
|
|
+ b.nextURL <- requestURL
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+// Run uploads blob parts to the upstream. If the upstream supports redirection, parts will be uploaded
|
|
|
+// in parallel as defined by Prepare. Otherwise, parts will be uploaded serially. Run sets b.err on error.
|
|
|
+func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
|
|
|
+ defer blobUploadManager.Delete(b.Digest)
|
|
|
+ ctx, b.CancelFunc = context.WithCancel(ctx)
|
|
|
+
|
|
|
+ p, err := GetBlobsPath(b.Digest)
|
|
|
+ if err != nil {
|
|
|
+ b.err = err
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
- for offset := int64(0); offset < layer.Size; {
|
|
|
- chunk := layer.Size - offset
|
|
|
- if chunk > chunkSize {
|
|
|
- chunk = chunkSize
|
|
|
- }
|
|
|
+ f, err := os.Open(p)
|
|
|
+ if err != nil {
|
|
|
+ b.err = err
|
|
|
+ return
|
|
|
+ }
|
|
|
+ defer f.Close()
|
|
|
|
|
|
- resp, err := uploadBlobChunk(ctx, http.MethodPatch, requestURL, f, offset, chunk, regOpts, &pw)
|
|
|
- if err != nil {
|
|
|
- fn(api.ProgressResponse{
|
|
|
- Status: fmt.Sprintf("error uploading chunk: %v", err),
|
|
|
- Digest: layer.Digest,
|
|
|
- Total: layer.Size,
|
|
|
- Completed: offset,
|
|
|
+ g, inner := errgroup.WithContext(ctx)
|
|
|
+ g.SetLimit(numUploadParts)
|
|
|
+ for i := range b.Parts {
|
|
|
+ part := &b.Parts[i]
|
|
|
+ select {
|
|
|
+ case <-inner.Done():
|
|
|
+ case requestURL := <-b.nextURL:
|
|
|
+ g.Go(func() error {
|
|
|
+ for try := 0; try < maxRetries; try++ {
|
|
|
+ r := io.NewSectionReader(f, part.Offset, part.Size)
|
|
|
+ err := b.uploadChunk(inner, http.MethodPatch, requestURL, r, part, opts)
|
|
|
+ switch {
|
|
|
+ case errors.Is(err, context.Canceled):
|
|
|
+ return err
|
|
|
+ case errors.Is(err, errMaxRetriesExceeded):
|
|
|
+ return err
|
|
|
+ case err != nil:
|
|
|
+ log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], part.N, try, err)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ return errMaxRetriesExceeded
|
|
|
})
|
|
|
-
|
|
|
- return err
|
|
|
}
|
|
|
+ }
|
|
|
|
|
|
- offset += chunk
|
|
|
- location := resp.Header.Get("Docker-Upload-Location")
|
|
|
- if location == "" {
|
|
|
- location = resp.Header.Get("Location")
|
|
|
- }
|
|
|
+ if err := g.Wait(); err != nil {
|
|
|
+ b.err = err
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- requestURL, err = url.Parse(location)
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
+ requestURL := <-b.nextURL
|
|
|
+
|
|
|
+ var sb strings.Builder
|
|
|
+ for _, part := range b.Parts {
|
|
|
+ sb.Write(part.Sum(nil))
|
|
|
}
|
|
|
|
|
|
+ md5sum := md5.Sum([]byte(sb.String()))
|
|
|
+
|
|
|
values := requestURL.Query()
|
|
|
- values.Add("digest", layer.Digest)
|
|
|
+ values.Add("digest", b.Digest)
|
|
|
+ values.Add("etag", fmt.Sprintf("%x-%d", md5sum, len(b.Parts)))
|
|
|
requestURL.RawQuery = values.Encode()
|
|
|
|
|
|
headers := make(http.Header)
|
|
|
headers.Set("Content-Type", "application/octet-stream")
|
|
|
headers.Set("Content-Length", "0")
|
|
|
|
|
|
- // finish the upload
|
|
|
- resp, err := makeRequest(ctx, "PUT", requestURL, headers, nil, regOpts)
|
|
|
+ resp, err := makeRequestWithRetry(ctx, "PUT", requestURL, headers, nil, opts)
|
|
|
if err != nil {
|
|
|
- log.Printf("couldn't finish upload: %v", err)
|
|
|
- return err
|
|
|
+ b.err = err
|
|
|
+ return
|
|
|
}
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
- if resp.StatusCode >= http.StatusBadRequest {
|
|
|
- body, _ := io.ReadAll(resp.Body)
|
|
|
- return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
|
|
|
- }
|
|
|
- return nil
|
|
|
+ b.done = true
|
|
|
}
|
|
|
|
|
|
-func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r io.ReaderAt, offset, limit int64, opts *RegistryOptions, pw *ProgressWriter) (*http.Response, error) {
|
|
|
- sectionReader := io.NewSectionReader(r, offset, limit)
|
|
|
-
|
|
|
+func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL *url.URL, rs io.ReadSeeker, part *blobUploadPart, opts *RegistryOptions) error {
|
|
|
headers := make(http.Header)
|
|
|
headers.Set("Content-Type", "application/octet-stream")
|
|
|
- headers.Set("Content-Length", strconv.Itoa(int(limit)))
|
|
|
+ headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
|
|
|
headers.Set("X-Redirect-Uploads", "1")
|
|
|
|
|
|
if method == http.MethodPatch {
|
|
|
- headers.Set("Content-Range", fmt.Sprintf("%d-%d", offset, offset+sectionReader.Size()-1))
|
|
|
+ headers.Set("Content-Range", fmt.Sprintf("%d-%d", part.Offset, part.Offset+part.Size-1))
|
|
|
}
|
|
|
|
|
|
- for try := 0; try < maxRetries; try++ {
|
|
|
- resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sectionReader, pw), opts)
|
|
|
- if err != nil && !errors.Is(err, io.EOF) {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
- defer resp.Body.Close()
|
|
|
+ buw := blobUploadWriter{blobUpload: b}
|
|
|
+ resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(rs, io.MultiWriter(&buw, part.Hash)), opts)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ defer resp.Body.Close()
|
|
|
|
|
|
- switch {
|
|
|
- case resp.StatusCode == http.StatusTemporaryRedirect:
|
|
|
- location, err := resp.Location()
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
+ location := resp.Header.Get("Docker-Upload-Location")
|
|
|
+ if location == "" {
|
|
|
+ location = resp.Header.Get("Location")
|
|
|
+ }
|
|
|
+
|
|
|
+ nextURL, err := url.Parse(location)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ switch {
|
|
|
+ case resp.StatusCode == http.StatusTemporaryRedirect:
|
|
|
+ b.nextURL <- nextURL
|
|
|
+
|
|
|
+ redirectURL, err := resp.Location()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
|
|
|
- pw.completed = offset
|
|
|
- if _, err := uploadBlobChunk(ctx, http.MethodPut, location, r, offset, limit, nil, pw); err != nil {
|
|
|
- // retry
|
|
|
- log.Printf("retrying redirected upload: %v", err)
|
|
|
+ for try := 0; try < maxRetries; try++ {
|
|
|
+ rs.Seek(0, io.SeekStart)
|
|
|
+ b.Completed.Add(-buw.written)
|
|
|
+ buw.written = 0
|
|
|
+ part.Hash = md5.New()
|
|
|
+ err := b.uploadChunk(ctx, http.MethodPut, redirectURL, rs, part, nil)
|
|
|
+ switch {
|
|
|
+ case errors.Is(err, context.Canceled):
|
|
|
+ return err
|
|
|
+ case errors.Is(err, errMaxRetriesExceeded):
|
|
|
+ return err
|
|
|
+ case err != nil:
|
|
|
+ log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], part.N, try, err)
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
- return resp, nil
|
|
|
- case resp.StatusCode == http.StatusUnauthorized:
|
|
|
- auth := resp.Header.Get("www-authenticate")
|
|
|
- authRedir := ParseAuthRedirectString(auth)
|
|
|
- token, err := getAuthToken(ctx, authRedir)
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
+ return nil
|
|
|
+ }
|
|
|
|
|
|
- opts.Token = token
|
|
|
+ return errMaxRetriesExceeded
|
|
|
|
|
|
- pw.completed = offset
|
|
|
- sectionReader = io.NewSectionReader(r, offset, limit)
|
|
|
- continue
|
|
|
- case resp.StatusCode >= http.StatusBadRequest:
|
|
|
- body, _ := io.ReadAll(resp.Body)
|
|
|
- return nil, fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
|
|
|
+ case resp.StatusCode == http.StatusUnauthorized:
|
|
|
+ auth := resp.Header.Get("www-authenticate")
|
|
|
+ authRedir := ParseAuthRedirectString(auth)
|
|
|
+ token, err := getAuthToken(ctx, authRedir)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
- return resp, nil
|
|
|
+ opts.Token = token
|
|
|
+ fallthrough
|
|
|
+ case resp.StatusCode >= http.StatusBadRequest:
|
|
|
+ body, err := io.ReadAll(resp.Body)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ rs.Seek(0, io.SeekStart)
|
|
|
+ b.Completed.Add(-buw.written)
|
|
|
+ buw.written = 0
|
|
|
+ return fmt.Errorf("http status %d %s: %s", resp.StatusCode, resp.Status, body)
|
|
|
+ }
|
|
|
+
|
|
|
+ if method == http.MethodPatch {
|
|
|
+ b.nextURL <- nextURL
|
|
|
}
|
|
|
|
|
|
- return nil, fmt.Errorf("max retries exceeded")
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (b *blobUpload) acquire() {
|
|
|
+ b.references.Add(1)
|
|
|
}
|
|
|
|
|
|
-type ProgressWriter struct {
|
|
|
- status string
|
|
|
- digest string
|
|
|
- bucket int64
|
|
|
- completed int64
|
|
|
- total int64
|
|
|
- fn func(api.ProgressResponse)
|
|
|
- mu sync.Mutex
|
|
|
+func (b *blobUpload) release() {
|
|
|
+ if b.references.Add(-1) == 0 {
|
|
|
+ b.CancelFunc()
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
-func (pw *ProgressWriter) Write(b []byte) (int, error) {
|
|
|
- pw.mu.Lock()
|
|
|
- defer pw.mu.Unlock()
|
|
|
-
|
|
|
- n := len(b)
|
|
|
- pw.bucket += int64(n)
|
|
|
-
|
|
|
- // throttle status updates to not spam the client
|
|
|
- if pw.bucket >= 1024*1024 || pw.completed+pw.bucket >= pw.total {
|
|
|
- pw.completed += pw.bucket
|
|
|
- pw.fn(api.ProgressResponse{
|
|
|
- Status: pw.status,
|
|
|
- Digest: pw.digest,
|
|
|
- Total: pw.total,
|
|
|
- Completed: pw.completed,
|
|
|
+func (b *blobUpload) Wait(ctx context.Context, fn func(api.ProgressResponse)) error {
|
|
|
+ b.acquire()
|
|
|
+ defer b.release()
|
|
|
+
|
|
|
+ ticker := time.NewTicker(60 * time.Millisecond)
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-ticker.C:
|
|
|
+ case <-ctx.Done():
|
|
|
+ return ctx.Err()
|
|
|
+ }
|
|
|
+
|
|
|
+ fn(api.ProgressResponse{
|
|
|
+ Status: fmt.Sprintf("uploading %s", b.Digest),
|
|
|
+ Digest: b.Digest,
|
|
|
+ Total: b.Total,
|
|
|
+ Completed: b.Completed.Load(),
|
|
|
})
|
|
|
|
|
|
- pw.bucket = 0
|
|
|
+ if b.done || b.err != nil {
|
|
|
+ return b.err
|
|
|
+ }
|
|
|
}
|
|
|
+}
|
|
|
|
|
|
+type blobUploadWriter struct {
|
|
|
+ written int64
|
|
|
+ *blobUpload
|
|
|
+}
|
|
|
+
|
|
|
+func (b *blobUploadWriter) Write(p []byte) (n int, err error) {
|
|
|
+ n = len(p)
|
|
|
+ b.written += int64(n)
|
|
|
+ b.Completed.Add(int64(n))
|
|
|
return n, nil
|
|
|
}
|
|
|
+
|
|
|
+func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
|
|
+ requestURL := mp.BaseURL()
|
|
|
+ requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
|
|
|
+
|
|
|
+ resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, opts)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ defer resp.Body.Close()
|
|
|
+
|
|
|
+ switch resp.StatusCode {
|
|
|
+ case http.StatusNotFound:
|
|
|
+ case http.StatusOK:
|
|
|
+ fn(api.ProgressResponse{
|
|
|
+ Status: fmt.Sprintf("uploading %s", layer.Digest),
|
|
|
+ Digest: layer.Digest,
|
|
|
+ Total: layer.Size,
|
|
|
+ Completed: layer.Size,
|
|
|
+ })
|
|
|
+
|
|
|
+ return nil
|
|
|
+ default:
|
|
|
+ return fmt.Errorf("unexpected status code %d", resp.StatusCode)
|
|
|
+ }
|
|
|
+
|
|
|
+ data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer})
|
|
|
+ upload := data.(*blobUpload)
|
|
|
+ if !ok {
|
|
|
+ requestURL := mp.BaseURL()
|
|
|
+ requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
|
|
|
+ if err := upload.Prepare(ctx, requestURL, opts); err != nil {
|
|
|
+ blobUploadManager.Delete(layer.Digest)
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ go upload.Run(context.Background(), opts)
|
|
|
+ }
|
|
|
+
|
|
|
+ return upload.Wait(ctx, fn)
|
|
|
+}
|