Sfoglia il codice sorgente

do not reload the running llm when runtime params change (#840)

- only reload the running llm if the model has changed, or the options for loading the running model have changed
- rename loaded llm to runner to differentiate from loaded model image
- remove logic which keeps the first system prompt in the generation context
Bruce MacDonald 1 anno fa
parent
commit
fe6f3b48f7
3 ha cambiato i file con 66 aggiunte e 86 eliminazioni
  1. 27 25
      api/types.go
  2. 5 7
      server/images.go
  3. 34 54
      server/routes.go

+ 27 - 25
api/types.go

@@ -161,15 +161,10 @@ func (r *GenerateResponse) Summary() {
 	}
 	}
 }
 }
 
 
-type Options struct {
-	Seed int `json:"seed,omitempty"`
-
-	// Backend options
-	UseNUMA bool `json:"numa,omitempty"`
-
-	// Model options
+// Runner options which must be set when the model is loaded into memory
+type Runner struct {
+	UseNUMA            bool    `json:"numa,omitempty"`
 	NumCtx             int     `json:"num_ctx,omitempty"`
 	NumCtx             int     `json:"num_ctx,omitempty"`
-	NumKeep            int     `json:"num_keep,omitempty"`
 	NumBatch           int     `json:"num_batch,omitempty"`
 	NumBatch           int     `json:"num_batch,omitempty"`
 	NumGQA             int     `json:"num_gqa,omitempty"`
 	NumGQA             int     `json:"num_gqa,omitempty"`
 	NumGPU             int     `json:"num_gpu,omitempty"`
 	NumGPU             int     `json:"num_gpu,omitempty"`
@@ -183,8 +178,15 @@ type Options struct {
 	EmbeddingOnly      bool    `json:"embedding_only,omitempty"`
 	EmbeddingOnly      bool    `json:"embedding_only,omitempty"`
 	RopeFrequencyBase  float32 `json:"rope_frequency_base,omitempty"`
 	RopeFrequencyBase  float32 `json:"rope_frequency_base,omitempty"`
 	RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"`
 	RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"`
+	NumThread          int     `json:"num_thread,omitempty"`
+}
 
 
-	// Predict options
+type Options struct {
+	Runner
+
+	// Predict options used at runtime
+	NumKeep          int      `json:"num_keep,omitempty"`
+	Seed             int      `json:"seed,omitempty"`
 	NumPredict       int      `json:"num_predict,omitempty"`
 	NumPredict       int      `json:"num_predict,omitempty"`
 	TopK             int      `json:"top_k,omitempty"`
 	TopK             int      `json:"top_k,omitempty"`
 	TopP             float32  `json:"top_p,omitempty"`
 	TopP             float32  `json:"top_p,omitempty"`
@@ -200,8 +202,6 @@ type Options struct {
 	MirostatEta      float32  `json:"mirostat_eta,omitempty"`
 	MirostatEta      float32  `json:"mirostat_eta,omitempty"`
 	PenalizeNewline  bool     `json:"penalize_newline,omitempty"`
 	PenalizeNewline  bool     `json:"penalize_newline,omitempty"`
 	Stop             []string `json:"stop,omitempty"`
 	Stop             []string `json:"stop,omitempty"`
-
-	NumThread int `json:"num_thread,omitempty"`
 }
 }
 
 
 var ErrInvalidOpts = fmt.Errorf("invalid options")
 var ErrInvalidOpts = fmt.Errorf("invalid options")
@@ -309,20 +309,22 @@ func DefaultOptions() Options {
 		PenalizeNewline:  true,
 		PenalizeNewline:  true,
 		Seed:             -1,
 		Seed:             -1,
 
 
-		// options set when the model is loaded
-		NumCtx:             2048,
-		RopeFrequencyBase:  10000.0,
-		RopeFrequencyScale: 1.0,
-		NumBatch:           512,
-		NumGPU:             -1, // -1 here indicates that NumGPU should be set dynamically
-		NumGQA:             1,
-		NumThread:          0, // let the runtime decide
-		LowVRAM:            false,
-		F16KV:              true,
-		UseMLock:           false,
-		UseMMap:            true,
-		UseNUMA:            false,
-		EmbeddingOnly:      true,
+		Runner: Runner{
+			// options set when the model is loaded
+			NumCtx:             2048,
+			RopeFrequencyBase:  10000.0,
+			RopeFrequencyScale: 1.0,
+			NumBatch:           512,
+			NumGPU:             -1, // -1 here indicates that NumGPU should be set dynamically
+			NumGQA:             1,
+			NumThread:          0, // let the runtime decide
+			LowVRAM:            false,
+			F16KV:              true,
+			UseMLock:           false,
+			UseMMap:            true,
+			UseNUMA:            false,
+			EmbeddingOnly:      true,
+		},
 	}
 	}
 }
 }
 
 

