浏览代码

Ollama `ps` command for showing currently loaded models (#4327)

Patrick Devine 11 月之前
父节点
当前提交
6845988807
共有 10 个文件被更改,包括 193 次插入50 次删除
  1. 9 0
      api/client.go
  2. 3 1
      api/types.go
  3. 75 0
      cmd/cmd.go
  4. 5 0
      cmd/interactive.go
  5. 3 1
      format/time.go
  6. 10 0
      format/time_test.go
  7. 5 0
      llm/server.go
  8. 29 0
      server/routes.go
  9. 42 38
      server/sched.go
  10. 12 10
      server/sched_test.go

+ 9 - 0
api/client.go

@@ -354,6 +354,15 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) {
 	return &lr, nil
 	return &lr, nil
 }
 }
 
 
+// List running models.
+func (c *Client) ListRunning(ctx context.Context) (*ListResponse, error) {
+	var lr ListResponse
+	if err := c.do(ctx, http.MethodGet, "/api/ps", nil, &lr); err != nil {
+		return nil, err
+	}
+	return &lr, nil
+}
+
 // Copy copies a model - creating a model with another name from an existing
 // Copy copies a model - creating a model with another name from an existing
 // model.
 // model.
 func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
 func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {

+ 3 - 1
api/types.go

@@ -289,10 +289,12 @@ type ListResponse struct {
 type ModelResponse struct {
 type ModelResponse struct {
 	Name       string       `json:"name"`
 	Name       string       `json:"name"`
 	Model      string       `json:"model"`
 	Model      string       `json:"model"`
-	ModifiedAt time.Time    `json:"modified_at"`
+	ModifiedAt time.Time    `json:"modified_at,omitempty"`
 	Size       int64        `json:"size"`
 	Size       int64        `json:"size"`
 	Digest     string       `json:"digest"`
 	Digest     string       `json:"digest"`
 	Details    ModelDetails `json:"details,omitempty"`
 	Details    ModelDetails `json:"details,omitempty"`
+	ExpiresAt  time.Time    `json:"expires_at,omitempty"`
+	SizeVRAM   int64        `json:"size_vram,omitempty"`
 }
 }
 
 
 type TokenResponse struct {
 type TokenResponse struct {

+ 75 - 0
cmd/cmd.go

@@ -12,6 +12,7 @@ import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"log"
 	"log"
+	"math"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
 	"os"
 	"os"
@@ -324,6 +325,18 @@ func RunHandler(cmd *cobra.Command, args []string) error {
 	}
 	}
 	opts.Format = format
 	opts.Format = format
 
 
+	keepAlive, err := cmd.Flags().GetString("keepalive")
+	if err != nil {
+		return err
+	}
+	if keepAlive != "" {
+		d, err := time.ParseDuration(keepAlive)
+		if err != nil {
+			return err
+		}
+		opts.KeepAlive = &api.Duration{Duration: d}
+	}
+
 	prompts := args[1:]
 	prompts := args[1:]
 	// prepend stdin to the prompt if provided
 	// prepend stdin to the prompt if provided
 	if !term.IsTerminal(int(os.Stdin.Fd())) {
 	if !term.IsTerminal(int(os.Stdin.Fd())) {
@@ -496,6 +509,52 @@ func ListHandler(cmd *cobra.Command, args []string) error {
 	return nil
 	return nil
 }
 }
 
 
+func ListRunningHandler(cmd *cobra.Command, args []string) error {
+	client, err := api.ClientFromEnvironment()
+	if err != nil {
+		return err
+	}
+
+	models, err := client.ListRunning(cmd.Context())
+	if err != nil {
+		return err
+	}
+
+	var data [][]string
+
+	for _, m := range models.Models {
+		if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) {
+			var procStr string
+			switch {
+			case m.SizeVRAM == 0:
+				procStr = "100% CPU"
+			case m.SizeVRAM == m.Size:
+				procStr = "100% GPU"
+			case m.SizeVRAM > m.Size || m.Size == 0:
+				procStr = "Unknown"
+			default:
+				sizeCPU := m.Size - m.SizeVRAM
+				cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100)
+				procStr = fmt.Sprintf("%d%%/%d%% CPU/GPU", int(cpuPercent), int(100-cpuPercent))
+			}
+			data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, format.HumanTime(m.ExpiresAt, "Never")})
+		}
+	}
+
+	table := tablewriter.NewWriter(os.Stdout)
+	table.SetHeader([]string{"NAME", "ID", "SIZE", "PROCESSOR", "UNTIL"})
+	table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
+	table.SetAlignment(tablewriter.ALIGN_LEFT)
+	table.SetHeaderLine(false)
+	table.SetBorder(false)
+	table.SetNoWhiteSpace(true)
+	table.SetTablePadding("\t")
+	table.AppendBulk(data)
+	table.Render()
+
+	return nil
+}
+
 func DeleteHandler(cmd *cobra.Command, args []string) error {
 func DeleteHandler(cmd *cobra.Command, args []string) error {
 	client, err := api.ClientFromEnvironment()
 	client, err := api.ClientFromEnvironment()
 	if err != nil {
 	if err != nil {
@@ -672,6 +731,7 @@ type runOptions struct {
 	Images      []api.ImageData
 	Images      []api.ImageData
 	Options     map[string]interface{}
 	Options     map[string]interface{}
 	MultiModal  bool
 	MultiModal  bool
+	KeepAlive   *api.Duration
 }
 }
 
 
 type displayResponseState struct {
 type displayResponseState struct {
@@ -766,6 +826,10 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
 		Options:  opts.Options,
 		Options:  opts.Options,
 	}
 	}
 
 
+	if opts.KeepAlive != nil {
+		req.KeepAlive = opts.KeepAlive
+	}
+
 	if err := client.Chat(cancelCtx, req, fn); err != nil {
 	if err := client.Chat(cancelCtx, req, fn); err != nil {
 		if errors.Is(err, context.Canceled) {
 		if errors.Is(err, context.Canceled) {
 			return nil, nil
 			return nil, nil
@@ -1075,6 +1139,7 @@ func NewCLI() *cobra.Command {
 		RunE:    RunHandler,
 		RunE:    RunHandler,
 	}
 	}
 
 
+	runCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
 	runCmd.Flags().Bool("verbose", false, "Show timings for response")
 	runCmd.Flags().Bool("verbose", false, "Show timings for response")
 	runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
 	runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
 	runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
 	runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
@@ -1123,6 +1188,14 @@ Environment Variables:
 		PreRunE: checkServerHeartbeat,
 		PreRunE: checkServerHeartbeat,
 		RunE:    ListHandler,
 		RunE:    ListHandler,
 	}
 	}
+
+	psCmd := &cobra.Command{
+		Use:     "ps",
+		Short:   "List running models",
+		PreRunE: checkServerHeartbeat,
+		RunE:    ListRunningHandler,
+	}
+
 	copyCmd := &cobra.Command{
 	copyCmd := &cobra.Command{
 		Use:     "cp SOURCE DESTINATION",
 		Use:     "cp SOURCE DESTINATION",
 		Short:   "Copy a model",
 		Short:   "Copy a model",
@@ -1146,6 +1219,7 @@ Environment Variables:
 		pullCmd,
 		pullCmd,
 		pushCmd,
 		pushCmd,
 		listCmd,
 		listCmd,
+		psCmd,
 		copyCmd,
 		copyCmd,
 		deleteCmd,
 		deleteCmd,
 	} {
 	} {
@@ -1160,6 +1234,7 @@ Environment Variables:
 		pullCmd,
 		pullCmd,
 		pushCmd,
 		pushCmd,
 		listCmd,
 		listCmd,
+		psCmd,
 		copyCmd,
 		copyCmd,
 		deleteCmd,
 		deleteCmd,
 	)
 	)

+ 5 - 0
cmd/interactive.go

@@ -56,6 +56,11 @@ func loadModel(cmd *cobra.Command, opts *runOptions) error {
 		Model:    opts.Model,
 		Model:    opts.Model,
 		Messages: []api.Message{},
 		Messages: []api.Message{},
 	}
 	}
+
+	if opts.KeepAlive != nil {
+		chatReq.KeepAlive = opts.KeepAlive
+	}
+
 	err = client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error {
 	err = client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error {
 		p.StopAndClear()
 		p.StopAndClear()
 		if len(opts.Messages) > 0 {
 		if len(opts.Messages) > 0 {

+ 3 - 1
format/time.go

@@ -60,7 +60,9 @@ func humanTime(t time.Time, zeroValue string) string {
 	}
 	}
 
 
 	delta := time.Since(t)
 	delta := time.Since(t)
-	if delta < 0 {
+	if int(delta.Hours())/24/365 < -20 {
+		return "Forever"
+	} else if delta < 0 {
 		return humanDuration(-delta) + " from now"
 		return humanDuration(-delta) + " from now"
 	}
 	}
 
 

+ 10 - 0
format/time_test.go

@@ -32,4 +32,14 @@ func TestHumanTime(t *testing.T) {
 		v := now.Add(800 * time.Millisecond)
 		v := now.Add(800 * time.Millisecond)
 		assertEqual(t, HumanTime(v, ""), "Less than a second from now")
 		assertEqual(t, HumanTime(v, ""), "Less than a second from now")
 	})
 	})
+
+	t.Run("time way in the future", func(t *testing.T) {
+		v := now.Add(24 * time.Hour * 365 * 200)
+		assertEqual(t, HumanTime(v, ""), "Forever")
+	})
+
+	t.Run("time way in the future lowercase", func(t *testing.T) {
+		v := now.Add(24 * time.Hour * 365 * 200)
+		assertEqual(t, HumanTimeLower(v, ""), "forever")
+	})
 }
 }

+ 5 - 0
llm/server.go

@@ -38,6 +38,7 @@ type LlamaServer interface {
 	Detokenize(ctx context.Context, tokens []int) (string, error)
 	Detokenize(ctx context.Context, tokens []int) (string, error)
 	Close() error
 	Close() error
 	EstimatedVRAM() uint64
 	EstimatedVRAM() uint64
+	EstimatedTotal() uint64
 }
 }
 
 
 // llmServer is an instance of the llama.cpp server
 // llmServer is an instance of the llama.cpp server
