gpu.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. //go:build linux || windows
  2. package gpu
  3. /*
  4. #cgo linux LDFLAGS: -lrt -lpthread -ldl -lstdc++ -lm
  5. #cgo windows LDFLAGS: -lpthread
  6. #include "gpu_info.h"
  7. */
  8. import "C"
  9. import (
  10. "fmt"
  11. "log"
  12. "runtime"
  13. "sync"
  14. "unsafe"
  15. )
  16. type handles struct {
  17. cuda *C.cuda_handle_t
  18. rocm *C.rocm_handle_t
  19. }
  20. var gpuMutex sync.Mutex
  21. var gpuHandles *handles = nil
  22. // TODO verify this is the correct min version
  23. const CudaComputeMajorMin = 5
  24. // Note: gpuMutex must already be held
  25. func initGPUHandles() {
  26. // TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
  27. log.Printf("Detecting GPU type")
  28. gpuHandles = &handles{nil, nil}
  29. var resp C.cuda_init_resp_t
  30. C.cuda_init(&resp)
  31. if resp.err != nil {
  32. log.Printf("CUDA not detected: %s", C.GoString(resp.err))
  33. C.free(unsafe.Pointer(resp.err))
  34. var resp C.rocm_init_resp_t
  35. C.rocm_init(&resp)
  36. if resp.err != nil {
  37. log.Printf("ROCm not detected: %s", C.GoString(resp.err))
  38. C.free(unsafe.Pointer(resp.err))
  39. } else {
  40. log.Printf("Radeon GPU detected")
  41. rocm := resp.rh
  42. gpuHandles.rocm = &rocm
  43. }
  44. } else {
  45. log.Printf("Nvidia GPU detected")
  46. cuda := resp.ch
  47. gpuHandles.cuda = &cuda
  48. }
  49. }
  50. func GetGPUInfo() GpuInfo {
  51. // TODO - consider exploring lspci (and equivalent on windows) to check for
  52. // GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
  53. gpuMutex.Lock()
  54. defer gpuMutex.Unlock()
  55. if gpuHandles == nil {
  56. initGPUHandles()
  57. }
  58. var memInfo C.mem_info_t
  59. resp := GpuInfo{}
  60. if gpuHandles.cuda != nil {
  61. C.cuda_check_vram(*gpuHandles.cuda, &memInfo)
  62. if memInfo.err != nil {
  63. log.Printf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err))
  64. C.free(unsafe.Pointer(memInfo.err))
  65. } else {
  66. // Verify minimum compute capability
  67. var cc C.cuda_compute_capability_t
  68. C.cuda_compute_capability(*gpuHandles.cuda, &cc)
  69. if cc.err != nil {
  70. log.Printf("error looking up CUDA GPU compute capability: %s", C.GoString(cc.err))
  71. C.free(unsafe.Pointer(cc.err))
  72. } else if cc.major >= CudaComputeMajorMin {
  73. log.Printf("CUDA Compute Capability detected: %d.%d", cc.major, cc.minor)
  74. resp.Library = "cuda"
  75. } else {
  76. log.Printf("CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor)
  77. }
  78. }
  79. } else if gpuHandles.rocm != nil {
  80. C.rocm_check_vram(*gpuHandles.rocm, &memInfo)
  81. if memInfo.err != nil {
  82. log.Printf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err))
  83. C.free(unsafe.Pointer(memInfo.err))
  84. } else {
  85. resp.Library = "rocm"
  86. }
  87. }
  88. if resp.Library == "" {
  89. C.cpu_check_ram(&memInfo)
  90. // In the future we may offer multiple CPU variants to tune CPU features
  91. if runtime.GOOS == "windows" {
  92. resp.Library = "cpu"
  93. } else {
  94. resp.Library = "default"
  95. }
  96. }
  97. if memInfo.err != nil {
  98. log.Printf("error looking up CPU memory: %s", C.GoString(memInfo.err))
  99. C.free(unsafe.Pointer(memInfo.err))
  100. return resp
  101. }
  102. resp.FreeMemory = uint64(memInfo.free)
  103. resp.TotalMemory = uint64(memInfo.total)
  104. return resp
  105. }
  106. func getCPUMem() (memInfo, error) {
  107. var ret memInfo
  108. var info C.mem_info_t
  109. C.cpu_check_ram(&info)
  110. if info.err != nil {
  111. defer C.free(unsafe.Pointer(info.err))
  112. return ret, fmt.Errorf(C.GoString(info.err))
  113. }
  114. ret.FreeMemory = uint64(info.free)
  115. ret.TotalMemory = uint64(info.total)
  116. return ret, nil
  117. }
  118. func CheckVRAM() (int64, error) {
  119. gpuInfo := GetGPUInfo()
  120. if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
  121. // leave 10% or 400MiB of VRAM free for overhead
  122. overhead := gpuInfo.FreeMemory / 10
  123. minOverhead := 400 * 1024 * 1024
  124. if overhead < minOverhead {
  125. overhead = minOverhead
  126. }
  127. return int64(gpuInfo.FreeMemory - overhead), nil
  128. }
  129. return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
  130. }