download.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. package server
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "log"
  8. "net/http"
  9. "os"
  10. "path"
  11. "strconv"
  12. "sync"
  13. "time"
  14. "github.com/jmorganca/ollama/api"
  15. )
  16. type FileDownload struct {
  17. Digest string
  18. FilePath string
  19. Total int64
  20. Completed int64
  21. }
  22. var inProgress sync.Map // map of digests currently being downloaded to their current download progress
  23. type downloadOpts struct {
  24. mp ModelPath
  25. digest string
  26. regOpts *RegistryOptions
  27. fn func(api.ProgressResponse)
  28. retry int // track the number of retries on this download
  29. }
  30. const maxRetry = 3
  31. // downloadBlob downloads a blob from the registry and stores it in the blobs directory
  32. func downloadBlob(ctx context.Context, opts downloadOpts) error {
  33. fp, err := GetBlobsPath(opts.digest)
  34. if err != nil {
  35. return err
  36. }
  37. if fi, _ := os.Stat(fp); fi != nil {
  38. // we already have the file, so return
  39. opts.fn(api.ProgressResponse{
  40. Digest: opts.digest,
  41. Total: int(fi.Size()),
  42. Completed: int(fi.Size()),
  43. })
  44. return nil
  45. }
  46. fileDownload := &FileDownload{
  47. Digest: opts.digest,
  48. FilePath: fp,
  49. Total: 1, // dummy value to indicate that we don't know the total size yet
  50. Completed: 0,
  51. }
  52. _, downloading := inProgress.LoadOrStore(opts.digest, fileDownload)
  53. if downloading {
  54. // this is another client requesting the server to download the same blob concurrently
  55. return monitorDownload(ctx, opts, fileDownload)
  56. }
  57. if err := doDownload(ctx, opts, fileDownload); err != nil {
  58. if errors.Is(err, errDownload) && opts.retry < maxRetry {
  59. opts.retry++
  60. log.Print(err)
  61. log.Printf("retrying download of %s", opts.digest)
  62. return downloadBlob(ctx, opts)
  63. }
  64. return err
  65. }
  66. return nil
  67. }
  68. var downloadMu sync.Mutex // mutex to check to resume a download while monitoring
  69. // monitorDownload monitors the download progress of a blob and resumes it if it is interrupted
  70. func monitorDownload(ctx context.Context, opts downloadOpts, f *FileDownload) error {
  71. tick := time.NewTicker(time.Second)
  72. for range tick.C {
  73. done, resume, err := func() (bool, bool, error) {
  74. downloadMu.Lock()
  75. defer downloadMu.Unlock()
  76. val, downloading := inProgress.Load(f.Digest)
  77. if !downloading {
  78. // check once again if the download is complete
  79. if fi, _ := os.Stat(f.FilePath); fi != nil {
  80. // successful download while monitoring
  81. opts.fn(api.ProgressResponse{
  82. Digest: f.Digest,
  83. Total: int(fi.Size()),
  84. Completed: int(fi.Size()),
  85. })
  86. return true, false, nil
  87. }
  88. // resume the download
  89. inProgress.Store(f.Digest, f) // store the file download again to claim the resume
  90. return false, true, nil
  91. }
  92. f, ok := val.(*FileDownload)
  93. if !ok {
  94. return false, false, fmt.Errorf("invalid type for in progress download: %T", val)
  95. }
  96. opts.fn(api.ProgressResponse{
  97. Status: fmt.Sprintf("downloading %s", f.Digest),
  98. Digest: f.Digest,
  99. Total: int(f.Total),
  100. Completed: int(f.Completed),
  101. })
  102. return false, false, nil
  103. }()
  104. if err != nil {
  105. return err
  106. }
  107. if done {
  108. // done downloading
  109. return nil
  110. }
  111. if resume {
  112. return doDownload(ctx, opts, f)
  113. }
  114. }
  115. return nil
  116. }
  117. var (
  118. chunkSize = 1024 * 1024 // 1 MiB in bytes
  119. errDownload = fmt.Errorf("download failed")
  120. )
  121. // doDownload downloads a blob from the registry and stores it in the blobs directory
  122. func doDownload(ctx context.Context, opts downloadOpts, f *FileDownload) error {
  123. defer inProgress.Delete(f.Digest)
  124. var size int64
  125. fi, err := os.Stat(f.FilePath + "-partial")
  126. switch {
  127. case errors.Is(err, os.ErrNotExist):
  128. // noop, file doesn't exist so create it
  129. case err != nil:
  130. return fmt.Errorf("stat: %w", err)
  131. default:
  132. size = fi.Size()
  133. // Ensure the size is divisible by the chunk size by removing excess bytes
  134. size -= size % int64(chunkSize)
  135. err := os.Truncate(f.FilePath+"-partial", size)
  136. if err != nil {
  137. return fmt.Errorf("truncate: %w", err)
  138. }
  139. }
  140. url := fmt.Sprintf("%s/v2/%s/blobs/%s", opts.mp.Registry, opts.mp.GetNamespaceRepository(), f.Digest)
  141. headers := make(http.Header)
  142. headers.Set("Range", fmt.Sprintf("bytes=%d-", size))
  143. resp, err := makeRequest(ctx, "GET", url, headers, nil, opts.regOpts)
  144. if err != nil {
  145. log.Printf("couldn't download blob: %v", err)
  146. return fmt.Errorf("%w: %w", errDownload, err)
  147. }
  148. defer resp.Body.Close()
  149. if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
  150. body, _ := io.ReadAll(resp.Body)
  151. return fmt.Errorf("%w: on download registry responded with code %d: %v", errDownload, resp.StatusCode, string(body))
  152. }
  153. err = os.MkdirAll(path.Dir(f.FilePath), 0o700)
  154. if err != nil {
  155. return fmt.Errorf("make blobs directory: %w", err)
  156. }
  157. remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
  158. f.Completed = size
  159. f.Total = remaining + f.Completed
  160. inProgress.Store(f.Digest, f)
  161. out, err := os.OpenFile(f.FilePath+"-partial", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
  162. if err != nil {
  163. return fmt.Errorf("open file: %w", err)
  164. }
  165. defer out.Close()
  166. outerLoop:
  167. for {
  168. select {
  169. case <-ctx.Done():
  170. // handle client request cancellation
  171. inProgress.Delete(f.Digest)
  172. return nil
  173. default:
  174. opts.fn(api.ProgressResponse{
  175. Status: fmt.Sprintf("downloading %s", f.Digest),
  176. Digest: f.Digest,
  177. Total: int(f.Total),
  178. Completed: int(f.Completed),
  179. })
  180. if f.Completed >= f.Total {
  181. if err := out.Close(); err != nil {
  182. return err
  183. }
  184. if err := os.Rename(f.FilePath+"-partial", f.FilePath); err != nil {
  185. opts.fn(api.ProgressResponse{
  186. Status: fmt.Sprintf("error renaming file: %v", err),
  187. Digest: f.Digest,
  188. Total: int(f.Total),
  189. Completed: int(f.Completed),
  190. })
  191. return err
  192. }
  193. break outerLoop
  194. }
  195. }
  196. n, err := io.CopyN(out, resp.Body, int64(chunkSize))
  197. if err != nil && !errors.Is(err, io.EOF) {
  198. return fmt.Errorf("%w: %w", errDownload, err)
  199. }
  200. f.Completed += n
  201. inProgress.Store(f.Digest, f)
  202. }
  203. log.Printf("success getting %s\n", f.Digest)
  204. return nil
  205. }