Pārlūkot izejas kodu

improve vram safety with 5% vram memory buffer (#724)

* check free memory not total
* wait for subprocess to exit
Bruce MacDonald 1 gadu atpakaļ
vecāks
revīzija
f2ba1311aa
1 mainītis faili ar 13 papildinājumiem un 7 dzēšanām
  1. 13 7
      llm/llama.go

+ 13 - 7
llm/llama.go

@@ -191,7 +191,7 @@ var errNoGPU = errors.New("nvidia-smi command failed")
 
 // CheckVRAM returns the available VRAM in MiB on Linux machines with NVIDIA GPUs
 func CheckVRAM() (int64, error) {
-	cmd := exec.Command("nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits")
+	cmd := exec.Command("nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits")
 	var stdout bytes.Buffer
 	cmd.Stdout = &stdout
 	err := cmd.Run()
@@ -199,7 +199,7 @@ func CheckVRAM() (int64, error) {
 		return 0, errNoGPU
 	}
 
-	var total int64
+	var free int64
 	scanner := bufio.NewScanner(&stdout)
 	for scanner.Scan() {
 		line := scanner.Text()
@@ -208,10 +208,10 @@ func CheckVRAM() (int64, error) {
 			return 0, fmt.Errorf("failed to parse available VRAM: %v", err)
 		}
 
-		total += vram
+		free += vram
 	}
 
-	return total, nil
+	return free, nil
 }
 
 func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
@@ -228,14 +228,14 @@ func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
 			return 0
 		}
 
-		totalVramBytes := int64(vramMib) * 1024 * 1024 // 1 MiB = 1024^2 bytes
+		freeVramBytes := int64(vramMib) * 1024 * 1024 // 1 MiB = 1024^2 bytes
 
 		// Calculate bytes per layer
 		// TODO: this is a rough heuristic, better would be to calculate this based on number of layers and context size
 		bytesPerLayer := fileSizeBytes / numLayer
 
-		// max number of layers we can fit in VRAM
-		layers := int(totalVramBytes / bytesPerLayer)
+		// max number of layers we can fit in VRAM, subtract 5% to prevent consuming all available VRAM and running out of memory
+		layers := int(freeVramBytes/bytesPerLayer) * 95 / 100
 		log.Printf("%d MiB VRAM available, loading up to %d GPU layers", vramMib, layers)
 
 		return layers
@@ -367,7 +367,13 @@ func waitForServer(llm *llama) error {
 }
 
 func (llm *llama) Close() {
+	// signal the sub-process to terminate
 	llm.Cancel()
+
+	// wait for the command to exit to prevent race conditions with the next run
+	if err := llm.Cmd.Wait(); err != nil {
+		log.Printf("llama runner exited: %v", err)
+	}
 }
 
 func (llm *llama) SetOptions(opts api.Options) {