download.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. package server
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "log"
  8. "net/http"
  9. "os"
  10. "path/filepath"
  11. "strconv"
  12. "sync"
  13. "time"
  14. "github.com/jmorganca/ollama/api"
  15. )
  16. type FileDownload struct {
  17. Digest string
  18. FilePath string
  19. Total int64
  20. Completed int64
  21. }
  22. var inProgress sync.Map // map of digests currently being downloaded to their current download progress
  23. type downloadOpts struct {
  24. mp ModelPath
  25. digest string
  26. regOpts *RegistryOptions
  27. fn func(api.ProgressResponse)
  28. retry int // track the number of retries on this download
  29. }
  30. const maxRetry = 3
  31. // downloadBlob downloads a blob from the registry and stores it in the blobs directory
  32. func downloadBlob(ctx context.Context, opts downloadOpts) error {
  33. fp, err := GetBlobsPath(opts.digest)
  34. if err != nil {
  35. return err
  36. }
  37. if fi, _ := os.Stat(fp); fi != nil {
  38. // we already have the file, so return
  39. opts.fn(api.ProgressResponse{
  40. Digest: opts.digest,
  41. Total: int(fi.Size()),
  42. Completed: int(fi.Size()),
  43. })
  44. return nil
  45. }
  46. fileDownload := &FileDownload{
  47. Digest: opts.digest,
  48. FilePath: fp,
  49. Total: 1, // dummy value to indicate that we don't know the total size yet
  50. Completed: 0,
  51. }
  52. _, downloading := inProgress.LoadOrStore(opts.digest, fileDownload)
  53. if downloading {
  54. // this is another client requesting the server to download the same blob concurrently
  55. return monitorDownload(ctx, opts, fileDownload)
  56. }
  57. if err := doDownload(ctx, opts, fileDownload); err != nil {
  58. if errors.Is(err, errDownload) && opts.retry < maxRetry {
  59. opts.retry++
  60. log.Print(err)
  61. log.Printf("retrying download of %s", opts.digest)
  62. return downloadBlob(ctx, opts)
  63. }
  64. return err
  65. }
  66. return nil
  67. }
  68. var downloadMu sync.Mutex // mutex to check to resume a download while monitoring
  69. // monitorDownload monitors the download progress of a blob and resumes it if it is interrupted
  70. func monitorDownload(ctx context.Context, opts downloadOpts, f *FileDownload) error {
  71. tick := time.NewTicker(time.Second)
  72. for range tick.C {
  73. done, resume, err := func() (bool, bool, error) {
  74. downloadMu.Lock()
  75. defer downloadMu.Unlock()
  76. val, downloading := inProgress.Load(f.Digest)
  77. if !downloading {
  78. // check once again if the download is complete
  79. if fi, _ := os.Stat(f.FilePath); fi != nil {
  80. // successful download while monitoring
  81. opts.fn(api.ProgressResponse{
  82. Digest: f.Digest,
  83. Total: int(fi.Size()),
  84. Completed: int(fi.Size()),
  85. })
  86. return true, false, nil
  87. }
  88. // resume the download
  89. inProgress.Store(f.Digest, f) // store the file download again to claim the resume
  90. return false, true, nil
  91. }
  92. f, ok := val.(*FileDownload)
  93. if !ok {
  94. return false, false, fmt.Errorf("invalid type for in progress download: %T", val)
  95. }
  96. opts.fn(api.ProgressResponse{
  97. Status: fmt.Sprintf("downloading %s", f.Digest),
  98. Digest: f.Digest,
  99. Total: int(f.Total),
  100. Completed: int(f.Completed),
  101. })
  102. return false, false, nil
  103. }()
  104. if err != nil {
  105. return err
  106. }
  107. if done {
  108. // done downloading
  109. return nil
  110. }
  111. if resume {
  112. return doDownload(ctx, opts, f)
  113. }
  114. }
  115. return nil
  116. }
  117. var (
  118. chunkSize = 1024 * 1024 // 1 MiB in bytes
  119. errDownload = fmt.Errorf("download failed")
  120. )
  121. // doDownload downloads a blob from the registry and stores it in the blobs directory
  122. func doDownload(ctx context.Context, opts downloadOpts, f *FileDownload) error {
  123. defer inProgress.Delete(f.Digest)
  124. var size int64
  125. fi, err := os.Stat(f.FilePath + "-partial")
  126. switch {
  127. case errors.Is(err, os.ErrNotExist):
  128. // noop, file doesn't exist so create it
  129. case err != nil:
  130. return fmt.Errorf("stat: %w", err)
  131. default:
  132. size = fi.Size()
  133. // Ensure the size is divisible by the chunk size by removing excess bytes
  134. size -= size % int64(chunkSize)
  135. err := os.Truncate(f.FilePath+"-partial", size)
  136. if err != nil {
  137. return fmt.Errorf("truncate: %w", err)
  138. }
  139. }
  140. requestURL := opts.mp.BaseURL()
  141. requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", f.Digest)
  142. headers := make(http.Header)
  143. headers.Set("Range", fmt.Sprintf("bytes=%d-", size))
  144. resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts.regOpts)
  145. if err != nil {
  146. log.Printf("couldn't download blob: %v", err)
  147. return fmt.Errorf("%w: %w", errDownload, err)
  148. }
  149. defer resp.Body.Close()
  150. if resp.StatusCode >= http.StatusBadRequest {
  151. body, _ := io.ReadAll(resp.Body)
  152. return fmt.Errorf("%w: on download registry responded with code %d: %v", errDownload, resp.StatusCode, string(body))
  153. }
  154. err = os.MkdirAll(filepath.Dir(f.FilePath), 0o700)
  155. if err != nil {
  156. return fmt.Errorf("make blobs directory: %w", err)
  157. }
  158. remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
  159. f.Completed = size
  160. f.Total = remaining + f.Completed
  161. inProgress.Store(f.Digest, f)
  162. out, err := os.OpenFile(f.FilePath+"-partial", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
  163. if err != nil {
  164. return fmt.Errorf("open file: %w", err)
  165. }
  166. defer out.Close()
  167. outerLoop:
  168. for {
  169. select {
  170. case <-ctx.Done():
  171. // handle client request cancellation
  172. inProgress.Delete(f.Digest)
  173. return nil
  174. default:
  175. opts.fn(api.ProgressResponse{
  176. Status: fmt.Sprintf("downloading %s", f.Digest),
  177. Digest: f.Digest,
  178. Total: int(f.Total),
  179. Completed: int(f.Completed),
  180. })
  181. if f.Completed >= f.Total {
  182. if err := out.Close(); err != nil {
  183. return err
  184. }
  185. if err := os.Rename(f.FilePath+"-partial", f.FilePath); err != nil {
  186. opts.fn(api.ProgressResponse{
  187. Status: fmt.Sprintf("error renaming file: %v", err),
  188. Digest: f.Digest,
  189. Total: int(f.Total),
  190. Completed: int(f.Completed),
  191. })
  192. return err
  193. }
  194. break outerLoop
  195. }
  196. }
  197. n, err := io.CopyN(out, resp.Body, int64(chunkSize))
  198. if err != nil && !errors.Is(err, io.EOF) {
  199. return fmt.Errorf("%w: %w", errDownload, err)
  200. }
  201. f.Completed += n
  202. inProgress.Store(f.Digest, f)
  203. }
  204. log.Printf("success getting %s\n", f.Digest)
  205. return nil
  206. }