浏览代码

allow the user to cancel generating with ctrl-C (#641)

Patrick Devine 1 年之前
父节点
当前提交
76db4a49cf
共有 1 个文件被更改,包括 24 次插入3 次删除
  1. 24 3
      cmd/cmd.go

+ 24 - 3
cmd/cmd.go

@@ -13,9 +13,11 @@ import (
 	"net"
 	"os"
 	"os/exec"
+	"os/signal"
 	"path/filepath"
 	"runtime"
 	"strings"
+	"syscall"
 	"time"
 
 	"github.com/dustin/go-humanize"
@@ -43,7 +45,7 @@ func (p Painter) Paint(line []rune, _ int) []rune {
 		if p.IsMultiLine {
 			prompt = "Use \"\"\" to end multi-line input"
 		} else {
-			prompt = "Send a message (/? for help, /bye to exit)"
+			prompt = "Send a message (/? for help)"
 		}
 		return []rune(fmt.Sprintf("\033[38;5;245m%s\033[%dD\033[0m", prompt, len(prompt)))
 	}
@@ -426,6 +428,19 @@ func generate(cmd *cobra.Command, model, prompt string) error {
 		wrapTerm = false
 	}
 
+	cancelCtx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	sigChan := make(chan os.Signal, 1)
+	signal.Notify(sigChan, syscall.SIGINT)
+	var abort bool
+
+	go func() {
+		<-sigChan
+		cancel()
+		abort = true
+	}()
+
 	var currentLineLength int
 	var wordBuffer string
 
@@ -465,7 +480,7 @@ func generate(cmd *cobra.Command, model, prompt string) error {
 		return nil
 	}
 
-	if err := client.Generate(context.Background(), &request, fn); err != nil {
+	if err := client.Generate(cancelCtx, &request, fn); err != nil {
 		if strings.Contains(err.Error(), "failed to load model") {
 			// tell the user to check the server log, if it exists locally
 			home, nestedErr := os.UserHomeDir()
@@ -477,6 +492,9 @@ func generate(cmd *cobra.Command, model, prompt string) error {
 			if _, nestedErr := os.Stat(logPath); nestedErr == nil {
 				err = fmt.Errorf("%w\nFor more details, check the error logs at %s", err, logPath)
 			}
+		} else if strings.Contains(err.Error(), "context canceled") && abort {
+			spinner.Finish()
+			return nil
 		}
 		return err
 	}
@@ -486,6 +504,9 @@ func generate(cmd *cobra.Command, model, prompt string) error {
 	}
 
 	if !latest.Done {
+		if abort {
+			return nil
+		}
 		return errors.New("unexpected end of response")
 	}
 
@@ -568,7 +589,7 @@ func generateInteractive(cmd *cobra.Command, model string) error {
 			return nil
 		case errors.Is(err, readline.ErrInterrupt):
 			if line == "" {
-				return nil
+				fmt.Println("Use Ctrl-D or /bye to exit.")
 			}
 
 			continue