@@ -955,6 +956,10 @@ func (s *llmServer) EstimatedVRAM() uint64 {
 	return s.estimatedVRAM
 	return s.estimatedVRAM
 }
 }
 
 
+func (s *llmServer) EstimatedTotal() uint64 {
+	return s.estimatedTotal
+}
+
 func parseDurationMs(ms float64) time.Duration {
 func parseDurationMs(ms float64) time.Duration {
 	dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
 	dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
 	if err != nil {
 	if err != nil {

+ 29 - 0
server/routes.go

@@ -979,6 +979,7 @@ func (s *Server) GenerateRoutes() http.Handler {
 	r.POST("/api/show", s.ShowModelHandler)
 	r.POST("/api/show", s.ShowModelHandler)
 	r.POST("/api/blobs/:digest", s.CreateBlobHandler)
 	r.POST("/api/blobs/:digest", s.CreateBlobHandler)
 	r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
 	r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
+	r.GET("/api/ps", s.ProcessHandler)
 
 
 	// Compatibility endpoints
 	// Compatibility endpoints
 	r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler)
 	r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler)
@@ -1137,6 +1138,34 @@ func streamResponse(c *gin.Context, ch chan any) {
 	})
 	})
 }
 }
 
 
+func (s *Server) ProcessHandler(c *gin.Context) {
+	models := []api.ModelResponse{}
+
+	for _, v := range s.sched.loaded {
+		model := v.model
+		modelDetails := api.ModelDetails{
+			Format:            model.Config.ModelFormat,
+			Family:            model.Config.ModelFamily,
+			Families:          model.Config.ModelFamilies,
+			ParameterSize:     model.Config.ModelType,
+			QuantizationLevel: model.Config.FileType,
+		}
+
+		mr := api.ModelResponse{
+			Model:     model.ShortName,
+			Name:      model.ShortName,
+			Size:      int64(v.estimatedTotal),
+			SizeVRAM:  int64(v.estimatedVRAM),
+			Digest:    model.Digest,
+			Details:   modelDetails,
+			ExpiresAt: v.expiresAt,
+		}
+		models = append(models, mr)
+	}
+
+	c.JSON(http.StatusOK, api.ListResponse{Models: models})
+}
+
 // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
 // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
 func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) {
 func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) {
 	encode := func(s string) ([]int, error) {
 	encode := func(s string) ([]int, error) {

+ 42 - 38
server/sched.go

@@ -177,7 +177,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
 				}
 				}
 				// Trigger an expiration to unload once it's done
 				// Trigger an expiration to unload once it's done
 				runnerToExpire.refMu.Lock()
 				runnerToExpire.refMu.Lock()
-				slog.Debug("resetting model to expire immediately to make room", "model", runnerToExpire.model, "refCount", runnerToExpire.refCount)
+				slog.Debug("resetting model to expire immediately to make room", "modelPath", runnerToExpire.modelPath, "refCount", runnerToExpire.refCount)
 				if runnerToExpire.expireTimer != nil {
 				if runnerToExpire.expireTimer != nil {
 					runnerToExpire.expireTimer.Stop()
 					runnerToExpire.expireTimer.Stop()
 					runnerToExpire.expireTimer = nil
 					runnerToExpire.expireTimer = nil
@@ -190,13 +190,13 @@ func (s *Scheduler) processPending(ctx context.Context) {
 				// Wait for the unload to happen
 				// Wait for the unload to happen
 				// Note: at this point we're queueing up all incoming requests, even if they were for
 				// Note: at this point we're queueing up all incoming requests, even if they were for
 				// a different model that's loaded and not scheduled to be removed.
 				// a different model that's loaded and not scheduled to be removed.
-				slog.Debug("waiting for pending requests to complete and unload to occur", "model", runnerToExpire.model)
+				slog.Debug("waiting for pending requests to complete and unload to occur", "modelPath", runnerToExpire.modelPath)
 				select {
 				select {
 				case <-ctx.Done():
 				case <-ctx.Done():
 					slog.Debug("shutting down scheduler pending loop")
 					slog.Debug("shutting down scheduler pending loop")
 					return
 					return
 				case <-s.unloadedCh:
 				case <-s.unloadedCh:
-					slog.Debug("unload completed", "model", runnerToExpire.model)
+					slog.Debug("unload completed", "modelPath", runnerToExpire.modelPath)
 					continue
 					continue
 				}
 				}
 			}
 			}
@@ -219,23 +219,23 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
 			runner := s.loaded[finished.model.ModelPath]
 			runner := s.loaded[finished.model.ModelPath]
 			s.loadedMu.Unlock()
 			s.loadedMu.Unlock()
 			if runner == nil {
 			if runner == nil {
-				slog.Error("finished requeset signal received after model unloaded", "model", finished.model.ModelPath)
+				slog.Error("finished requeset signal received after model unloaded", "modelPath", finished.model.ModelPath)
 				continue
 				continue
 			}
 			}
 			runner.refMu.Lock()
 			runner.refMu.Lock()
 			runner.refCount--
 			runner.refCount--
 			if runner.refCount <= 0 {
 			if runner.refCount <= 0 {
 				if runner.sessionDuration <= 0 {
 				if runner.sessionDuration <= 0 {
-					slog.Debug("runner with zero duration has gone idle, expiring to unload", "model", runner.model)
+					slog.Debug("runner with zero duration has gone idle, expiring to unload", "modelPath", runner.modelPath)
 					if runner.expireTimer != nil {
 					if runner.expireTimer != nil {
 						runner.expireTimer.Stop()
 						runner.expireTimer.Stop()
 						runner.expireTimer = nil
 						runner.expireTimer = nil
 					}
 					}
 					s.expiredCh <- runner
 					s.expiredCh <- runner
 				} else if runner.expireTimer == nil {
 				} else if runner.expireTimer == nil {
-					slog.Debug("runner with non-zero duration has gone idle, adding timer", "model", runner.model, "duration", runner.sessionDuration)
+					slog.Debug("runner with non-zero duration has gone idle, adding timer", "modelPath", runner.modelPath, "duration", runner.sessionDuration)
 					runner.expireTimer = time.AfterFunc(runner.sessionDuration, func() {
 					runner.expireTimer = time.AfterFunc(runner.sessionDuration, func() {
-						slog.Debug("timer expired, expiring to unload", "model", runner.model)
+						slog.Debug("timer expired, expiring to unload", "modelPath", runner.modelPath)
 						runner.refMu.Lock()
 						runner.refMu.Lock()
 						defer runner.refMu.Unlock()
 						defer runner.refMu.Unlock()
 						if runner.expireTimer != nil {
 						if runner.expireTimer != nil {
@@ -244,19 +244,21 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
 						}
 						}
 						s.expiredCh <- runner
 						s.expiredCh <- runner
 					})
 					})
+					runner.expiresAt = time.Now().Add(runner.sessionDuration)
 				} else {
 				} else {
-					slog.Debug("runner with non-zero duration has gone idle, resetting timer", "model", runner.model, "duration", runner.sessionDuration)
+					slog.Debug("runner with non-zero duration has gone idle, resetting timer", "modelPath", runner.modelPath, "duration", runner.sessionDuration)
 					runner.expireTimer.Reset(runner.sessionDuration)
 					runner.expireTimer.Reset(runner.sessionDuration)
+					runner.expiresAt = time.Now().Add(runner.sessionDuration)
 				}
 				}
 			}
 			}
-			slog.Debug("after processing request finished event", "model", runner.model, "refCount", runner.refCount)
+			slog.Debug("after processing request finished event", "modelPath", runner.modelPath, "refCount", runner.refCount)
 			runner.refMu.Unlock()
 			runner.refMu.Unlock()
 		case runner := <-s.expiredCh:
 		case runner := <-s.expiredCh:
-			slog.Debug("runner expired event received", "model", runner.model)
+			slog.Debug("runner expired event received", "modelPath", runner.modelPath)
 			runner.refMu.Lock()
 			runner.refMu.Lock()
 			if runner.refCount > 0 {
 			if runner.refCount > 0 {
 				// Shouldn't happen, but safeguard to ensure no leaked runners
 				// Shouldn't happen, but safeguard to ensure no leaked runners
-				slog.Debug("expired event with positive ref count, retrying", "model", runner.model, "refCount", runner.refCount)
+				slog.Debug("expired event with positive ref count, retrying", "modelPath", runner.modelPath, "refCount", runner.refCount)
 				go func(runner *runnerRef) {
 				go func(runner *runnerRef) {
 					// We can't unload yet, but want to as soon as the current request completes
 					// We can't unload yet, but want to as soon as the current request completes
 					// So queue up another expired event
 					// So queue up another expired event
@@ -268,16 +270,16 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
 			}
 			}
 
 
 			s.loadedMu.Lock()
 			s.loadedMu.Lock()
-			slog.Debug("got lock to unload", "model", runner.model)
+			slog.Debug("got lock to unload", "modelPath", runner.modelPath)
 			finished := runner.waitForVRAMRecovery()
 			finished := runner.waitForVRAMRecovery()
 			runner.unload()
 			runner.unload()
-			delete(s.loaded, runner.model)
+			delete(s.loaded, runner.modelPath)
 			s.loadedMu.Unlock()
 			s.loadedMu.Unlock()
-			slog.Debug("runner released", "model", runner.model)
+			slog.Debug("runner released", "modelPath", runner.modelPath)
 			runner.refMu.Unlock()
 			runner.refMu.Unlock()
 
 
 			<-finished
 			<-finished
-			slog.Debug("sending an unloaded event", "model", runner.model)
+			slog.Debug("sending an unloaded event", "modelPath", runner.modelPath)
 			s.unloadedCh <- struct{}{}
 			s.unloadedCh <- struct{}{}
 		}
 		}
 	}
 	}
