Ver Fonte

pull fixes

Michael Yang há 1 ano atrás
pai
commit
0944b01e7d
3 ficheiros alterados com 14 adições e 18 exclusões
  1. 3 4
      cmd/cmd.go
  2. 10 13
      server/models.go
  3. 1 1
      server/routes.go

+ 3 - 4
cmd/cmd.go

@@ -59,7 +59,7 @@ func pull(model string) error {
 		&api.PullRequest{Model: model},
 		func(progress api.PullProgress) error {
 			if bar == nil {
-				if progress.Percent == 100 {
+				if progress.Percent >= 100 {
 					// already downloaded
 					return nil
 				}
@@ -73,10 +73,9 @@ func pull(model string) error {
 }
 
 func RunGenerate(_ *cobra.Command, args []string) error {
-	// join all args into a single prompt
-	prompt := strings.Join(args[1:], " ")
 	if len(args) > 1 {
-		return generate(args[0], prompt)
+		// join all args into a single prompt
+		return generate(args[0], strings.Join(args[1:], " "))
 	}
 
 	if term.IsTerminal(int(os.Stdin.Fd())) {

+ 10 - 13
server/models.go

@@ -119,25 +119,22 @@ func saveModel(model *Model, fn func(total, completed int64)) error {
 	}
 	defer out.Close()
 
-	totalSize, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
+	remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
+	completed := size
 
-	totalBytes := size
-	totalSize += size
+	total := remaining + completed
 
 	for {
-		n, err := io.CopyN(out, resp.Body, 8192)
-		if err != nil && !errors.Is(err, io.EOF) {
-			return err
+		fn(total, completed)
+		if completed >= total {
+			return os.Rename(model.TempFile(), model.FullName())
 		}
 
-		if n == 0 {
-			break
+		n , err := io.CopyN(out, resp.Body, 8192)
+		if err != nil && !errors.Is(err, io.EOF) {
+			return err
 		}
 
-		totalBytes += n
-		fn(totalSize, totalBytes)
+		completed += n
 	}
-
-	fn(totalSize, totalSize)
-	return os.Rename(model.TempFile(), model.FullName())
 }

+ 1 - 1
server/routes.go

@@ -112,7 +112,7 @@ func pull(c *gin.Context) {
 		ch <- api.PullProgress{
 			Total:     total,
 			Completed: completed,
-			Percent:   float64(total) / float64(completed) * 100,
+			Percent:   float64(completed) / float64(total) * 100,
 		}
 	}