Browse Source

cmd: handle sigint globally

This change also updates both client.do and client.stream to return
ctx.Err(). Previously this error is skipped so canceled contexts are
silently ignored
Michael Yang 2 months ago
parent
commit
fcfbb06f1b
3 changed files with 20 additions and 37 deletions
  1. 3 2
      api/client.go
  2. 4 34
      cmd/cmd.go
  3. 13 1
      main.go

+ 3 - 2
api/client.go

@@ -126,7 +126,8 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
 			return err
 			return err
 		}
 		}
 	}
 	}
-	return nil
+
+	return ctx.Err()
 }
 }
 
 
 const maxBufferSize = 512 * format.KiloByte
 const maxBufferSize = 512 * format.KiloByte
@@ -189,7 +190,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
 		}
 		}
 	}
 	}
 
 
-	return nil
+	return ctx.Err()
 }
 }
 
 
 // GenerateResponseFunc is a function that [Client.Generate] invokes every time
 // GenerateResponseFunc is a function that [Client.Generate] invokes every time

+ 4 - 34
cmd/cmd.go

@@ -15,13 +15,11 @@ import (
 	"net"
 	"net"
 	"net/http"
 	"net/http"
 	"os"
 	"os"
-	"os/signal"
 	"path/filepath"
 	"path/filepath"
 	"runtime"
 	"runtime"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
 	"sync/atomic"
 	"sync/atomic"
-	"syscall"
 	"time"
 	"time"
 
 
 	"github.com/containerd/console"
 	"github.com/containerd/console"
@@ -330,6 +328,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
 			if err := PullHandler(cmd, []string{name}); err != nil {
 			if err := PullHandler(cmd, []string{name}); err != nil {
 				return nil, err
 				return nil, err
 			}
 			}
+
 			return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
 			return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
 		}
 		}
 		return info, err
 		return info, err
@@ -858,17 +857,6 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
 	spinner := progress.NewSpinner("")
 	spinner := progress.NewSpinner("")
 	p.Add("", spinner)
 	p.Add("", spinner)
 
 
-	cancelCtx, cancel := context.WithCancel(cmd.Context())
-	defer cancel()
-
-	sigChan := make(chan os.Signal, 1)
-	signal.Notify(sigChan, syscall.SIGINT)
-
-	go func() {
-		<-sigChan
-		cancel()
-	}()
-
 	var state *displayResponseState = &displayResponseState{}
 	var state *displayResponseState = &displayResponseState{}
 	var latest api.ChatResponse
 	var latest api.ChatResponse
 	var fullResponse strings.Builder
 	var fullResponse strings.Builder
@@ -903,10 +891,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
 		req.KeepAlive = opts.KeepAlive
 		req.KeepAlive = opts.KeepAlive
 	}
 	}
 
 
-	if err := client.Chat(cancelCtx, req, fn); err != nil {
-		if errors.Is(err, context.Canceled) {
-			return nil, nil
-		}
+	if err := client.Chat(cmd.Context(), req, fn); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
@@ -946,17 +931,6 @@ func generate(cmd *cobra.Command, opts runOptions) error {
 		generateContext = []int{}
 		generateContext = []int{}
 	}
 	}
 
 
-	ctx, cancel := context.WithCancel(cmd.Context())
-	defer cancel()
-
-	sigChan := make(chan os.Signal, 1)
-	signal.Notify(sigChan, syscall.SIGINT)
-
-	go func() {
-		<-sigChan
-		cancel()
-	}()
-
 	var state *displayResponseState = &displayResponseState{}
 	var state *displayResponseState = &displayResponseState{}
 
 
 	fn := func(response api.GenerateResponse) error {
 	fn := func(response api.GenerateResponse) error {
@@ -992,10 +966,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
 		KeepAlive: opts.KeepAlive,
 		KeepAlive: opts.KeepAlive,
 	}
 	}
 
 
-	if err := client.Generate(ctx, &request, fn); err != nil {
-		if errors.Is(err, context.Canceled) {
-			return nil
-		}
+	if err := client.Generate(cmd.Context(), &request, fn); err != nil {
 		return err
 		return err
 	}
 	}
 
 
@@ -1017,8 +988,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
 		latest.Summary()
 		latest.Summary()
 	}
 	}
 
 
-	ctx = context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context)
-	cmd.SetContext(ctx)
+	cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context))
 
 
 	return nil
 	return nil
 }
 }

+ 13 - 1
main.go

@@ -2,6 +2,8 @@ package main
 
 
 import (
 import (
 	"context"
 	"context"
+	"os"
+	"os/signal"
 
 
 	"github.com/spf13/cobra"
 	"github.com/spf13/cobra"
 
 
@@ -9,5 +11,15 @@ import (
 )
 )
 
 
 func main() {
 func main() {
-	cobra.CheckErr(cmd.NewCLI().ExecuteContext(context.Background()))
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	sigChan := make(chan os.Signal, 1)
+	signal.Notify(sigChan, os.Interrupt)
+	go func() {
+		<-sigChan
+		cancel()
+	}()
+
+	cobra.CheckErr(cmd.NewCLI().ExecuteContext(ctx))
 }
 }