gpu.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. //go:build linux || windows
  2. package gpu
  3. /*
  4. #cgo linux LDFLAGS: -lrt -lpthread -ldl -lstdc++ -lm
  5. #cgo windows LDFLAGS: -lpthread
  6. #include "gpu_info.h"
  7. */
  8. import "C"
  9. import (
  10. "fmt"
  11. "log/slog"
  12. "os"
  13. "path/filepath"
  14. "runtime"
  15. "strings"
  16. "sync"
  17. "unsafe"
  18. )
  19. type handles struct {
  20. cuda *C.cuda_handle_t
  21. rocm *C.rocm_handle_t
  22. }
  23. var gpuMutex sync.Mutex
  24. var gpuHandles *handles = nil
  25. // With our current CUDA compile flags, 5.2 and older will not work properly
  26. const CudaComputeMajorMin = 6
  27. // Possible locations for the nvidia-ml library
  28. var CudaLinuxGlobs = []string{
  29. "/usr/local/cuda/lib64/libnvidia-ml.so*",
  30. "/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so*",
  31. "/usr/lib/x86_64-linux-gnu/libnvidia-ml.so*",
  32. "/usr/lib/wsl/lib/libnvidia-ml.so*",
  33. "/usr/lib/wsl/drivers/*/libnvidia-ml.so*",
  34. "/opt/cuda/lib64/libnvidia-ml.so*",
  35. "/usr/lib*/libnvidia-ml.so*",
  36. "/usr/local/lib*/libnvidia-ml.so*",
  37. "/usr/lib/aarch64-linux-gnu/nvidia/current/libnvidia-ml.so*",
  38. "/usr/lib/aarch64-linux-gnu/libnvidia-ml.so*",
  39. // TODO: are these stubs ever valid?
  40. "/opt/cuda/targets/x86_64-linux/lib/stubs/libnvidia-ml.so*",
  41. }
  42. var CudaWindowsGlobs = []string{
  43. "c:\\Windows\\System32\\nvml.dll",
  44. }
  45. var RocmLinuxGlobs = []string{
  46. "/opt/rocm*/lib*/librocm_smi64.so*",
  47. }
  48. var RocmWindowsGlobs = []string{
  49. "c:\\Windows\\System32\\rocm_smi64.dll",
  50. }
  51. // Note: gpuMutex must already be held
  52. func initGPUHandles() {
  53. // TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
  54. gpuHandles = &handles{nil, nil}
  55. var cudaMgmtName string
  56. var cudaMgmtPatterns []string
  57. var rocmMgmtName string
  58. var rocmMgmtPatterns []string
  59. switch runtime.GOOS {
  60. case "windows":
  61. cudaMgmtName = "nvml.dll"
  62. cudaMgmtPatterns = make([]string, len(CudaWindowsGlobs))
  63. copy(cudaMgmtPatterns, CudaWindowsGlobs)
  64. rocmMgmtName = "rocm_smi64.dll"
  65. rocmMgmtPatterns = make([]string, len(RocmWindowsGlobs))
  66. copy(rocmMgmtPatterns, RocmWindowsGlobs)
  67. case "linux":
  68. cudaMgmtName = "libnvidia-ml.so"
  69. cudaMgmtPatterns = make([]string, len(CudaLinuxGlobs))
  70. copy(cudaMgmtPatterns, CudaLinuxGlobs)
  71. rocmMgmtName = "librocm_smi64.so"
  72. rocmMgmtPatterns = make([]string, len(RocmLinuxGlobs))
  73. copy(rocmMgmtPatterns, RocmLinuxGlobs)
  74. default:
  75. return
  76. }
  77. slog.Info("Detecting GPU type")
  78. cudaLibPaths := FindGPULibs(cudaMgmtName, cudaMgmtPatterns)
  79. if len(cudaLibPaths) > 0 {
  80. cuda := LoadCUDAMgmt(cudaLibPaths)
  81. if cuda != nil {
  82. slog.Info("Nvidia GPU detected")
  83. gpuHandles.cuda = cuda
  84. return
  85. }
  86. }
  87. rocmLibPaths := FindGPULibs(rocmMgmtName, rocmMgmtPatterns)
  88. if len(rocmLibPaths) > 0 {
  89. rocm := LoadROCMMgmt(rocmLibPaths)
  90. if rocm != nil {
  91. slog.Info("Radeon GPU detected")
  92. gpuHandles.rocm = rocm
  93. return
  94. }
  95. }
  96. }
  97. func GetGPUInfo() GpuInfo {
  98. // TODO - consider exploring lspci (and equivalent on windows) to check for
  99. // GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
  100. gpuMutex.Lock()
  101. defer gpuMutex.Unlock()
  102. if gpuHandles == nil {
  103. initGPUHandles()
  104. }
  105. var memInfo C.mem_info_t
  106. resp := GpuInfo{}
  107. if gpuHandles.cuda != nil {
  108. C.cuda_check_vram(*gpuHandles.cuda, &memInfo)
  109. if memInfo.err != nil {
  110. slog.Info(fmt.Sprintf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err)))
  111. C.free(unsafe.Pointer(memInfo.err))
  112. } else {
  113. // Verify minimum compute capability
  114. var cc C.cuda_compute_capability_t
  115. C.cuda_compute_capability(*gpuHandles.cuda, &cc)
  116. if cc.err != nil {
  117. slog.Info(fmt.Sprintf("error looking up CUDA GPU compute capability: %s", C.GoString(cc.err)))
  118. C.free(unsafe.Pointer(cc.err))
  119. } else if cc.major >= CudaComputeMajorMin {
  120. slog.Info(fmt.Sprintf("CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
  121. resp.Library = "cuda"
  122. } else {
  123. slog.Info(fmt.Sprintf("CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
  124. }
  125. }
  126. } else if gpuHandles.rocm != nil {
  127. C.rocm_check_vram(*gpuHandles.rocm, &memInfo)
  128. if memInfo.err != nil {
  129. slog.Info(fmt.Sprintf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err)))
  130. C.free(unsafe.Pointer(memInfo.err))
  131. } else {
  132. resp.Library = "rocm"
  133. var version C.rocm_version_resp_t
  134. C.rocm_get_version(*gpuHandles.rocm, &version)
  135. verString := C.GoString(version.str)
  136. if version.status == 0 {
  137. resp.Variant = "v" + verString
  138. } else {
  139. slog.Info(fmt.Sprintf("failed to look up ROCm version: %s", verString))
  140. }
  141. C.free(unsafe.Pointer(version.str))
  142. }
  143. }
  144. if resp.Library == "" {
  145. C.cpu_check_ram(&memInfo)
  146. resp.Library = "cpu"
  147. resp.Variant = GetCPUVariant()
  148. }
  149. if memInfo.err != nil {
  150. slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err)))
  151. C.free(unsafe.Pointer(memInfo.err))
  152. return resp
  153. }
  154. resp.DeviceCount = uint32(memInfo.count)
  155. resp.FreeMemory = uint64(memInfo.free)
  156. resp.TotalMemory = uint64(memInfo.total)
  157. return resp
  158. }
  159. func getCPUMem() (memInfo, error) {
  160. var ret memInfo
  161. var info C.mem_info_t
  162. C.cpu_check_ram(&info)
  163. if info.err != nil {
  164. defer C.free(unsafe.Pointer(info.err))
  165. return ret, fmt.Errorf(C.GoString(info.err))
  166. }
  167. ret.FreeMemory = uint64(info.free)
  168. ret.TotalMemory = uint64(info.total)
  169. return ret, nil
  170. }
  171. func CheckVRAM() (int64, error) {
  172. gpuInfo := GetGPUInfo()
  173. if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
  174. // leave 10% or 1024MiB of VRAM free per GPU to handle unaccounted for overhead
  175. overhead := gpuInfo.FreeMemory / 10
  176. gpus := uint64(gpuInfo.DeviceCount)
  177. if overhead < gpus*1024*1024*1024 {
  178. overhead = gpus * 1024 * 1024 * 1024
  179. }
  180. return int64(gpuInfo.FreeMemory - overhead), nil
  181. }
  182. return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
  183. }
  184. func FindGPULibs(baseLibName string, patterns []string) []string {
  185. // Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
  186. var ldPaths []string
  187. gpuLibPaths := []string{}
  188. slog.Info(fmt.Sprintf("Searching for GPU management library %s", baseLibName))
  189. switch runtime.GOOS {
  190. case "windows":
  191. ldPaths = strings.Split(os.Getenv("PATH"), ";")
  192. case "linux":
  193. ldPaths = strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
  194. default:
  195. return gpuLibPaths
  196. }
  197. // Start with whatever we find in the PATH/LD_LIBRARY_PATH
  198. for _, ldPath := range ldPaths {
  199. d, err := filepath.Abs(ldPath)
  200. if err != nil {
  201. continue
  202. }
  203. patterns = append(patterns, filepath.Join(d, baseLibName+"*"))
  204. }
  205. slog.Debug(fmt.Sprintf("gpu management search paths: %v", patterns))
  206. for _, pattern := range patterns {
  207. // Ignore glob discovery errors
  208. matches, _ := filepath.Glob(pattern)
  209. for _, match := range matches {
  210. // Resolve any links so we don't try the same lib multiple times
  211. // and weed out any dups across globs
  212. libPath := match
  213. tmp := match
  214. var err error
  215. for ; err == nil; tmp, err = os.Readlink(libPath) {
  216. if !filepath.IsAbs(tmp) {
  217. tmp = filepath.Join(filepath.Dir(libPath), tmp)
  218. }
  219. libPath = tmp
  220. }
  221. new := true
  222. for _, cmp := range gpuLibPaths {
  223. if cmp == libPath {
  224. new = false
  225. break
  226. }
  227. }
  228. if new {
  229. gpuLibPaths = append(gpuLibPaths, libPath)
  230. }
  231. }
  232. }
  233. slog.Info(fmt.Sprintf("Discovered GPU libraries: %v", gpuLibPaths))
  234. return gpuLibPaths
  235. }
  236. func LoadCUDAMgmt(cudaLibPaths []string) *C.cuda_handle_t {
  237. var resp C.cuda_init_resp_t
  238. resp.ch.verbose = getVerboseState()
  239. for _, libPath := range cudaLibPaths {
  240. lib := C.CString(libPath)
  241. defer C.free(unsafe.Pointer(lib))
  242. C.cuda_init(lib, &resp)
  243. if resp.err != nil {
  244. slog.Info(fmt.Sprintf("Unable to load CUDA management library %s: %s", libPath, C.GoString(resp.err)))
  245. C.free(unsafe.Pointer(resp.err))
  246. } else {
  247. return &resp.ch
  248. }
  249. }
  250. return nil
  251. }
  252. func LoadROCMMgmt(rocmLibPaths []string) *C.rocm_handle_t {
  253. var resp C.rocm_init_resp_t
  254. resp.rh.verbose = getVerboseState()
  255. for _, libPath := range rocmLibPaths {
  256. lib := C.CString(libPath)
  257. defer C.free(unsafe.Pointer(lib))
  258. C.rocm_init(lib, &resp)
  259. if resp.err != nil {
  260. slog.Info(fmt.Sprintf("Unable to load ROCm management library %s: %s", libPath, C.GoString(resp.err)))
  261. C.free(unsafe.Pointer(resp.err))
  262. } else {
  263. return &resp.rh
  264. }
  265. }
  266. return nil
  267. }
  268. func getVerboseState() C.uint16_t {
  269. if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
  270. return C.uint16_t(1)
  271. }
  272. return C.uint16_t(0)
  273. }