@@ -316,18 +318,20 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList)
 		req.errCh <- err
 		req.errCh <- err
 		return
 		return
 	}
 	}
-	runner := &runnerRef{}
-	runner.model = req.model.ModelPath
-	runner.adapters = req.model.AdapterPaths
-	runner.projectors = req.model.ProjectorPaths
-	runner.llama = llama
-	runner.Options = &req.opts
-	runner.sessionDuration = req.sessionDuration
-	runner.gpus = gpus
-	runner.estimatedVRAM = llama.EstimatedVRAM()
-	runner.loading = true
-	runner.refCount = 1
+	runner := &runnerRef{
+		model:           req.model,
+		modelPath:       req.model.ModelPath,
+		llama:           llama,
+		Options:         &req.opts,
+		sessionDuration: req.sessionDuration,
+		gpus:            gpus,
+		estimatedVRAM:   llama.EstimatedVRAM(),
+		estimatedTotal:  llama.EstimatedTotal(),
+		loading:         true,
+		refCount:        1,
+	}
 	runner.refMu.Lock()
 	runner.refMu.Lock()
+
 	s.loadedMu.Lock()
 	s.loadedMu.Lock()
 	s.loaded[req.model.ModelPath] = runner
 	s.loaded[req.model.ModelPath] = runner
 	slog.Info("loaded runners", "count", len(s.loaded))
 	slog.Info("loaded runners", "count", len(s.loaded))
