download.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. package server
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "log"
  8. "net/http"
  9. "os"
  10. "path"
  11. "strconv"
  12. "sync"
  13. "time"
  14. "github.com/jmorganca/ollama/api"
  15. )
  16. type FileDownload struct {
  17. Digest string
  18. FilePath string
  19. Total int64
  20. Completed int64
  21. }
  22. var inProgress sync.Map // map of digests currently being downloaded to their current download progress
  23. // downloadBlob downloads a blob from the registry and stores it in the blobs directory
  24. func downloadBlob(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
  25. fp, err := GetBlobsPath(digest)
  26. if err != nil {
  27. return err
  28. }
  29. if fi, _ := os.Stat(fp); fi != nil {
  30. // we already have the file, so return
  31. fn(api.ProgressResponse{
  32. Digest: digest,
  33. Total: int(fi.Size()),
  34. Completed: int(fi.Size()),
  35. })
  36. return nil
  37. }
  38. fileDownload := &FileDownload{
  39. Digest: digest,
  40. FilePath: fp,
  41. Total: 1, // dummy value to indicate that we don't know the total size yet
  42. Completed: 0,
  43. }
  44. _, downloading := inProgress.LoadOrStore(digest, fileDownload)
  45. if downloading {
  46. // this is another client requesting the server to download the same blob concurrently
  47. return monitorDownload(ctx, mp, regOpts, fileDownload, fn)
  48. }
  49. return doDownload(ctx, mp, regOpts, fileDownload, fn)
  50. }
  51. var downloadMu sync.Mutex // mutex to check to resume a download while monitoring
  52. // monitorDownload monitors the download progress of a blob and resumes it if it is interrupted
  53. func monitorDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload, fn func(api.ProgressResponse)) error {
  54. tick := time.NewTicker(time.Second)
  55. for range tick.C {
  56. done, resume, err := func() (bool, bool, error) {
  57. downloadMu.Lock()
  58. defer downloadMu.Unlock()
  59. val, downloading := inProgress.Load(f.Digest)
  60. if !downloading {
  61. // check once again if the download is complete
  62. if fi, _ := os.Stat(f.FilePath); fi != nil {
  63. // successful download while monitoring
  64. fn(api.ProgressResponse{
  65. Digest: f.Digest,
  66. Total: int(fi.Size()),
  67. Completed: int(fi.Size()),
  68. })
  69. return true, false, nil
  70. }
  71. // resume the download
  72. inProgress.Store(f.Digest, f) // store the file download again to claim the resume
  73. return false, true, nil
  74. }
  75. f, ok := val.(*FileDownload)
  76. if !ok {
  77. return false, false, fmt.Errorf("invalid type for in progress download: %T", val)
  78. }
  79. fn(api.ProgressResponse{
  80. Status: fmt.Sprintf("downloading %s", f.Digest),
  81. Digest: f.Digest,
  82. Total: int(f.Total),
  83. Completed: int(f.Completed),
  84. })
  85. return false, false, nil
  86. }()
  87. if err != nil {
  88. return err
  89. }
  90. if done {
  91. // done downloading
  92. return nil
  93. }
  94. if resume {
  95. return doDownload(ctx, mp, regOpts, f, fn)
  96. }
  97. }
  98. return nil
  99. }
  100. var chunkSize = 1024 * 1024 // 1 MiB in bytes
  101. // doDownload downloads a blob from the registry and stores it in the blobs directory
  102. func doDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload, fn func(api.ProgressResponse)) error {
  103. var size int64
  104. fi, err := os.Stat(f.FilePath + "-partial")
  105. switch {
  106. case errors.Is(err, os.ErrNotExist):
  107. // noop, file doesn't exist so create it
  108. case err != nil:
  109. return fmt.Errorf("stat: %w", err)
  110. default:
  111. size = fi.Size()
  112. // Ensure the size is divisible by the chunk size by removing excess bytes
  113. size -= size % int64(chunkSize)
  114. err := os.Truncate(f.FilePath+"-partial", size)
  115. if err != nil {
  116. return fmt.Errorf("truncate: %w", err)
  117. }
  118. }
  119. url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), f.Digest)
  120. headers := map[string]string{
  121. "Range": fmt.Sprintf("bytes=%d-", size),
  122. }
  123. resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts)
  124. if err != nil {
  125. log.Printf("couldn't download blob: %v", err)
  126. return err
  127. }
  128. defer resp.Body.Close()
  129. if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
  130. body, _ := io.ReadAll(resp.Body)
  131. return fmt.Errorf("on download registry responded with code %d: %v", resp.StatusCode, string(body))
  132. }
  133. err = os.MkdirAll(path.Dir(f.FilePath), 0o700)
  134. if err != nil {
  135. return fmt.Errorf("make blobs directory: %w", err)
  136. }
  137. remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
  138. f.Completed = size
  139. f.Total = remaining + f.Completed
  140. inProgress.Store(f.Digest, f)
  141. out, err := os.OpenFile(f.FilePath+"-partial", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
  142. if err != nil {
  143. return fmt.Errorf("open file: %w", err)
  144. }
  145. defer out.Close()
  146. outerLoop:
  147. for {
  148. select {
  149. case <-ctx.Done():
  150. // handle client request cancellation
  151. inProgress.Delete(f.Digest)
  152. return nil
  153. default:
  154. fn(api.ProgressResponse{
  155. Status: fmt.Sprintf("downloading %s", f.Digest),
  156. Digest: f.Digest,
  157. Total: int(f.Total),
  158. Completed: int(f.Completed),
  159. })
  160. if f.Completed >= f.Total {
  161. if err := out.Close(); err != nil {
  162. return err
  163. }
  164. if err := os.Rename(f.FilePath+"-partial", f.FilePath); err != nil {
  165. fn(api.ProgressResponse{
  166. Status: fmt.Sprintf("error renaming file: %v", err),
  167. Digest: f.Digest,
  168. Total: int(f.Total),
  169. Completed: int(f.Completed),
  170. })
  171. return err
  172. }
  173. break outerLoop
  174. }
  175. }
  176. n, err := io.CopyN(out, resp.Body, int64(chunkSize))
  177. if err != nil && !errors.Is(err, io.EOF) {
  178. return err
  179. }
  180. f.Completed += n
  181. inProgress.Store(f.Digest, f)
  182. }
  183. inProgress.Delete(f.Digest)
  184. log.Printf("success getting %s\n", f.Digest)
  185. return nil
  186. }