浏览代码

resumable downloads

Bruce MacDonald 1 年之前
父节点
当前提交
c9f45abef3
共有 4 个文件被更改,包括 42 次插入17 次删除
  1. 20 7
      api/client.go
  2. 0 4
      api/types.go
  3. 6 3
      cmd/cmd.go
  4. 16 3
      server/models.go

+ 20 - 7
api/client.go

@@ -9,11 +9,12 @@ import (
 	"io"
 	"net/http"
 	"strings"
+	"sync"
 )
 
 type Client struct {
-	URL        string
-	HTTP       http.Client
+	URL  string
+	HTTP http.Client
 }
 
 func checkError(resp *http.Response, body []byte) error {
@@ -64,7 +65,14 @@ func (c *Client) stream(ctx context.Context, method string, path string, reqData
 	for {
 		line, err := reader.ReadBytes('\n')
 		if err != nil {
-			break
+			if err == io.EOF {
+				break
+			} else {
+				return err // Handle other errors
+			}
+		}
+		if err := checkError(res, line); err != nil {
+			return err
 		}
 		callback(bytes.TrimSuffix(line, []byte("\n")))
 	}
@@ -128,8 +136,9 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, callback fu
 	return &res, nil
 }
 
-func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progress PullProgress)) (*PullResponse, error) {
-	var res PullResponse
+func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progress PullProgress)) error {
+	var wg sync.WaitGroup
+	wg.Add(1)
 	if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(progressBytes []byte) {
 		/*
 			Events have the following format for progress:
@@ -148,10 +157,14 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progr
 			fmt.Println(err)
 			return
 		}
+		if progress.Completed >= progress.Total {
+			wg.Done()
+		}
 		callback(progress)
 	}); err != nil {
-		return nil, err
+		return err
 	}
 
-	return &res, nil
+	wg.Wait()
+	return nil
 }

+ 0 - 4
api/types.go

@@ -28,10 +28,6 @@ type PullProgress struct {
 	Percent   float64 `json:"percent"`
 }
 
-type PullResponse struct {
-	Response string `json:"response"`
-}
-
 type GenerateRequest struct {
 	Model  string `json:"model"`
 	Prompt string `json:"prompt"`

+ 6 - 3
cmd/cmd.go

@@ -40,7 +40,7 @@ func run(model string) error {
 	mutex := &sync.Mutex{}
 	var progressData api.PullProgress
 
-	callback := func(progress api.PullProgress) {
+	pullCallback := func(progress api.PullProgress) {
 		mutex.Lock()
 		progressData = progress
 		if bar == nil {
@@ -60,8 +60,11 @@ func run(model string) error {
 		bar.Set(int(progress.Completed))
 		mutex.Unlock()
 	}
-	_, err = client.Pull(context.Background(), &pr, callback)
-	return err
+	if err := client.Pull(context.Background(), &pr, pullCallback); err != nil {
+		return err
+	}
+	fmt.Println("Up to date.")
+	return nil
 }
 
 func serve() error {

+ 16 - 3
server/models.go

@@ -79,6 +79,7 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
 		panic(err)
 	}
 	// check for resume
+	alreadyDownloaded := 0
 	fileInfo, err := os.Stat(fileName)
 	if err != nil {
 		if !os.IsNotExist(err) {
@@ -86,7 +87,8 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
 		}
 		// file doesn't exist, create it now
 	} else {
-		req.Header.Add("Range", "bytes="+strconv.FormatInt(fileInfo.Size(), 10)+"-")
+		alreadyDownloaded = int(fileInfo.Size())
+		req.Header.Add("Range", "bytes="+strconv.Itoa(alreadyDownloaded)+"-")
 	}
 
 	resp, err := client.Do(req)
@@ -96,7 +98,17 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
 
 	defer resp.Body.Close()
 
-	if resp.StatusCode != http.StatusOK {
+	if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable {
+		// already downloaded
+		progressCh <- api.PullProgress{
+			Total:     alreadyDownloaded,
+			Completed: alreadyDownloaded,
+			Percent:   100,
+		}
+		return nil
+	}
+
+	if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
 		return fmt.Errorf("failed to download model: %s", resp.Status)
 	}
 
@@ -109,7 +121,8 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
 	totalSize, _ := strconv.Atoi(resp.Header.Get("Content-Length"))
 
 	buf := make([]byte, 1024)
-	totalBytes := 0
+	totalBytes := alreadyDownloaded
+	totalSize += alreadyDownloaded
 
 	for {
 		n, err := resp.Body.Read(buf)