+ 5 - 7
server/images.go

@@ -45,7 +45,6 @@ type Model struct {
 	System        string
 	System        string
 	License       []string
 	License       []string
 	Digest        string
 	Digest        string
-	ConfigDigest  string
 	Options       map[string]interface{}
 	Options       map[string]interface{}
 }
 }
 
 
@@ -166,12 +165,11 @@ func GetModel(name string) (*Model, error) {
 	}
 	}
 
 
 	model := &Model{
 	model := &Model{
-		Name:         mp.GetFullTagname(),
-		ShortName:    mp.GetShortTagname(),
-		Digest:       digest,
-		ConfigDigest: manifest.Config.Digest,
-		Template:     "{{ .Prompt }}",
-		License:      []string{},
+		Name:      mp.GetFullTagname(),
+		ShortName: mp.GetShortTagname(),
+		Digest:    digest,
+		Template:  "{{ .Prompt }}",
+		License:   []string{},
 	}
 	}
 
 
 	for _, layer := range manifest.Layers {
 	for _, layer := range manifest.Layers {

+ 34 - 54
server/routes.go

@@ -46,13 +46,13 @@ func init() {
 var loaded struct {
 var loaded struct {
 	mu sync.Mutex
 	mu sync.Mutex
 
 
-	llm llm.LLM
+	runner llm.LLM
 
 
 	expireAt    time.Time
 	expireAt    time.Time
 	expireTimer *time.Timer
 	expireTimer *time.Timer
 
 
-	digest  string
-	options api.Options
+	*Model
+	*api.Options
 }
 }
 
 
 var defaultSessionDuration = 5 * time.Minute
 var defaultSessionDuration = 5 * time.Minute
@@ -70,59 +70,39 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
 	}
 	}
 
 
 	// check if the loaded model is still running in a subprocess, in case something unexpected happened
 	// check if the loaded model is still running in a subprocess, in case something unexpected happened
-	if loaded.llm != nil {
-		if err := loaded.llm.Ping(ctx); err != nil {
+	if loaded.runner != nil {
+		if err := loaded.runner.Ping(ctx); err != nil {
 			log.Print("loaded llm process not responding, closing now")
 			log.Print("loaded llm process not responding, closing now")
 			// the subprocess is no longer running, so close it
 			// the subprocess is no longer running, so close it
-			loaded.llm.Close()
-			loaded.llm = nil
-			loaded.digest = ""
+			loaded.runner.Close()
+			loaded.runner = nil
+			loaded.Model = nil
+			loaded.Options = nil
 		}
 		}
 	}
 	}
 
 
-	if model.Digest != loaded.digest || !reflect.DeepEqual(loaded.options, opts) {
-		if loaded.llm != nil {
+	needLoad := loaded.runner == nil || // is there a model loaded?
+		loaded.ModelPath != model.ModelPath || // has the base model changed?
+		!reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed?
+		!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) // have the runner options changed?
+
+	if needLoad {
+		if loaded.runner != nil {
 			log.Println("changing loaded model")
 			log.Println("changing loaded model")
-			loaded.llm.Close()
-			loaded.llm = nil
-			loaded.digest = ""
+			loaded.runner.Close()
+			loaded.runner = nil
+			loaded.Model = nil
+			loaded.Options = nil
 		}
 		}
 
 
-		llmModel, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts)
+		llmRunner, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
 
 
-		// set cache values before modifying opts
-		loaded.llm = llmModel
-		loaded.digest = model.Digest
-		loaded.options = opts
-
-		if opts.NumKeep < 0 {
-			promptWithSystem, err := model.Prompt(api.GenerateRequest{})
-			if err != nil {
-				return err
-			}
-
-			promptNoSystem, err := model.Prompt(api.GenerateRequest{Context: []int{0}})
-			if err != nil {
-				return err
-			}
-
-			tokensWithSystem, err := llmModel.Encode(ctx, promptWithSystem)
-			if err != nil {
-				return err
-			}
-
-			tokensNoSystem, err := llmModel.Encode(ctx, promptNoSystem)
-			if err != nil {
-				return err
-			}
-
-			opts.NumKeep = len(tokensWithSystem) - len(tokensNoSystem)
-
-			llmModel.SetOptions(opts)
-		}
+		loaded.Model = model
+		loaded.runner = llmRunner
+		loaded.Options = &opts
 	}
 	}
 
 
 	loaded.expireAt = time.Now().Add(sessionDuration)
 	loaded.expireAt = time.Now().Add(sessionDuration)
