Prechádzať zdrojové kódy

Prevent loading models larger than total memory

Users may not realize the siny new model they're trying to load
fits on their disk, but can't load into system+GPU memory.  Today
we crash, but with this fix, we'll give them a better error message
before even trying to load it.
Daniel Hiltgen 10 mesiacov pred
rodič
commit
3c75113e37
2 zmenil súbory, kde vykonal 38 pridanie a 0 odobranie
  1. 26 0
      server/sched.go
  2. 12 0
      server/sched_test.go

+ 26 - 0
server/sched.go

@@ -139,6 +139,11 @@ func (s *Scheduler) processPending(ctx context.Context) {
 			}
 
 			for {
+				cpus := s.getCpuFn()
+				var systemMem gpu.GpuInfo
+				if len(cpus) > 0 {
+					systemMem = cpus[0]
+				}
 				var runnerToExpire *runnerRef
 				s.loadedMu.Lock()
 				runner := s.loaded[pending.model.ModelPath]
@@ -192,6 +197,27 @@ func (s *Scheduler) processPending(ctx context.Context) {
 						break
 					}
 
+					// Block attempting to load a model larger than system memory + GPU memory
+					estimate := llm.EstimateGPULayers(gpus, ggml, pending.model.ProjectorPaths, pending.opts)
+					maxSize := systemMem.FreeMemory
+					for _, gpu := range gpus {
+						if gpu.Library == "cpu" {
+							continue
+						}
+						if loadedCount == 0 {
+							// If no other models are loaded, set the limit based on what's available
+							maxSize += gpu.FreeMemory
+						} else {
+							// Other models could be unloaded, favor total memory for limit
+							maxSize += gpu.TotalMemory
+						}
+					}
+					if estimate.TotalSize > maxSize {
+						slog.Warn("model request too large for system", "requested", format.HumanBytes2(estimate.TotalSize), "system", format.HumanBytes2(maxSize))
+						pending.errCh <- fmt.Errorf("requested model (%s) is too large for this system (%s)", format.HumanBytes2(estimate.TotalSize), format.HumanBytes2(maxSize))
+						break
+					}
+
 					// Evaluate if the model will fit in the available system memory, or if we should unload a model first
 					if len(gpus) == 1 && gpus[0].Library == "cpu" {
 						// simplifying assumption of defaultParallel when in CPU mode

+ 12 - 0
server/sched_test.go

@@ -199,6 +199,8 @@ func TestRequests(t *testing.T) {
 		require.Equal(t, resp.llama, scenario1a.srv)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario1a.req.errCh)
+	case err := <-scenario1a.req.errCh:
+		t.Fatal(err.Error())
 	case <-ctx.Done():
 		t.Fatal("timeout")
 	}
@@ -212,6 +214,8 @@ func TestRequests(t *testing.T) {
 		require.Equal(t, resp.llama, scenario1a.srv)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario1b.req.errCh)
+	case err := <-scenario1b.req.errCh:
+		t.Fatal(err.Error())
 	case <-ctx.Done():
 		t.Fatal("timeout")
 	}
@@ -230,6 +234,8 @@ func TestRequests(t *testing.T) {
 		require.Equal(t, resp.llama, scenario2a.srv)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario2a.req.errCh)
+	case err := <-scenario2a.req.errCh:
+		t.Fatal(err.Error())
 	case <-ctx.Done():
 		t.Fatal("timeout")
 	}
@@ -246,6 +252,8 @@ func TestRequests(t *testing.T) {
 		require.Equal(t, resp.llama, scenario3a.srv)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario3a.req.errCh)
+	case err := <-scenario3a.req.errCh:
+		t.Fatal(err.Error())
 	case <-ctx.Done():
 		t.Fatal("timeout")
 	}
@@ -262,6 +270,8 @@ func TestRequests(t *testing.T) {
 		require.Equal(t, resp.llama, scenario3b.srv)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario3b.req.errCh)
+	case err := <-scenario3b.req.errCh:
+		t.Fatal(err.Error())
 	case <-ctx.Done():
 		t.Fatal("timeout")
 	}
@@ -278,6 +288,8 @@ func TestRequests(t *testing.T) {
 		require.Equal(t, resp.llama, scenario3c.srv)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario3c.req.errCh)
+	case err := <-scenario3c.req.errCh:
+		t.Fatal(err.Error())
 	case <-ctx.Done():
 		t.Fatal("timeout")
 	}