gpu.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. //go:build linux || windows
  2. package gpu
  3. /*
  4. #cgo linux LDFLAGS: -lrt -lpthread -ldl -lstdc++ -lm
  5. #cgo windows LDFLAGS: -lpthread
  6. #include "gpu_info.h"
  7. */
  8. import "C"
  9. import (
  10. "fmt"
  11. "log"
  12. "runtime"
  13. "sync"
  14. "unsafe"
  15. "github.com/jmorganca/ollama/api"
  16. )
  17. type handles struct {
  18. cuda *C.cuda_handle_t
  19. rocm *C.rocm_handle_t
  20. }
  21. var gpuMutex sync.Mutex
  22. var gpuHandles *handles = nil
  23. // TODO verify this is the correct min version
  24. const CudaComputeMajorMin = 5
  25. // Note: gpuMutex must already be held
  26. func initGPUHandles() {
  27. // TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
  28. log.Printf("Detecting GPU type")
  29. gpuHandles = &handles{nil, nil}
  30. var resp C.cuda_init_resp_t
  31. C.cuda_init(&resp)
  32. if resp.err != nil {
  33. log.Printf("CUDA not detected: %s", C.GoString(resp.err))
  34. C.free(unsafe.Pointer(resp.err))
  35. var resp C.rocm_init_resp_t
  36. C.rocm_init(&resp)
  37. if resp.err != nil {
  38. log.Printf("ROCm not detected: %s", C.GoString(resp.err))
  39. C.free(unsafe.Pointer(resp.err))
  40. } else {
  41. log.Printf("Radeon GPU detected")
  42. rocm := resp.rh
  43. gpuHandles.rocm = &rocm
  44. }
  45. } else {
  46. log.Printf("Nvidia GPU detected")
  47. cuda := resp.ch
  48. gpuHandles.cuda = &cuda
  49. }
  50. }
  51. func GetGPUInfo() GpuInfo {
  52. // TODO - consider exploring lspci (and equivalent on windows) to check for
  53. // GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
  54. gpuMutex.Lock()
  55. defer gpuMutex.Unlock()
  56. if gpuHandles == nil {
  57. initGPUHandles()
  58. }
  59. var memInfo C.mem_info_t
  60. resp := GpuInfo{}
  61. if gpuHandles.cuda != nil {
  62. C.cuda_check_vram(*gpuHandles.cuda, &memInfo)
  63. if memInfo.err != nil {
  64. log.Printf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err))
  65. C.free(unsafe.Pointer(memInfo.err))
  66. } else {
  67. // Verify minimum compute capability
  68. var cc C.cuda_compute_capability_t
  69. C.cuda_compute_capability(*gpuHandles.cuda, &cc)
  70. if cc.err != nil {
  71. log.Printf("error looking up CUDA GPU compute capability: %s", C.GoString(cc.err))
  72. C.free(unsafe.Pointer(cc.err))
  73. } else if cc.major >= CudaComputeMajorMin {
  74. log.Printf("CUDA Compute Capability detected: %d.%d", cc.major, cc.minor)
  75. resp.Library = "cuda"
  76. } else {
  77. log.Printf("CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor)
  78. }
  79. }
  80. } else if gpuHandles.rocm != nil {
  81. C.rocm_check_vram(*gpuHandles.rocm, &memInfo)
  82. if memInfo.err != nil {
  83. log.Printf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err))
  84. C.free(unsafe.Pointer(memInfo.err))
  85. } else {
  86. resp.Library = "rocm"
  87. }
  88. }
  89. if resp.Library == "" {
  90. C.cpu_check_ram(&memInfo)
  91. // In the future we may offer multiple CPU variants to tune CPU features
  92. if runtime.GOOS == "windows" {
  93. resp.Library = "cpu"
  94. } else {
  95. resp.Library = "default"
  96. }
  97. }
  98. if memInfo.err != nil {
  99. log.Printf("error looking up CPU memory: %s", C.GoString(memInfo.err))
  100. C.free(unsafe.Pointer(memInfo.err))
  101. return resp
  102. }
  103. resp.FreeMemory = uint64(memInfo.free)
  104. resp.TotalMemory = uint64(memInfo.total)
  105. return resp
  106. }
  107. func getCPUMem() (memInfo, error) {
  108. var ret memInfo
  109. var info C.mem_info_t
  110. C.cpu_check_ram(&info)
  111. if info.err != nil {
  112. defer C.free(unsafe.Pointer(info.err))
  113. return ret, fmt.Errorf(C.GoString(info.err))
  114. }
  115. ret.FreeMemory = uint64(info.free)
  116. ret.TotalMemory = uint64(info.total)
  117. return ret, nil
  118. }
  119. func CheckVRAM() (int64, error) {
  120. gpuInfo := GetGPUInfo()
  121. if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
  122. return int64(gpuInfo.FreeMemory), nil
  123. }
  124. return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
  125. }
  126. func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
  127. if opts.NumGPU != -1 {
  128. return opts.NumGPU
  129. }
  130. info := GetGPUInfo()
  131. if info.Library == "cpu" || info.Library == "default" {
  132. return 0
  133. }
  134. /*
  135. Calculate bytes per layer, this will roughly be the size of the model file divided by the number of layers.
  136. We can store the model weights and the kv cache in vram,
  137. to enable kv chache vram storage add two additional layers to the number of layers retrieved from the model file.
  138. */
  139. bytesPerLayer := uint64(fileSizeBytes / numLayer)
  140. // 75% of the absolute max number of layers we can fit in available VRAM, off-loading too many layers to the GPU can cause OOM errors
  141. layers := int(info.FreeMemory/bytesPerLayer) * 3 / 4
  142. log.Printf("%d MB VRAM available, loading up to %d %s GPU layers out of %d", info.FreeMemory/(1024*1024), layers, info.Library, numLayer)
  143. return layers
  144. }