Browse Source

Prevent loading models larger than total memory

Users may not realize the siny new model they're trying to load
fits on their disk, but can't load into system+GPU memory.  Today
we crash, but with this fix, we'll give them a better error message
before even trying to load it.
Daniel Hiltgen 10 months ago
parent
commit
3c75113e37
2 changed files with 38 additions and 0 deletions
  1. 26 0
      server/sched.go
  2. 12 0
      server/sched_test.go

+ 26 - 0
server/sched.go

@@ -139,6 +139,11 @@ func (s *Scheduler) processPending(ctx context.Context) {
 			}
 
 			for {
+				cpus := s.getCpuFn()
+				var systemMem gpu.GpuInfo
+				if len(cpus) > 0 {
+					systemMem = cpus[0]
+				}
 				var runnerToExpire *runnerRef
 				s.loadedMu.Lock()
 				runner := s.loaded[pending.model.ModelPath]
@@ -192,6 +197,27 @@ func (s *Scheduler) processPending(ctx context.Context) {
 						break
 					}
 
+					// Block attempting to load a model larger than system memory + GPU memory
+					estimate := llm.EstimateGPULayers(gpus, ggml, pending.model.ProjectorPaths, pending.opts)
+					maxSize := systemMem.FreeMemory
+					for _, gpu := range gpus {
+						if gpu.Library == "cpu" {
+							continue
+						}
+						if loadedCount == 0 {
+							// If no other models are loaded, set the limit based on what's available
+							maxSize += gpu.FreeMemory
+						} else {
+							// Other models could be unloaded, favor total memory for limit
+							maxSize += gpu.TotalMemory
+						}
+					}
+					if estimate.TotalSize > maxSize {
+						slog.Warn("model request too large for system", "requested", format.HumanBytes2(estimate.TotalSize), "system", format.HumanBytes2(maxSize))
+						pending.errCh <- fmt.Errorf("requested model (%s) is too large for this system (%s)", format.HumanBytes2(estimate.TotalSize), format.HumanBytes2(maxSize))
+						break
+					}
+
 					// Evaluate if the model will fit in the available system memory, or if we should unload a model first
 					if len(gpus) == 1 && gpus[0].Library == "cpu" {
 						// simplifying assumption of defaultParallel when in CPU mode

+ 12 - 0
server/sched_test.go

@@ -199,6 +199,8 @@ func TestRequests(t *testing.T) {
 		require.Equal(t, resp.llama, scenario1a.srv)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario1a.req.errCh)
+	case err := <-scenario1a.req.errCh:
+		t.Fatal(err.Error())
 	case <-ctx.Done():
 		t.Fatal("timeout")
 	}
@@ -212,6 +214,8 @@ func TestRequests(t *testing.T) {
 		require.Equal(t, resp.llama, scenario1a.srv)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario1b.req.errCh)
+	case err := <-scenario1b.req.errCh:
+		t.Fatal(err.Error())
 	case <-ctx.Done():
 		t.Fatal("timeout")
 	}
@@ -230,6 +234,8 @@ func TestRequests(t *testing.T) {
 		require.Equal(t, resp.llama, scenario2a.srv)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario2a.req.errCh)
+	case err := <-scenario2a.req.errCh:
+		t.Fatal(err.Error())
 	case <-ctx.Done():
 		t.Fatal("timeout")
 	}
@@ -246,6 +252,8 @@ func TestRequests(t *testing.T) {
 		require.Equal(t, resp.llama, scenario3a.srv)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario3a.req.errCh)
+	case err := <-scenario3a.req.errCh:
+		t.Fatal(err.Error())
 	case <-ctx.Done():
 		t.Fatal("timeout")
 	}
@@ -262,6 +270,8 @@ func TestRequests(t *testing.T) {
 		require.Equal(t, resp.llama, scenario3b.srv)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario3b.req.errCh)
+	case err := <-scenario3b.req.errCh:
+		t.Fatal(err.Error())
 	case <-ctx.Done():
 		t.Fatal("timeout")
 	}
@@ -278,6 +288,8 @@ func TestRequests(t *testing.T) {
 		require.Equal(t, resp.llama, scenario3c.srv)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario3c.req.errCh)
+	case err := <-scenario3c.req.errCh:
+		t.Fatal(err.Error())
 	case <-ctx.Done():
 		t.Fatal("timeout")
 	}