Browse Source

check total (system + video) memory

Michael Yang 1 year ago
parent
commit
4a8931f634
1 changed files with 14 additions and 5 deletions
  1. 14 5
      llm/llm.go

+ 14 - 5
llm/llm.go

@@ -60,7 +60,7 @@ func New(workDir, model string, adapters []string, opts api.Options) (LLM, error
 
 	var requiredMemory int64
 	var f16Multiplier int64 = 2
-	totalResidentMemory := int64(memory.TotalMemory())
+
 	switch ggml.ModelType() {
 	case "3B", "7B":
 		requiredMemory = 8 * format.GigaByte
@@ -75,10 +75,19 @@ func New(workDir, model string, adapters []string, opts api.Options) (LLM, error
 		f16Multiplier = 4
 	}
 
-	if ggml.FileType() == "F16" && requiredMemory*f16Multiplier > totalResidentMemory {
-		return nil, fmt.Errorf("F16 model requires at least %s of memory", format.HumanBytes(requiredMemory))
-	} else if requiredMemory > totalResidentMemory {
-		return nil, fmt.Errorf("model requires at least %s of memory", format.HumanBytes(requiredMemory))
+	systemMemory := int64(memory.TotalMemory())
+
+	videoMemory, err := CheckVRAM()
+	if err != nil{
+		videoMemory = 0
+	}
+
+	totalMemory := systemMemory + videoMemory
+
+	if ggml.FileType() == "F16" && requiredMemory*f16Multiplier > totalMemory {
+		return nil, fmt.Errorf("F16 model requires at least %s of total memory", format.HumanBytes(requiredMemory))
+	} else if requiredMemory > totalMemory {
+		return nil, fmt.Errorf("model requires at least %s of total memory", format.HumanBytes(requiredMemory))
 	}
 
 	switch ggml.Name() {