Selaa lähdekoodia

add "stop" command (#6739)

Patrick Devine 7 kuukautta sitten
vanhempi
commit
abed273de3
5 muutettua tiedostoa jossa 172 lisäystä ja 25 poistoa
  1. 54 2
      cmd/cmd.go
  2. 1 22
      cmd/interactive.go
  3. 52 0
      server/routes.go
  4. 19 1
      server/sched.go
  5. 46 0
      server/sched_test.go

+ 54 - 2
cmd/cmd.go

@@ -346,6 +346,39 @@ func (w *progressWriter) Write(p []byte) (n int, err error) {
 	return len(p), nil
 }
 
+func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
+	p := progress.NewProgress(os.Stderr)
+	defer p.StopAndClear()
+
+	spinner := progress.NewSpinner("")
+	p.Add("", spinner)
+
+	client, err := api.ClientFromEnvironment()
+	if err != nil {
+		return err
+	}
+
+	req := &api.GenerateRequest{
+		Model:     opts.Model,
+		KeepAlive: opts.KeepAlive,
+	}
+
+	return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
+}
+
+func StopHandler(cmd *cobra.Command, args []string) error {
+	opts := &runOptions{
+		Model:     args[0],
+		KeepAlive: &api.Duration{Duration: 0},
+	}
+	if err := loadOrUnloadModel(cmd, opts); err != nil {
+		if strings.Contains(err.Error(), "not found") {
+			return fmt.Errorf("couldn't find model \"%s\" to stop", args[0])
+		}
+	}
+	return nil
+}
+
 func RunHandler(cmd *cobra.Command, args []string) error {
 	interactive := true
 
@@ -424,7 +457,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
 	opts.ParentModel = info.Details.ParentModel
 
 	if interactive {
-		if err := loadModel(cmd, &opts); err != nil {
+		if err := loadOrUnloadModel(cmd, &opts); err != nil {
 			return err
 		}
 
@@ -615,7 +648,15 @@ func ListRunningHandler(cmd *cobra.Command, args []string) error {
 				cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100)
 				procStr = fmt.Sprintf("%d%%/%d%% CPU/GPU", int(cpuPercent), int(100-cpuPercent))
 			}
-			data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, format.HumanTime(m.ExpiresAt, "Never")})
+
+			var until string
+			delta := time.Since(m.ExpiresAt)
+			if delta > 0 {
+				until = "Stopping..."
+			} else {
+				until = format.HumanTime(m.ExpiresAt, "Never")
+			}
+			data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, until})
 		}
 	}
 
