upload.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. package server
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "log"
  8. "net/http"
  9. "net/url"
  10. "os"
  11. "strconv"
  12. "sync"
  13. "github.com/jmorganca/ollama/api"
  14. )
  15. const (
  16. redirectChunkSize int64 = 1024 * 1024 * 1024
  17. regularChunkSize int64 = 95 * 1024 * 1024
  18. )
  19. func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (*url.URL, int64, error) {
  20. requestURL := mp.BaseURL()
  21. requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
  22. if layer.From != "" {
  23. values := requestURL.Query()
  24. values.Add("mount", layer.Digest)
  25. values.Add("from", layer.From)
  26. requestURL.RawQuery = values.Encode()
  27. }
  28. resp, err := makeRequestWithRetry(ctx, "POST", requestURL, nil, nil, regOpts)
  29. if err != nil {
  30. log.Printf("couldn't start upload: %v", err)
  31. return nil, 0, err
  32. }
  33. defer resp.Body.Close()
  34. location := resp.Header.Get("Docker-Upload-Location")
  35. chunkSize := redirectChunkSize
  36. if location == "" {
  37. location = resp.Header.Get("Location")
  38. chunkSize = regularChunkSize
  39. }
  40. locationURL, err := url.Parse(location)
  41. if err != nil {
  42. return nil, 0, err
  43. }
  44. return locationURL, chunkSize, nil
  45. }
  46. func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSize int64, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
  47. // TODO allow resumability
  48. // TODO allow canceling uploads via DELETE
  49. fp, err := GetBlobsPath(layer.Digest)
  50. if err != nil {
  51. return err
  52. }
  53. f, err := os.Open(fp)
  54. if err != nil {
  55. return err
  56. }
  57. defer f.Close()
  58. pw := ProgressWriter{
  59. status: fmt.Sprintf("uploading %s", layer.Digest),
  60. digest: layer.Digest,
  61. total: layer.Size,
  62. fn: fn,
  63. }
  64. for offset := int64(0); offset < layer.Size; {
  65. chunk := layer.Size - offset
  66. if chunk > chunkSize {
  67. chunk = chunkSize
  68. }
  69. resp, err := uploadBlobChunk(ctx, http.MethodPatch, requestURL, f, offset, chunk, regOpts, &pw)
  70. if err != nil {
  71. fn(api.ProgressResponse{
  72. Status: fmt.Sprintf("error uploading chunk: %v", err),
  73. Digest: layer.Digest,
  74. Total: layer.Size,
  75. Completed: offset,
  76. })
  77. return err
  78. }
  79. offset += chunk
  80. location := resp.Header.Get("Docker-Upload-Location")
  81. if location == "" {
  82. location = resp.Header.Get("Location")
  83. }
  84. requestURL, err = url.Parse(location)
  85. if err != nil {
  86. return err
  87. }
  88. }
  89. values := requestURL.Query()
  90. values.Add("digest", layer.Digest)
  91. requestURL.RawQuery = values.Encode()
  92. headers := make(http.Header)
  93. headers.Set("Content-Type", "application/octet-stream")
  94. headers.Set("Content-Length", "0")
  95. // finish the upload
  96. resp, err := makeRequest(ctx, "PUT", requestURL, headers, nil, regOpts)
  97. if err != nil {
  98. log.Printf("couldn't finish upload: %v", err)
  99. return err
  100. }
  101. defer resp.Body.Close()
  102. if resp.StatusCode >= http.StatusBadRequest {
  103. body, _ := io.ReadAll(resp.Body)
  104. return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
  105. }
  106. return nil
  107. }
  108. func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r io.ReaderAt, offset, limit int64, opts *RegistryOptions, pw *ProgressWriter) (*http.Response, error) {
  109. sectionReader := io.NewSectionReader(r, offset, limit)
  110. headers := make(http.Header)
  111. headers.Set("Content-Type", "application/octet-stream")
  112. headers.Set("Content-Length", strconv.Itoa(int(limit)))
  113. headers.Set("X-Redirect-Uploads", "1")
  114. if method == http.MethodPatch {
  115. headers.Set("Content-Range", fmt.Sprintf("%d-%d", offset, offset+sectionReader.Size()-1))
  116. }
  117. for try := 0; try < maxRetries; try++ {
  118. resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sectionReader, pw), opts)
  119. if err != nil && !errors.Is(err, io.EOF) {
  120. return nil, err
  121. }
  122. defer resp.Body.Close()
  123. switch {
  124. case resp.StatusCode == http.StatusTemporaryRedirect:
  125. location, err := resp.Location()
  126. if err != nil {
  127. return nil, err
  128. }
  129. pw.completed = offset
  130. if _, err := uploadBlobChunk(ctx, http.MethodPut, location, r, offset, limit, nil, pw); err != nil {
  131. // retry
  132. log.Printf("retrying redirected upload: %v", err)
  133. continue
  134. }
  135. return resp, nil
  136. case resp.StatusCode == http.StatusUnauthorized:
  137. auth := resp.Header.Get("www-authenticate")
  138. authRedir := ParseAuthRedirectString(auth)
  139. token, err := getAuthToken(ctx, authRedir)
  140. if err != nil {
  141. return nil, err
  142. }
  143. opts.Token = token
  144. pw.completed = offset
  145. sectionReader = io.NewSectionReader(r, offset, limit)
  146. continue
  147. case resp.StatusCode >= http.StatusBadRequest:
  148. body, _ := io.ReadAll(resp.Body)
  149. return nil, fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
  150. }
  151. return resp, nil
  152. }
  153. return nil, fmt.Errorf("max retries exceeded")
  154. }
  155. type ProgressWriter struct {
  156. status string
  157. digest string
  158. bucket int64
  159. completed int64
  160. total int64
  161. fn func(api.ProgressResponse)
  162. mu sync.Mutex
  163. }
  164. func (pw *ProgressWriter) Write(b []byte) (int, error) {
  165. pw.mu.Lock()
  166. defer pw.mu.Unlock()
  167. n := len(b)
  168. pw.bucket += int64(n)
  169. // throttle status updates to not spam the client
  170. if pw.bucket >= 1024*1024 || pw.completed+pw.bucket >= pw.total {
  171. pw.completed += pw.bucket
  172. pw.fn(api.ProgressResponse{
  173. Status: pw.status,
  174. Digest: pw.digest,
  175. Total: pw.total,
  176. Completed: pw.completed,
  177. })
  178. pw.bucket = 0
  179. }
  180. return n, nil
  181. }