Sfoglia il codice sorgente

rename partial file

Michael Yang 1 anno fa
parent
commit
948323fa78
1 ha cambiato i file con 44 aggiunte e 31 eliminazioni
  1. 44 31
      server/models.go

+ 44 - 31
server/models.go

@@ -2,6 +2,7 @@ package server
 
 import (
 	"encoding/json"
+	"errors"
 	"fmt"
 	"io"
 	"net/http"
@@ -34,6 +35,14 @@ func (m *Model) FullName() string {
 	return path.Join(home, ".ollama", "models", m.Name+".bin")
 }
 
+func (m *Model) TempFile() string {
+	fullName := m.FullName()
+	return path.Join(
+		path.Dir(fullName),
+		fmt.Sprintf(".%s.part", path.Base(fullName)),
+	)
+}
+
 func getRemote(model string) (*Model, error) {
 	// resolve the model download from our directory
 	resp, err := http.Get(directoryURL)
@@ -66,37 +75,45 @@ func saveModel(model *Model, fn func(total, completed int64)) error {
 	if err != nil {
 		return fmt.Errorf("failed to download model: %w", err)
 	}
-	// check for resume
-	alreadyDownloaded := int64(0)
-	fileInfo, err := os.Stat(model.FullName())
-	if err != nil {
-		if !os.IsNotExist(err) {
-			return fmt.Errorf("failed to check resume model file: %w", err)
-		}
-		// file doesn't exist, create it now
-	} else {
-		alreadyDownloaded = fileInfo.Size()
-		req.Header.Add("Range", fmt.Sprintf("bytes=%d-", alreadyDownloaded))
+
+	// check if completed file exists
+	fi, err := os.Stat(model.FullName())
+	switch {
+	case errors.Is(err, os.ErrNotExist):
+		// noop, file doesn't exist so create it
+	case err != nil:
+		return fmt.Errorf("stat: %w", err)
+	default:
+		fn(fi.Size(), fi.Size())
+		return nil
+	}
+
+	var size int64
+
+	// completed file doesn't exist, check partial file
+	fi, err = os.Stat(model.TempFile())
+	switch {
+	case errors.Is(err, os.ErrNotExist):
+		// noop, file doesn't exist so create it
+	case err != nil:
+		return fmt.Errorf("stat: %w", err)
+	default:
+		size = fi.Size()
 	}
 
+	req.Header.Add("Range", fmt.Sprintf("bytes=%d-", size))
+
 	resp, err := client.Do(req)
 	if err != nil {
 		return fmt.Errorf("failed to download model: %w", err)
 	}
-
 	defer resp.Body.Close()
 
-	if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable {
-		// already downloaded
-		fn(alreadyDownloaded, alreadyDownloaded)
-		return nil
-	}
-
-	if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
+	if resp.StatusCode >= 400 {
 		return fmt.Errorf("failed to download model: %s", resp.Status)
 	}
 
-	out, err := os.OpenFile(model.FullName(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
+	out, err := os.OpenFile(model.TempFile(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
 	if err != nil {
 		panic(err)
 	}
@@ -104,27 +121,23 @@ func saveModel(model *Model, fn func(total, completed int64)) error {
 
 	totalSize, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
 
-	buf := make([]byte, 1024)
-	totalBytes := alreadyDownloaded
-	totalSize += alreadyDownloaded
+	totalBytes := size
+	totalSize += size
 
 	for {
-		n, err := resp.Body.Read(buf)
-		if err != nil && err != io.EOF {
+		n, err := io.CopyN(out, resp.Body, 8192)
+		if err != nil && !errors.Is(err, io.EOF) {
 			return err
 		}
+
 		if n == 0 {
 			break
 		}
-		if _, err := out.Write(buf[:n]); err != nil {
-			return err
-		}
-
-		totalBytes += int64(n)
 
+		totalBytes += n
 		fn(totalSize, totalBytes)
 	}
 
 	fn(totalSize, totalSize)
-	return nil
+	return os.Rename(model.TempFile(), model.FullName())
 }