@@ -1294,6 +1335,15 @@ func NewCLI() *cobra.Command {
 	runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
 	runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
 	runCmd.Flags().String("format", "", "Response format (e.g. json)")
+
+	stopCmd := &cobra.Command{
+		Use:     "stop MODEL",
+		Short:   "Stop a running model",
+		Args:    cobra.ExactArgs(1),
+		PreRunE: checkServerHeartbeat,
+		RunE:    StopHandler,
+	}
+
 	serveCmd := &cobra.Command{
 		Use:     "serve",
 		Aliases: []string{"start"},
@@ -1361,6 +1411,7 @@ func NewCLI() *cobra.Command {
 		createCmd,
 		showCmd,
 		runCmd,
+		stopCmd,
 		pullCmd,
 		pushCmd,
 		listCmd,
@@ -1400,6 +1451,7 @@ func NewCLI() *cobra.Command {
 		createCmd,
 		showCmd,
 		runCmd,
+		stopCmd,
 		pullCmd,
 		pushCmd,
 		listCmd,

+ 1 - 22
cmd/interactive.go

@@ -18,7 +18,6 @@ import (
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/envconfig"
 	"github.com/ollama/ollama/parser"
-	"github.com/ollama/ollama/progress"
 	"github.com/ollama/ollama/readline"
 	"github.com/ollama/ollama/types/errtypes"
 )
@@ -31,26 +30,6 @@ const (
 	MultilineSystem
 )
 
-func loadModel(cmd *cobra.Command, opts *runOptions) error {
-	p := progress.NewProgress(os.Stderr)
-	defer p.StopAndClear()
-
-	spinner := progress.NewSpinner("")
-	p.Add("", spinner)
-
-	client, err := api.ClientFromEnvironment()
-	if err != nil {
-		return err
-	}
-
-	chatReq := &api.ChatRequest{
-		Model:     opts.Model,
-		KeepAlive: opts.KeepAlive,
-	}
-
-	return client.Chat(cmd.Context(), chatReq, func(api.ChatResponse) error { return nil })
-}
-
 func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 	usage := func() {
 		fmt.Fprintln(os.Stderr, "Available Commands:")
@@ -217,7 +196,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 			opts.Model = args[1]
 			opts.Messages = []api.Message{}
 			fmt.Printf("Loading model '%s'\n", opts.Model)
-			if err := loadModel(cmd, &opts); err != nil {
+			if err := loadOrUnloadModel(cmd, &opts); err != nil {
 				return err
 			}
 			continue

+ 52 - 0
server/routes.go

@@ -117,6 +117,32 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		return
 	}
 
+	// expire the runner
+	if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
+		model, err := GetModel(req.Model)
+		if err != nil {
+			switch {
+			case os.IsNotExist(err):
+				c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
+			case err.Error() == "invalid model name":
+				c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+			default:
+				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			}
+			return
+		}
+		s.sched.expireRunner(model)
+
+		c.JSON(http.StatusOK, api.GenerateResponse{
+			Model:      req.Model,
+			CreatedAt:  time.Now().UTC(),
+			Response:   "",
+			Done:       true,
+			DoneReason: "unload",
+		})
+		return
+	}
+
 	if req.Format != "" && req.Format != "json" {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
 		return
@@ -1322,6 +1348,32 @@ func (s *Server) ChatHandler(c *gin.Context) {
 		return
 	}
 
+	// expire the runner
+	if len(req.Messages) == 0 && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
+		model, err := GetModel(req.Model)
+		if err != nil {
+			switch {
+			case os.IsNotExist(err):
+				c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
+			case err.Error() == "invalid model name":
+				c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+			default:
+				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			}
+			return
+		}
+		s.sched.expireRunner(model)
+
+		c.JSON(http.StatusOK, api.ChatResponse{
+			Model:      req.Model,
+			CreatedAt:  time.Now().UTC(),
+			Message:    api.Message{Role: "assistant"},
+			Done:       true,
+			DoneReason: "unload",
+		})
+		return
+	}
+
 	caps := []Capability{CapabilityCompletion}
 	if len(req.Tools) > 0 {
 		caps = append(caps, CapabilityTools)

+ 19 - 1
server/sched.go

@@ -360,7 +360,6 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
 			slog.Debug("runner expired event received", "modelPath", runner.modelPath)
 			runner.refMu.Lock()
 			if runner.refCount > 0 {
-				// Shouldn't happen, but safeguard to ensure no leaked runners
 				slog.Debug("expired event with positive ref count, retrying", "modelPath", runner.modelPath, "refCount", runner.refCount)
 				go func(runner *runnerRef) {
 					// We can't unload yet, but want to as soon as the current request completes
@@ -802,6 +801,25 @@ func (s *Scheduler) unloadAllRunners() {
 	}
 }
 
+func (s *Scheduler) expireRunner(model *Model) {
+	s.loadedMu.Lock()
+	defer s.loadedMu.Unlock()
+	runner, ok := s.loaded[model.ModelPath]
+	if ok {
+		runner.refMu.Lock()
+		runner.expiresAt = time.Now()
+		if runner.expireTimer != nil {
+			runner.expireTimer.Stop()
+			runner.expireTimer = nil
+		}
+		runner.sessionDuration = 0
+		if runner.refCount <= 0 {
+			s.expiredCh <- runner
+		}
+		runner.refMu.Unlock()
+	}
+}
+
 // If other runners are loaded, make sure the pending request will fit in system memory
 // If not, pick a runner to unload, else return nil and the request can be loaded
 func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) *runnerRef {

+ 46 - 0
server/sched_test.go

@@ -406,6 +406,52 @@ func TestGetRunner(t *testing.T) {
 	b.ctxDone()
 }
 
+func TestExpireRunner(t *testing.T) {
+	ctx, done := context.WithTimeout(context.Background(), 20*time.Millisecond)
+	defer done()
+	s := InitScheduler(ctx)
+	req := &LlmRequest{
+		ctx:             ctx,
+		model:           &Model{ModelPath: "foo"},
+		opts:            api.DefaultOptions(),
+		successCh:       make(chan *runnerRef, 1),
+		errCh:           make(chan error, 1),
+		sessionDuration: &api.Duration{Duration: 2 * time.Minute},
+	}
+
+	var ggml *llm.GGML
+	gpus := gpu.GpuInfoList{}
+	server := &mockLlm{estimatedVRAM: 10, estimatedVRAMByGPU: map[string]uint64{}}
+	s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
+		return server, nil
+	}
+	s.load(req, ggml, gpus, 0)
+
+	select {
+	case err := <-req.errCh:
+		if err != nil {
+			t.Fatalf("expected no errors when loading, got '%s'", err.Error())
+		}
+	case resp := <-req.successCh:
+		s.loadedMu.Lock()
+		if resp.refCount != uint(1) || len(s.loaded) != 1 {
+			t.Fatalf("expected a model to be loaded")
+		}
+		s.loadedMu.Unlock()
+	}
+
+	s.expireRunner(&Model{ModelPath: "foo"})
+
+	s.finishedReqCh <- req
+	s.processCompleted(ctx)
+
+	s.loadedMu.Lock()
+	if len(s.loaded) != 0 {
+		t.Fatalf("expected model to be unloaded")
+	}
+	s.loadedMu.Unlock()
+}
+
 // TODO - add one scenario that triggers the bogus finished event with positive ref count
 func TestPrematureExpired(t *testing.T) {
 	ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)