Browse Source

Reload model if `num_gpu` changes (#3920)

* reload model if `num_gpu` changes

* dont reload on -1

* fix tests
Jeffrey Morgan 1 year ago
parent
commit
00b0699c75
2 changed files with 15 additions and 6 deletions
  1. 12 6
      server/sched.go
  2. 3 0
      server/sched_test.go

+ 12 - 6
server/sched.go

@@ -421,16 +421,21 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
 	slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
 	runner.refMu.Lock()
 	defer runner.refMu.Unlock()
-	// Ignore the NumGPU settings for comparison
-	optsExisting := runner.Options.Runner
-	optsExisting.NumGPU = -1
-	optsNew := req.opts.Runner
-	optsNew.NumGPU = -1
+
 	timeout := 10 * time.Second
 	if runner.loading {
 		timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems...
 	}
-	ctx, cancel := context.WithTimeout(ctx, timeout) // BUG -
+
+	// Don't reload runner if num_gpu=-1 was provided
+	optsExisting := runner.Options.Runner
+	optsNew := req.opts.Runner
+	if optsNew.NumGPU < 0 {
+		optsExisting.NumGPU = -1
+		optsNew.NumGPU = -1
+	}
+
+	ctx, cancel := context.WithTimeout(ctx, timeout)
 	defer cancel()
 	if !reflect.DeepEqual(runner.adapters, req.model.AdapterPaths) || // have the adapters changed?
 		!reflect.DeepEqual(runner.projectors, req.model.ProjectorPaths) || // have the projectors changed?
@@ -438,6 +443,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
 		runner.llama.Ping(ctx) != nil {
 		return true
 	}
+
 	return false
 }
 

+ 3 - 0
server/sched_test.go

@@ -490,6 +490,9 @@ func TestNeedsReload(t *testing.T) {
 	require.False(t, resp)
 	req.opts.NumGPU = 99
 	resp = runner.needsReload(ctx, req)
+	require.True(t, resp)
+	req.opts.NumGPU = -1
+	resp = runner.needsReload(ctx, req)
 	require.False(t, resp)
 }