Bruce MacDonald 1 год назад
Родитель
Сommit
8228d166ce
1 измененных файлов с 46 добавлено и 49 удалено
  1. 46 49
      server/download.go

+ 46 - 49
server/download.go

@@ -55,11 +55,7 @@ func downloadBlob(ctx context.Context, mp ModelPath, digest string, regOpts *Reg
 		// this is another client requesting the server to download the same blob concurrently
 		return monitorDownload(ctx, mp, regOpts, fileDownload, fn)
 	}
-	resp, err := requestDownload(ctx, mp, regOpts, fileDownload)
-	if err != nil {
-		return err
-	}
-	return doDownload(ctx, fileDownload, resp, fn)
+	return doDownload(ctx, mp, regOpts, fileDownload, fn)
 }
 
 var downloadMu sync.Mutex // mutex to check to resume a download while monitoring
@@ -68,49 +64,55 @@ var downloadMu sync.Mutex // mutex to check to resume a download while monitorin
 func monitorDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload, fn func(api.ProgressResponse)) error {
 	tick := time.NewTicker(time.Second)
 	for range tick.C {
-		downloadMu.Lock()
-		val, downloading := inProgress.Load(f.Digest)
-		if !downloading {
-			// check once again if the download is complete
-			if fi, _ := os.Stat(f.FilePath); fi != nil {
-				downloadMu.Unlock()
-				// successfull download while monitoring
-				fn(api.ProgressResponse{
-					Digest:    f.Digest,
-					Total:     int(fi.Size()),
-					Completed: int(fi.Size()),
-				})
-				return nil
+		done, resume, err := func() (bool, bool, error) {
+			downloadMu.Lock()
+			defer downloadMu.Unlock()
+			val, downloading := inProgress.Load(f.Digest)
+			if !downloading {
+				// check once again if the download is complete
+				if fi, _ := os.Stat(f.FilePath); fi != nil {
+					// successful download while monitoring
+					fn(api.ProgressResponse{
+						Digest:    f.Digest,
+						Total:     int(fi.Size()),
+						Completed: int(fi.Size()),
+					})
+					return true, false, nil
+				}
+				// resume the download
+				inProgress.Store(f.Digest, f) // store the file download again to claim the resume
+				return false, true, nil
 			}
-			// resume the download
-			resp, err := requestDownload(ctx, mp, regOpts, f)
-			if err != nil {
-				downloadMu.Unlock()
-				return fmt.Errorf("resume: %w", err)
+			f, ok := val.(*FileDownload)
+			if !ok {
+				return false, false, fmt.Errorf("invalid type for in progress download: %T", val)
 			}
-			inProgress.Store(f.Digest, f)
-			downloadMu.Unlock()
-			return doDownload(ctx, f, resp, fn)
+			fn(api.ProgressResponse{
+				Status:    fmt.Sprintf("downloading %s", f.Digest),
+				Digest:    f.Digest,
+				Total:     int(f.Total),
+				Completed: int(f.Completed),
+			})
+			return false, false, nil
+		}()
+		if err != nil {
+			return err
+		}
+		if done {
+			// done downloading
+			return nil
 		}
-		downloadMu.Unlock()
-		f, ok := val.(*FileDownload)
-		if !ok {
-			return fmt.Errorf("invalid type for in progress download: %T", val)
+		if resume {
+			return doDownload(ctx, mp, regOpts, f, fn)
 		}
-		fn(api.ProgressResponse{
-			Status:    fmt.Sprintf("downloading %s", f.Digest),
-			Digest:    f.Digest,
-			Total:     int(f.Total),
-			Completed: int(f.Completed),
-		})
 	}
 	return nil
 }
 
 var chunkSize = 1024 * 1024 // 1 MiB in bytes
 
-// requestDownload requests a blob from the registry and returns the response
-func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload) (*http.Response, error) {
+// doDownload downloads a blob from the registry and stores it in the blobs directory
+func doDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload, fn func(api.ProgressResponse)) error {
 	var size int64
 
 	fi, err := os.Stat(f.FilePath + "-partial")
@@ -118,7 +120,7 @@ func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions
 	case errors.Is(err, os.ErrNotExist):
 		// noop, file doesn't exist so create it
 	case err != nil:
-		return nil, fmt.Errorf("stat: %w", err)
+		return fmt.Errorf("stat: %w", err)
 	default:
 		size = fi.Size()
 		// Ensure the size is divisible by the chunk size by removing excess bytes
@@ -126,7 +128,7 @@ func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions
 
 		err := os.Truncate(f.FilePath+"-partial", size)
 		if err != nil {
-			return nil, fmt.Errorf("truncate: %w", err)
+			return fmt.Errorf("truncate: %w", err)
 		}
 	}
 
@@ -138,18 +140,18 @@ func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions
 	resp, err := makeRequest("GET", url, headers, nil, regOpts)
 	if err != nil {
 		log.Printf("couldn't download blob: %v", err)
-		return nil, err
+		return err
 	}
-	// resp MUST be closed by doDownload, which should follow this function
+	defer resp.Body.Close()
 
 	if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
 		body, _ := io.ReadAll(resp.Body)
-		return nil, fmt.Errorf("on download registry responded with code %d: %v", resp.StatusCode, string(body))
+		return fmt.Errorf("on download registry responded with code %d: %v", resp.StatusCode, string(body))
 	}
 
 	err = os.MkdirAll(path.Dir(f.FilePath), 0o700)
 	if err != nil {
-		return nil, fmt.Errorf("make blobs directory: %w", err)
+		return fmt.Errorf("make blobs directory: %w", err)
 	}
 
 	remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
@@ -157,12 +159,7 @@ func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions
 	f.Total = remaining + f.Completed
 
 	inProgress.Store(f.Digest, f)
-	return resp, nil
-}
 
-// doDownload downloads a blob from the registry and stores it in the blobs directory
-func doDownload(ctx context.Context, f *FileDownload, resp *http.Response, fn func(api.ProgressResponse)) error {
-	defer resp.Body.Close()
 	out, err := os.OpenFile(f.FilePath+"-partial", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
 	if err != nil {
 		return fmt.Errorf("open file: %w", err)