Browse Source

app: gracefully shut down `ollama serve` on windows (#3641)

* app: gracefully shut down `ollama serve` on windows

* fix linter errors

* bring back `HideWindow`

* remove creation flags

* restore `windows.CREATE_NEW_PROCESS_GROUP`
Jeffrey Morgan 1 năm trước cách đây
mục cha
commit
7027f264fb
3 tập tin đã thay đổi với 118 bổ sung7 xóa
  1. 15 6
      app/lifecycle/server.go
  2. 26 0
      app/lifecycle/server_unix.go
  3. 77 1
      app/lifecycle/server_windows.go

+ 15 - 6
app/lifecycle/server.go

@@ -9,7 +9,6 @@ import (
 	"os"
 	"os/exec"
 	"path/filepath"
-	"syscall"
 	"time"
 
 	"github.com/ollama/ollama/api"
@@ -87,19 +86,29 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
 	// Re-wire context done behavior to attempt a graceful shutdown of the server
 	cmd.Cancel = func() error {
 		if cmd.Process != nil {
-			cmd.Process.Signal(os.Interrupt) //nolint:errcheck
+			err := terminate(cmd)
+			if err != nil {
+				slog.Warn("error trying to gracefully terminate server", "err", err)
+				return cmd.Process.Kill()
+			}
+
 			tick := time.NewTicker(10 * time.Millisecond)
 			defer tick.Stop()
+
 			for {
 				select {
 				case <-tick.C:
-					// OS agnostic "is it still running"
-					if proc, err := os.FindProcess(int(cmd.Process.Pid)); err != nil || errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
-						return nil //nolint:nilerr
+					exited, err := isProcessExited(cmd.Process.Pid)
+					if err != nil {
+						return err
+					}
+
+					if exited {
+						return nil
 					}
 				case <-time.After(5 * time.Second):
 					slog.Warn("graceful server shutdown timeout, killing", "pid", cmd.Process.Pid)
-					cmd.Process.Kill() //nolint:errcheck
+					return cmd.Process.Kill()
 				}
 			}
 		}

+ 26 - 0
app/lifecycle/server_unix.go

@@ -4,9 +4,35 @@ package lifecycle
 
 import (
 	"context"
+	"errors"
+	"fmt"
+	"os"
 	"os/exec"
+	"syscall"
 )
 
 func getCmd(ctx context.Context, cmd string) *exec.Cmd {
 	return exec.CommandContext(ctx, cmd, "serve")
 }
+
+func terminate(cmd *exec.Cmd) error {
+	return cmd.Process.Signal(os.Interrupt)
+}
+
+func isProcessExited(pid int) (bool, error) {
+	proc, err := os.FindProcess(pid)
+	if err != nil {
+		return false, fmt.Errorf("failed to find process: %v", err)
+	}
+
+	err = proc.Signal(syscall.Signal(0))
+	if err != nil {
+		if errors.Is(err, os.ErrProcessDone) || errors.Is(err, syscall.ESRCH) {
+			return true, nil
+		}
+
+		return false, fmt.Errorf("error signaling process: %v", err)
+	}
+
+	return false, nil
+}

+ 77 - 1
app/lifecycle/server_windows.go

@@ -2,12 +2,88 @@ package lifecycle
 
 import (
 	"context"
+	"fmt"
 	"os/exec"
 	"syscall"
+
+	"golang.org/x/sys/windows"
 )
 
 func getCmd(ctx context.Context, exePath string) *exec.Cmd {
 	cmd := exec.CommandContext(ctx, exePath, "serve")
-	cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true, CreationFlags: 0x08000000}
+	cmd.SysProcAttr = &syscall.SysProcAttr{
+		HideWindow:    true,
+		CreationFlags: windows.CREATE_NEW_PROCESS_GROUP,
+	}
+
 	return cmd
 }
+
+func terminate(cmd *exec.Cmd) error {
+	dll, err := windows.LoadDLL("kernel32.dll")
+	if err != nil {
+		return err
+	}
+	defer dll.Release() // nolint: errcheck
+
+	pid := cmd.Process.Pid
+
+	f, err := dll.FindProc("AttachConsole")
+	if err != nil {
+		return err
+	}
+
+	r1, _, err := f.Call(uintptr(pid))
+	if r1 == 0 && err != syscall.ERROR_ACCESS_DENIED {
+		return err
+	}
+
+	f, err = dll.FindProc("SetConsoleCtrlHandler")
+	if err != nil {
+		return err
+	}
+
+	r1, _, err = f.Call(0, 1)
+	if r1 == 0 {
+		return err
+	}
+
+	f, err = dll.FindProc("GenerateConsoleCtrlEvent")
+	if err != nil {
+		return err
+	}
+
+	r1, _, err = f.Call(windows.CTRL_BREAK_EVENT, uintptr(pid))
+	if r1 == 0 {
+		return err
+	}
+
+	r1, _, err = f.Call(windows.CTRL_C_EVENT, uintptr(pid))
+	if r1 == 0 {
+		return err
+	}
+
+	return nil
+}
+
+const STILL_ACTIVE = 259
+
+func isProcessExited(pid int) (bool, error) {
+	hProcess, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION, false, uint32(pid))
+	if err != nil {
+		return false, fmt.Errorf("failed to open process: %v", err)
+	}
+	defer windows.CloseHandle(hProcess) // nolint: errcheck
+
+	var exitCode uint32
+	err = windows.GetExitCodeProcess(hProcess, &exitCode)
+	if err != nil {
+		return false, fmt.Errorf("failed to get exit code: %v", err)
+	}
+
+	if exitCode == STILL_ACTIVE {
+		return false, nil
+	}
+
+	return true, nil
+}