瀏覽代碼

Wire up load progress

This doesn't expose a UX yet, but wires the initial server portion
of progress reporting during load
Daniel Hiltgen 11 月之前
父節點
當前提交
b37b496a12
共有 3 個文件被更改,包括 61 次插入8 次删除
  1. 13 1
      llm/ext_server/server.cpp
  2. 31 0
      llm/patches/01-load-progress.diff
  3. 17 7
      llm/server.go

+ 13 - 1
llm/ext_server/server.cpp

@@ -334,6 +334,7 @@ struct server_metrics {
 struct llama_server_context
 {
     llama_model *model = nullptr;
+    float modelProgress = 0.0;
     llama_context *ctx = nullptr;
 
     clip_ctx *clp_ctx = nullptr;
@@ -2779,6 +2780,12 @@ inline void signal_handler(int signal) {
     shutdown_handler(signal);
 }
 
+static bool update_load_progress(float progress, void *data)
+{
+    ((llama_server_context*)data)->modelProgress = progress;
+    return true;
+}
+
 #if defined(_WIN32)
 char* wchar_to_char(const wchar_t* wstr) {
     if (wstr == nullptr) return nullptr;
@@ -2884,7 +2891,9 @@ int main(int argc, char **argv) {
                 break;
             }
             case SERVER_STATE_LOADING_MODEL:
-                res.set_content(R"({"status": "loading model"})", "application/json");
+                char buf[128];
+                snprintf(&buf[0], 128, R"({"status": "loading model", "progress": %0.2f})", llama.modelProgress);
+                res.set_content(buf, "application/json");
                 res.status = 503; // HTTP Service Unavailable
                 break;
             case SERVER_STATE_ERROR:
@@ -3079,6 +3088,9 @@ int main(int argc, char **argv) {
             });
 
     // load the model
+    params.progress_callback = update_load_progress;
+    params.progress_callback_user_data = (void*)&llama;
+
     if (!llama.load_model(params))
     {
         state.store(SERVER_STATE_ERROR);

+ 31 - 0
llm/patches/01-load-progress.diff

@@ -0,0 +1,31 @@
+diff --git a/common/common.cpp b/common/common.cpp
+index ba1ecf0e..cead57cc 100644
+--- a/common/common.cpp
++++ b/common/common.cpp
+@@ -1836,6 +1836,8 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
+     mparams.use_mmap        = params.use_mmap;
+     mparams.use_mlock       = params.use_mlock;
+     mparams.check_tensors   = params.check_tensors;
++    mparams.progress_callback = params.progress_callback;
++    mparams.progress_callback_user_data = params.progress_callback_user_data;
+     if (params.kv_overrides.empty()) {
+         mparams.kv_overrides = NULL;
+     } else {
+diff --git a/common/common.h b/common/common.h
+index d80344f2..71e84834 100644
+--- a/common/common.h
++++ b/common/common.h
+@@ -174,6 +174,13 @@ struct gpt_params {
+     // multimodal models (see examples/llava)
+     std::string mmproj = "";        // path to multimodal projector
+     std::vector<std::string> image; // path to image file(s)
++
++    // Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
++    // If the provided progress_callback returns true, model loading continues.
++    // If it returns false, model loading is immediately aborted.
++    llama_progress_callback progress_callback = NULL;
++    // context pointer passed to the progress callback
++    void * progress_callback_user_data;
+ };
+ 
+ void gpt_params_handle_model_default(gpt_params & params);

+ 17 - 7
llm/server.go

@@ -55,6 +55,7 @@ type llmServer struct {
 	totalLayers    uint64
 	gpuCount       int
 	loadDuration   time.Duration // Record how long it took the model to load
+	loadProgress   float32
 
 	sem *semaphore.Weighted
 }
@@ -425,10 +426,11 @@ func (s ServerStatus) ToString() string {
 }
 
 type ServerStatusResp struct {
-	Status          string `json:"status"`
-	SlotsIdle       int    `json:"slots_idle"`
-	SlotsProcessing int    `json:"slots_processing"`
-	Error           string `json:"error"`
+	Status          string  `json:"status"`
+	SlotsIdle       int     `json:"slots_idle"`
+	SlotsProcessing int     `json:"slots_processing"`
+	Error           string  `json:"error"`
+	Progress        float32 `json:"progress"`
 }
 
 func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
@@ -476,6 +478,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
 	case "no slot available":
 		return ServerStatusNoSlotsAvailable, nil
 	case "loading model":
+		s.loadProgress = status.Progress
 		return ServerStatusLoadingModel, nil
 	default:
 		return ServerStatusError, fmt.Errorf("server error: %+v", status)
@@ -516,7 +519,8 @@ func (s *llmServer) Ping(ctx context.Context) error {
 
 func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
 	start := time.Now()
-	expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
+	stallDuration := 60 * time.Second
+	stallTimer := time.Now().Add(stallDuration) // give up if we stall for
 
 	slog.Info("waiting for llama runner to start responding")
 	var lastStatus ServerStatus = -1
@@ -534,13 +538,13 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
 			return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
 		default:
 		}
-		if time.Now().After(expiresAt) {
+		if time.Now().After(stallTimer) {
 			// timeout
 			msg := ""
 			if s.status != nil && s.status.LastErrMsg != "" {
 				msg = s.status.LastErrMsg
 			}
-			return fmt.Errorf("timed out waiting for llama runner to start: %s", msg)
+			return fmt.Errorf("timed out waiting for llama runner to start - progress %0.2f - %s", s.loadProgress, msg)
 		}
 		if s.cmd.ProcessState != nil {
 			msg := ""
@@ -551,6 +555,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
 		}
 		ctx, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
 		defer cancel()
+		priorProgress := s.loadProgress
 		status, _ := s.getServerStatus(ctx)
 		if lastStatus != status && status != ServerStatusReady {
 			// Only log on status changes
@@ -563,6 +568,11 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
 			return nil
 		default:
 			lastStatus = status
+			// Reset the timer as long as we're making forward progress on the load
+			if priorProgress != s.loadProgress {
+				slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
+				stallTimer = time.Now().Add(stallDuration)
+			}
 			time.Sleep(time.Millisecond * 250)
 			continue
 		}