@@ -136,13 +116,13 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
 				return
 				return
 			}
 			}
 
 
-			if loaded.llm == nil {
-				return
+			if loaded.runner != nil {
+				loaded.runner.Close()
 			}
 			}
 
 
-			loaded.llm.Close()
-			loaded.llm = nil
-			loaded.digest = ""
+			loaded.runner = nil
+			loaded.Model = nil
+			loaded.Options = nil
 		})
 		})
 	}
 	}
 
 
@@ -215,7 +195,7 @@ func GenerateHandler(c *gin.Context) {
 		if req.Prompt == "" && req.Template == "" && req.System == "" {
 		if req.Prompt == "" && req.Template == "" && req.System == "" {
 			ch <- api.GenerateResponse{Model: req.Model, Done: true}
 			ch <- api.GenerateResponse{Model: req.Model, Done: true}
 		} else {
 		} else {
-			if err := loaded.llm.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil {
+			if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil {
 				ch <- gin.H{"error": err.Error()}
 				ch <- gin.H{"error": err.Error()}
 			}
 			}
 		}
 		}
@@ -263,12 +243,12 @@ func EmbeddingHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
-	if !loaded.options.EmbeddingOnly {
+	if !loaded.Options.EmbeddingOnly {
 		c.JSON(http.StatusBadRequest, gin.H{"error": "embedding option must be set to true"})
 		c.JSON(http.StatusBadRequest, gin.H{"error": "embedding option must be set to true"})
 		return
 		return
 	}
 	}
 
 
-	embedding, err := loaded.llm.Embedding(c.Request.Context(), req.Prompt)
+	embedding, err := loaded.runner.Embedding(c.Request.Context(), req.Prompt)
 	if err != nil {
 	if err != nil {
 		log.Printf("embedding generation failed: %v", err)
 		log.Printf("embedding generation failed: %v", err)
 		c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
 		c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
@@ -599,8 +579,8 @@ func Serve(ln net.Listener, allowOrigins []string) error {
 	signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
 	signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
 	go func() {
 	go func() {
 		<-signals
 		<-signals
-		if loaded.llm != nil {
-			loaded.llm.Close()
+		if loaded.runner != nil {
+			loaded.runner.Close()
 		}
 		}
 		os.RemoveAll(workDir)
 		os.RemoveAll(workDir)
 		os.Exit(0)
 		os.Exit(0)