download.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. package server
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "log"
  8. "net/http"
  9. "os"
  10. "path"
  11. "strconv"
  12. "sync"
  13. "time"
  14. "github.com/jmorganca/ollama/api"
  15. )
  16. type FileDownload struct {
  17. Digest string
  18. FilePath string
  19. Total int64
  20. Completed int64
  21. }
  22. var inProgress sync.Map // map of digests currently being downloaded to their current download progress
  23. // downloadBlob downloads a blob from the registry and stores it in the blobs directory
  24. func downloadBlob(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
  25. fp, err := GetBlobsPath(digest)
  26. if err != nil {
  27. return err
  28. }
  29. if fi, _ := os.Stat(fp); fi != nil {
  30. // we already have the file, so return
  31. fn(api.ProgressResponse{
  32. Digest: digest,
  33. Total: int(fi.Size()),
  34. Completed: int(fi.Size()),
  35. })
  36. return nil
  37. }
  38. fileDownload := &FileDownload{
  39. Digest: digest,
  40. FilePath: fp,
  41. Total: 1, // dummy value to indicate that we don't know the total size yet
  42. Completed: 0,
  43. }
  44. _, downloading := inProgress.LoadOrStore(digest, fileDownload)
  45. if downloading {
  46. // this is another client requesting the server to download the same blob concurrently
  47. return monitorDownload(ctx, mp, regOpts, fileDownload, fn)
  48. }
  49. resp, err := requestDownload(ctx, mp, regOpts, fileDownload)
  50. if err != nil {
  51. return err
  52. }
  53. return doDownload(ctx, fileDownload, resp, fn)
  54. }
  55. var downloadMu sync.Mutex // mutex to check to resume a download while monitoring
  56. // monitorDownload monitors the download progress of a blob and resumes it if it is interrupted
  57. func monitorDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload, fn func(api.ProgressResponse)) error {
  58. tick := time.NewTicker(time.Second)
  59. for range tick.C {
  60. downloadMu.Lock()
  61. val, downloading := inProgress.Load(f.Digest)
  62. if !downloading {
  63. // check once again if the download is complete
  64. if fi, _ := os.Stat(f.FilePath); fi != nil {
  65. downloadMu.Unlock()
  66. // successfull download while monitoring
  67. fn(api.ProgressResponse{
  68. Digest: f.Digest,
  69. Total: int(fi.Size()),
  70. Completed: int(fi.Size()),
  71. })
  72. return nil
  73. }
  74. // resume the download
  75. resp, err := requestDownload(ctx, mp, regOpts, f)
  76. if err != nil {
  77. return fmt.Errorf("resume: %w", err)
  78. }
  79. inProgress.Store(f.Digest, f)
  80. downloadMu.Unlock()
  81. return doDownload(ctx, f, resp, fn)
  82. }
  83. downloadMu.Unlock()
  84. f, ok := val.(*FileDownload)
  85. if !ok {
  86. return fmt.Errorf("invalid type for in progress download: %T", val)
  87. }
  88. fn(api.ProgressResponse{
  89. Status: fmt.Sprintf("downloading %s", f.Digest),
  90. Digest: f.Digest,
  91. Total: int(f.Total),
  92. Completed: int(f.Completed),
  93. })
  94. }
  95. return nil
  96. }
  97. var chunkSize = 1024 * 1024 // 1 MiB in bytes
  98. // requestDownload requests a blob from the registry and returns the response
  99. func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload) (*http.Response, error) {
  100. var size int64
  101. fi, err := os.Stat(f.FilePath + "-partial")
  102. switch {
  103. case errors.Is(err, os.ErrNotExist):
  104. // noop, file doesn't exist so create it
  105. case err != nil:
  106. return nil, fmt.Errorf("stat: %w", err)
  107. default:
  108. size = fi.Size()
  109. // Ensure the size is divisible by the chunk size by removing excess bytes
  110. size -= size % int64(chunkSize)
  111. err := os.Truncate(f.FilePath+"-partial", size)
  112. if err != nil {
  113. return nil, fmt.Errorf("truncate: %w", err)
  114. }
  115. }
  116. url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), f.Digest)
  117. headers := map[string]string{
  118. "Range": fmt.Sprintf("bytes=%d-", size),
  119. }
  120. resp, err := makeRequest("GET", url, headers, nil, regOpts)
  121. if err != nil {
  122. log.Printf("couldn't download blob: %v", err)
  123. return nil, err
  124. }
  125. // resp MUST be closed by doDownload, which should follow this function
  126. if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
  127. body, _ := io.ReadAll(resp.Body)
  128. return nil, fmt.Errorf("on download registry responded with code %d: %v", resp.StatusCode, string(body))
  129. }
  130. err = os.MkdirAll(path.Dir(f.FilePath), 0o700)
  131. if err != nil {
  132. return nil, fmt.Errorf("make blobs directory: %w", err)
  133. }
  134. remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
  135. f.Completed = size
  136. f.Total = remaining + f.Completed
  137. inProgress.Store(f.Digest, f)
  138. return resp, nil
  139. }
  140. // doDownload downloads a blob from the registry and stores it in the blobs directory
  141. func doDownload(ctx context.Context, f *FileDownload, resp *http.Response, fn func(api.ProgressResponse)) error {
  142. defer resp.Body.Close()
  143. out, err := os.OpenFile(f.FilePath+"-partial", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
  144. if err != nil {
  145. return fmt.Errorf("open file: %w", err)
  146. }
  147. defer out.Close()
  148. outerLoop:
  149. for {
  150. select {
  151. case <-ctx.Done():
  152. // handle client request cancellation
  153. inProgress.Delete(f.Digest)
  154. return nil
  155. default:
  156. fn(api.ProgressResponse{
  157. Status: fmt.Sprintf("downloading %s", f.Digest),
  158. Digest: f.Digest,
  159. Total: int(f.Total),
  160. Completed: int(f.Completed),
  161. })
  162. if f.Completed >= f.Total {
  163. if err := out.Close(); err != nil {
  164. return err
  165. }
  166. if err := os.Rename(f.FilePath+"-partial", f.FilePath); err != nil {
  167. fn(api.ProgressResponse{
  168. Status: fmt.Sprintf("error renaming file: %v", err),
  169. Digest: f.Digest,
  170. Total: int(f.Total),
  171. Completed: int(f.Completed),
  172. })
  173. return err
  174. }
  175. break outerLoop
  176. }
  177. }
  178. n, err := io.CopyN(out, resp.Body, int64(chunkSize))
  179. if err != nil && !errors.Is(err, io.EOF) {
  180. return err
  181. }
  182. f.Completed += n
  183. inProgress.Store(f.Digest, f)
  184. }
  185. inProgress.Delete(f.Digest)
  186. log.Printf("success getting %s\n", f.Digest)
  187. return nil
  188. }