소스 검색

fix model reloading

ensure runtime model changes (template, system prompt, messages,
options) are captured on model updates without needing to reload the
server
Michael Yang 10 달 전
부모
커밋
ac7a842e55
2개의 변경된 파일23개의 추가작업 그리고 21개의 파일을 삭제
  1. 1 1
      llm/server.go
  2. 22 20
      server/routes.go

+ 1 - 1
llm/server.go

@@ -679,7 +679,7 @@ type CompletionRequest struct {
 	Prompt  string
 	Format  string
 	Images  []ImageData
-	Options api.Options
+	Options *api.Options
 }
 
 type CompletionResponse struct {

+ 22 - 20
server/routes.go

@@ -69,23 +69,25 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
 	return opts, nil
 }
 
-func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (*runnerRef, error) {
+// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
+// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
+func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
 	if name == "" {
-		return nil, fmt.Errorf("model %w", errRequired)
+		return nil, nil, nil, fmt.Errorf("model %w", errRequired)
 	}
 
 	model, err := GetModel(name)
 	if err != nil {
-		return nil, err
+		return nil, nil, nil, err
 	}
 
 	if err := model.CheckCapabilities(caps...); err != nil {
-		return nil, fmt.Errorf("%s %w", name, err)
+		return nil, nil, nil, fmt.Errorf("%s %w", name, err)
 	}
 
 	opts, err := modelOptions(model, requestOpts)
 	if err != nil {
-		return nil, err
+		return nil, nil, nil, err
 	}
 
 	runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
@@ -93,10 +95,10 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
 	select {
 	case runner = <-runnerCh:
 	case err = <-errCh:
-		return nil, err
+		return nil, nil, nil, err
 	}
 
-	return runner, nil
+	return runner.llama, model, &opts, nil
 }
 
 func (s *Server) GenerateHandler(c *gin.Context) {
@@ -118,7 +120,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 	}
 
 	caps := []Capability{CapabilityCompletion}
-	r, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
+	r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
 	if errors.Is(err, errCapabilityCompletion) {
 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
 		return
@@ -147,8 +149,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		var msgs []api.Message
 		if req.System != "" {
 			msgs = append(msgs, api.Message{Role: "system", Content: req.System})
-		} else if r.model.System != "" {
-			msgs = append(msgs, api.Message{Role: "system", Content: r.model.System})
+		} else if m.System != "" {
+			msgs = append(msgs, api.Message{Role: "system", Content: m.System})
 		}
 
 		for _, i := range images {
@@ -157,7 +159,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 
 		msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
 
-		tmpl := r.model.Template
+		tmpl := m.Template
 		if req.Template != "" {
 			tmpl, err = template.Parse(req.Template)
 			if err != nil {
@@ -168,7 +170,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 
 		var b bytes.Buffer
 		if req.Context != nil {
-			s, err := r.llama.Detokenize(c.Request.Context(), req.Context)
+			s, err := r.Detokenize(c.Request.Context(), req.Context)
 			if err != nil {
 				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 				return
@@ -190,11 +192,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 	ch := make(chan any)
 	go func() {
 		defer close(ch)
-		if err := r.llama.Completion(c.Request.Context(), llm.CompletionRequest{
+		if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
 			Prompt:  prompt,
 			Images:  images,
 			Format:  req.Format,
-			Options: *r.Options,
+			Options: opts,
 		}, func(r llm.CompletionResponse) {
 			ch <- api.GenerateResponse{
 				Model:      req.Model,
@@ -254,7 +256,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
 		return
 	}
 
-	r, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
+	r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
 	if err != nil {
 		handleScheduleError(c, req.Model, err)
 		return
@@ -266,7 +268,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
 		return
 	}
 
-	embedding, err := r.llama.Embedding(c.Request.Context(), req.Prompt)
+	embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
 	if err != nil {
 		slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
 		c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
@@ -1130,7 +1132,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
 	}
 
 	caps := []Capability{CapabilityCompletion}
-	r, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
+	r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
 	if errors.Is(err, errCapabilityCompletion) {
 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
 		return
@@ -1150,7 +1152,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
 		return
 	}
 
-	prompt, images, err := chatPrompt(c.Request.Context(), r.model, r.llama.Tokenize, r.Options, req.Messages)
+	prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
 	if err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
@@ -1161,11 +1163,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
 	ch := make(chan any)
 	go func() {
 		defer close(ch)
-		if err := r.llama.Completion(c.Request.Context(), llm.CompletionRequest{
+		if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
 			Prompt:  prompt,
 			Images:  images,
 			Format:  req.Format,
-			Options: *r.Options,
+			Options: opts,
 		}, func(r llm.CompletionResponse) {
 			ch <- api.ChatResponse{
 				Model:      req.Model,