@@ -339,7 +343,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList)
 			slog.Error("error loading llama server", "error", err)
 			slog.Error("error loading llama server", "error", err)
 			runner.refCount--
 			runner.refCount--
 			req.errCh <- err
 			req.errCh <- err
-			slog.Debug("triggering expiration for failed load", "model", runner.model)
+			slog.Debug("triggering expiration for failed load", "model", runner.modelPath)
 			s.expiredCh <- runner
 			s.expiredCh <- runner
 			return
 			return
 		}
 		}
@@ -408,17 +412,18 @@ type runnerRef struct {
 	refCount uint // prevent unloading if > 0
 	refCount uint // prevent unloading if > 0
 	// unloading bool      // set to true when we are trying to unload the runner
 	// unloading bool      // set to true when we are trying to unload the runner
 
 
-	llama         llm.LlamaServer
-	loading       bool            // True only during initial load, then false forever
-	gpus          gpu.GpuInfoList // Recorded at time of provisioning
-	estimatedVRAM uint64
+	llama          llm.LlamaServer
+	loading        bool            // True only during initial load, then false forever
+	gpus           gpu.GpuInfoList // Recorded at time of provisioning
+	estimatedVRAM  uint64
+	estimatedTotal uint64
 
 
 	sessionDuration time.Duration
 	sessionDuration time.Duration
 	expireTimer     *time.Timer
 	expireTimer     *time.Timer
+	expiresAt       time.Time
 
 
-	model      string
-	adapters   []string
-	projectors []string
+	model     *Model
+	modelPath string
 	*api.Options
 	*api.Options
 }
 }
 
 
