123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236 |
- package server
- import (
- "context"
- "errors"
- "fmt"
- "io"
- "log"
- "net/http"
- "os"
- "path"
- "strconv"
- "sync"
- "time"
- "github.com/jmorganca/ollama/api"
- )
- type FileDownload struct {
- Digest string
- FilePath string
- Total int64
- Completed int64
- }
- var inProgress sync.Map // map of digests currently being downloaded to their current download progress
- type downloadOpts struct {
- mp ModelPath
- digest string
- regOpts *RegistryOptions
- fn func(api.ProgressResponse)
- retry int // track the number of retries on this download
- }
- const maxRetry = 3
- // downloadBlob downloads a blob from the registry and stores it in the blobs directory
- func downloadBlob(ctx context.Context, opts downloadOpts) error {
- fp, err := GetBlobsPath(opts.digest)
- if err != nil {
- return err
- }
- if fi, _ := os.Stat(fp); fi != nil {
- // we already have the file, so return
- opts.fn(api.ProgressResponse{
- Digest: opts.digest,
- Total: int(fi.Size()),
- Completed: int(fi.Size()),
- })
- return nil
- }
- fileDownload := &FileDownload{
- Digest: opts.digest,
- FilePath: fp,
- Total: 1, // dummy value to indicate that we don't know the total size yet
- Completed: 0,
- }
- _, downloading := inProgress.LoadOrStore(opts.digest, fileDownload)
- if downloading {
- // this is another client requesting the server to download the same blob concurrently
- return monitorDownload(ctx, opts, fileDownload)
- }
- if err := doDownload(ctx, opts, fileDownload); err != nil {
- if errors.Is(err, errDownload) && opts.retry < maxRetry {
- opts.retry++
- log.Print(err)
- log.Printf("retrying download of %s", opts.digest)
- return downloadBlob(ctx, opts)
- }
- return err
- }
- return nil
- }
- var downloadMu sync.Mutex // mutex to check to resume a download while monitoring
- // monitorDownload monitors the download progress of a blob and resumes it if it is interrupted
- func monitorDownload(ctx context.Context, opts downloadOpts, f *FileDownload) error {
- tick := time.NewTicker(time.Second)
- for range tick.C {
- done, resume, err := func() (bool, bool, error) {
- downloadMu.Lock()
- defer downloadMu.Unlock()
- val, downloading := inProgress.Load(f.Digest)
- if !downloading {
- // check once again if the download is complete
- if fi, _ := os.Stat(f.FilePath); fi != nil {
- // successful download while monitoring
- opts.fn(api.ProgressResponse{
- Digest: f.Digest,
- Total: int(fi.Size()),
- Completed: int(fi.Size()),
- })
- return true, false, nil
- }
- // resume the download
- inProgress.Store(f.Digest, f) // store the file download again to claim the resume
- return false, true, nil
- }
- f, ok := val.(*FileDownload)
- if !ok {
- return false, false, fmt.Errorf("invalid type for in progress download: %T", val)
- }
- opts.fn(api.ProgressResponse{
- Status: fmt.Sprintf("downloading %s", f.Digest),
- Digest: f.Digest,
- Total: int(f.Total),
- Completed: int(f.Completed),
- })
- return false, false, nil
- }()
- if err != nil {
- return err
- }
- if done {
- // done downloading
- return nil
- }
- if resume {
- return doDownload(ctx, opts, f)
- }
- }
- return nil
- }
- var (
- chunkSize = 1024 * 1024 // 1 MiB in bytes
- errDownload = fmt.Errorf("download failed")
- )
- // doDownload downloads a blob from the registry and stores it in the blobs directory
- func doDownload(ctx context.Context, opts downloadOpts, f *FileDownload) error {
- defer inProgress.Delete(f.Digest)
- var size int64
- fi, err := os.Stat(f.FilePath + "-partial")
- switch {
- case errors.Is(err, os.ErrNotExist):
- // noop, file doesn't exist so create it
- case err != nil:
- return fmt.Errorf("stat: %w", err)
- default:
- size = fi.Size()
- // Ensure the size is divisible by the chunk size by removing excess bytes
- size -= size % int64(chunkSize)
- err := os.Truncate(f.FilePath+"-partial", size)
- if err != nil {
- return fmt.Errorf("truncate: %w", err)
- }
- }
- url := fmt.Sprintf("%s/v2/%s/blobs/%s", opts.mp.Registry, opts.mp.GetNamespaceRepository(), f.Digest)
- headers := make(http.Header)
- headers.Set("Range", fmt.Sprintf("bytes=%d-", size))
- resp, err := makeRequest(ctx, "GET", url, headers, nil, opts.regOpts)
- if err != nil {
- log.Printf("couldn't download blob: %v", err)
- return fmt.Errorf("%w: %w", errDownload, err)
- }
- defer resp.Body.Close()
- if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
- body, _ := io.ReadAll(resp.Body)
- return fmt.Errorf("%w: on download registry responded with code %d: %v", errDownload, resp.StatusCode, string(body))
- }
- err = os.MkdirAll(path.Dir(f.FilePath), 0o700)
- if err != nil {
- return fmt.Errorf("make blobs directory: %w", err)
- }
- remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
- f.Completed = size
- f.Total = remaining + f.Completed
- inProgress.Store(f.Digest, f)
- out, err := os.OpenFile(f.FilePath+"-partial", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
- if err != nil {
- return fmt.Errorf("open file: %w", err)
- }
- defer out.Close()
- outerLoop:
- for {
- select {
- case <-ctx.Done():
- // handle client request cancellation
- inProgress.Delete(f.Digest)
- return nil
- default:
- opts.fn(api.ProgressResponse{
- Status: fmt.Sprintf("downloading %s", f.Digest),
- Digest: f.Digest,
- Total: int(f.Total),
- Completed: int(f.Completed),
- })
- if f.Completed >= f.Total {
- if err := out.Close(); err != nil {
- return err
- }
- if err := os.Rename(f.FilePath+"-partial", f.FilePath); err != nil {
- opts.fn(api.ProgressResponse{
- Status: fmt.Sprintf("error renaming file: %v", err),
- Digest: f.Digest,
- Total: int(f.Total),
- Completed: int(f.Completed),
- })
- return err
- }
- break outerLoop
- }
- }
- n, err := io.CopyN(out, resp.Body, int64(chunkSize))
- if err != nil && !errors.Is(err, io.EOF) {
- return fmt.Errorf("%w: %w", errDownload, err)
- }
- f.Completed += n
- inProgress.Store(f.Digest, f)
- }
- log.Printf("success getting %s\n", f.Digest)
- return nil
- }
|