Przeglądaj źródła

fix model reloading

ensure runtime model changes (template, system prompt, messages,
options) are captured on model updates without needing to reload the
server
Michael Yang 10 miesięcy temu
rodzic
commit
ac7a842e55
2 zmienionych plików z 23 dodań i 21 usunięć
  1. 1 1
      llm/server.go
  2. 22 20
      server/routes.go

+ 1 - 1
llm/server.go

@@ -679,7 +679,7 @@ type CompletionRequest struct {
 	Prompt  string
 	Prompt  string
 	Format  string
 	Format  string
 	Images  []ImageData
 	Images  []ImageData
-	Options api.Options
+	Options *api.Options
 }
 }
 
 
 type CompletionResponse struct {
 type CompletionResponse struct {

+ 22 - 20
server/routes.go

@@ -69,23 +69,25 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
 	return opts, nil
 	return opts, nil
 }
 }
 
 
-func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (*runnerRef, error) {
+// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
+// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
+func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
 	if name == "" {
 	if name == "" {
-		return nil, fmt.Errorf("model %w", errRequired)
+		return nil, nil, nil, fmt.Errorf("model %w", errRequired)
 	}
 	}
 
 
 	model, err := GetModel(name)
 	model, err := GetModel(name)
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return nil, nil, nil, err
 	}
 	}
 
 
 	if err := model.CheckCapabilities(caps...); err != nil {
 	if err := model.CheckCapabilities(caps...); err != nil {
-		return nil, fmt.Errorf("%s %w", name, err)
+		return nil, nil, nil, fmt.Errorf("%s %w", name, err)
 	}
 	}
 
 
 	opts, err := modelOptions(model, requestOpts)
 	opts, err := modelOptions(model, requestOpts)
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return nil, nil, nil, err
 	}
 	}
 
 
 	runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
 	runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
@@ -93,10 +95,10 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
 	select {
 	select {
 	case runner = <-runnerCh:
 	case runner = <-runnerCh:
 	case err = <-errCh:
 	case err = <-errCh:
-		return nil, err
+		return nil, nil, nil, err
 	}
 	}
 
 
-	return runner, nil
+	return runner.llama, model, &opts, nil
 }
 }
 
 
 func (s *Server) GenerateHandler(c *gin.Context) {
 func (s *Server) GenerateHandler(c *gin.Context) {
@@ -118,7 +120,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 	}
 	}
 
 
 	caps := []Capability{CapabilityCompletion}
 	caps := []Capability{CapabilityCompletion}
-	r, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
+	r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
 	if errors.Is(err, errCapabilityCompletion) {
 	if errors.Is(err, errCapabilityCompletion) {
 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
 		return
 		return
@@ -147,8 +149,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		var msgs []api.Message
 		var msgs []api.Message
 		if req.System != "" {
 		if req.System != "" {
 			msgs = append(msgs, api.Message{Role: "system", Content: req.System})
 			msgs = append(msgs, api.Message{Role: "system", Content: req.System})
-		} else if r.model.System != "" {
-			msgs = append(msgs, api.Message{Role: "system", Content: r.model.System})
+		} else if m.System != "" {
+			msgs = append(msgs, api.Message{Role: "system", Content: m.System})
 		}
 		}
 
 
 		for _, i := range images {
 		for _, i := range images {
@@ -157,7 +159,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 
 
 		msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
 		msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
 
 
-		tmpl := r.model.Template
+		tmpl := m.Template
 		if req.Template != "" {
 		if req.Template != "" {
 			tmpl, err = template.Parse(req.Template)
 			tmpl, err = template.Parse(req.Template)
 			if err != nil {
 			if err != nil {
@@ -168,7 +170,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 
 
 		var b bytes.Buffer
 		var b bytes.Buffer
 		if req.Context != nil {
 		if req.Context != nil {
-			s, err := r.llama.Detokenize(c.Request.Context(), req.Context)
+			s, err := r.Detokenize(c.Request.Context(), req.Context)
 			if err != nil {
 			if err != nil {
 				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 				return
 				return
@@ -190,11 +192,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 	ch := make(chan any)
 	ch := make(chan any)
 	go func() {
 	go func() {
 		defer close(ch)
 		defer close(ch)
-		if err := r.llama.Completion(c.Request.Context(), llm.CompletionRequest{
+		if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
 			Prompt:  prompt,
 			Prompt:  prompt,
 			Images:  images,
 			Images:  images,
 			Format:  req.Format,
 			Format:  req.Format,
-			Options: *r.Options,
+			Options: opts,
 		}, func(r llm.CompletionResponse) {
 		}, func(r llm.CompletionResponse) {
 			ch <- api.GenerateResponse{
 			ch <- api.GenerateResponse{
 				Model:      req.Model,
 				Model:      req.Model,
@@ -254,7 +256,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
-	r, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
+	r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
 	if err != nil {
 	if err != nil {
 		handleScheduleError(c, req.Model, err)
 		handleScheduleError(c, req.Model, err)
 		return
 		return
@@ -266,7 +268,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
-	embedding, err := r.llama.Embedding(c.Request.Context(), req.Prompt)
+	embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
 	if err != nil {
 	if err != nil {
 		slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
 		slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
 		c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
 		c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
@@ -1130,7 +1132,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
 	}
 	}
 
 
 	caps := []Capability{CapabilityCompletion}
 	caps := []Capability{CapabilityCompletion}
-	r, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
+	r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
 	if errors.Is(err, errCapabilityCompletion) {
 	if errors.Is(err, errCapabilityCompletion) {
 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
 		return
 		return
@@ -1150,7 +1152,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
-	prompt, images, err := chatPrompt(c.Request.Context(), r.model, r.llama.Tokenize, r.Options, req.Messages)
+	prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
 	if err != nil {
 	if err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 		return
@@ -1161,11 +1163,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
 	ch := make(chan any)
 	ch := make(chan any)
 	go func() {
 	go func() {
 		defer close(ch)
 		defer close(ch)
-		if err := r.llama.Completion(c.Request.Context(), llm.CompletionRequest{
+		if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
 			Prompt:  prompt,
 			Prompt:  prompt,
 			Images:  images,
 			Images:  images,
 			Format:  req.Format,
 			Format:  req.Format,
-			Options: *r.Options,
+			Options: opts,
 		}, func(r llm.CompletionResponse) {
 		}, func(r llm.CompletionResponse) {
 			ch <- api.ChatResponse{
 			ch <- api.ChatResponse{
 				Model:      req.Model,
 				Model:      req.Model,