@@ -431,9 +436,8 @@ func (runner *runnerRef) unload() {
 	if runner.llama != nil {
 	if runner.llama != nil {
 		runner.llama.Close()
 		runner.llama.Close()
 	}
 	}
+	runner.model = nil
 	runner.llama = nil
 	runner.llama = nil
-	runner.adapters = nil
-	runner.projectors = nil
 	runner.Options = nil
 	runner.Options = nil
 	runner.gpus = nil
 	runner.gpus = nil
 }
 }
@@ -462,8 +466,8 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
 
 
 	ctx, cancel := context.WithTimeout(ctx, timeout)
 	ctx, cancel := context.WithTimeout(ctx, timeout)
 	defer cancel()
 	defer cancel()
-	if !reflect.DeepEqual(runner.adapters, req.model.AdapterPaths) || // have the adapters changed?
-		!reflect.DeepEqual(runner.projectors, req.model.ProjectorPaths) || // have the projectors changed?
+	if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
+		!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
 		!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
 		!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
 		runner.llama.Ping(ctx) != nil {
 		runner.llama.Ping(ctx) != nil {
 		return true
 		return true

+ 12 - 10
server/sched_test.go

@@ -164,7 +164,8 @@ func TestRequests(t *testing.T) {
 
 
 	// simple reload of same model
 	// simple reload of same model
 	scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
 	scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
-	scenario2a.req.model = scenario1a.req.model
+	tmpModel := *scenario1a.req.model
+	scenario2a.req.model = &tmpModel
 	scenario2a.ggml = scenario1a.ggml
 	scenario2a.ggml = scenario1a.ggml
 
 
 	// Multiple loaded models
 	// Multiple loaded models
@@ -496,10 +497,9 @@ func TestNeedsReload(t *testing.T) {
 	llm := &mockLlm{}
 	llm := &mockLlm{}
 	do := api.DefaultOptions()
 	do := api.DefaultOptions()
 	runner := &runnerRef{
 	runner := &runnerRef{
-		adapters:   []string{"adapter1"},
-		projectors: []string{"projector1"},
-		Options:    &do,
-		llama:      llm,
+		model:   &Model{AdapterPaths: []string{"adapter1"}, ProjectorPaths: []string{"projector1"}},
+		Options: &do,
+		llama:   llm,
 	}
 	}
 	req := &LlmRequest{
 	req := &LlmRequest{
 		model: &Model{
 		model: &Model{
@@ -510,10 +510,10 @@ func TestNeedsReload(t *testing.T) {
 	}
 	}
 	resp := runner.needsReload(ctx, req)
 	resp := runner.needsReload(ctx, req)
 	require.True(t, resp)
 	require.True(t, resp)
-	req.model.AdapterPaths = runner.adapters
+	req.model.AdapterPaths = runner.model.AdapterPaths
 	resp = runner.needsReload(ctx, req)
 	resp = runner.needsReload(ctx, req)
 	require.True(t, resp)
 	require.True(t, resp)
-	req.model.ProjectorPaths = runner.projectors
+	req.model.ProjectorPaths = runner.model.ProjectorPaths
 	runner.loading = true
 	runner.loading = true
 	req.opts.NumBatch = 1234
 	req.opts.NumBatch = 1234
 	resp = runner.needsReload(ctx, req)
 	resp = runner.needsReload(ctx, req)
@@ -558,11 +558,11 @@ func TestUnloadAllRunners(t *testing.T) {
 func TestUnload(t *testing.T) {
 func TestUnload(t *testing.T) {
 	llm1 := &mockLlm{}
 	llm1 := &mockLlm{}
 	r1 := &runnerRef{llama: llm1}
 	r1 := &runnerRef{llama: llm1}
-	r2 := &runnerRef{adapters: []string{"A"}}
+	r2 := &runnerRef{model: &Model{AdapterPaths: []string{"A"}}}
 	r1.unload()
 	r1.unload()
 	require.True(t, llm1.closeCalled)
 	require.True(t, llm1.closeCalled)
 	r2.unload()
 	r2.unload()
-	require.Nil(t, r2.adapters)
+	require.Nil(t, r2.model)
 }
 }
 
 
 type mockLlm struct {
 type mockLlm struct {
@@ -578,6 +578,7 @@ type mockLlm struct {
 	closeResp         error
 	closeResp         error
 	closeCalled       bool
 	closeCalled       bool
 	estimatedVRAM     uint64
 	estimatedVRAM     uint64
+	estimatedTotal    uint64
 }
 }
 
 
 func (s *mockLlm) Ping(ctx context.Context) error             { return s.pingResp }
 func (s *mockLlm) Ping(ctx context.Context) error             { return s.pingResp }
@@ -598,4 +599,5 @@ func (s *mockLlm) Close() error {
 	s.closeCalled = true
 	s.closeCalled = true
 	return s.closeResp
 	return s.closeResp
 }
 }
-func (s *mockLlm) EstimatedVRAM() uint64 { return s.estimatedVRAM }
+func (s *mockLlm) EstimatedVRAM() uint64  { return s.estimatedVRAM }
+func (s *mockLlm) EstimatedTotal() uint64 { return s.estimatedTotal }