gpu.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. //go:build linux || windows
  2. package gpu
  3. /*
  4. #cgo linux LDFLAGS: -lrt -lpthread -ldl -lstdc++ -lm
  5. #cgo windows LDFLAGS: -lpthread
  6. #include "gpu_info.h"
  7. */
  8. import "C"
  9. import (
  10. "fmt"
  11. "log"
  12. "runtime"
  13. "sync"
  14. "unsafe"
  15. "github.com/jmorganca/ollama/api"
  16. )
  17. type handles struct {
  18. cuda *C.cuda_handle_t
  19. rocm *C.rocm_handle_t
  20. }
  21. var gpuMutex sync.Mutex
  22. var gpuHandles *handles = nil
  23. // Note: gpuMutex must already be held
  24. func initGPUHandles() {
  25. // TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
  26. log.Printf("Detecting GPU type")
  27. gpuHandles = &handles{nil, nil}
  28. var resp C.cuda_init_resp_t
  29. C.cuda_init(&resp)
  30. if resp.err != nil {
  31. log.Printf("CUDA not detected: %s", C.GoString(resp.err))
  32. C.free(unsafe.Pointer(resp.err))
  33. var resp C.rocm_init_resp_t
  34. C.rocm_init(&resp)
  35. if resp.err != nil {
  36. log.Printf("ROCm not detected: %s", C.GoString(resp.err))
  37. C.free(unsafe.Pointer(resp.err))
  38. } else {
  39. log.Printf("Radeon GPU detected")
  40. rocm := resp.rh
  41. gpuHandles.rocm = &rocm
  42. }
  43. } else {
  44. log.Printf("Nvidia GPU detected")
  45. cuda := resp.ch
  46. gpuHandles.cuda = &cuda
  47. }
  48. }
  49. func GetGPUInfo() GpuInfo {
  50. // TODO - consider exploring lspci (and equivalent on windows) to check for
  51. // GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
  52. gpuMutex.Lock()
  53. defer gpuMutex.Unlock()
  54. if gpuHandles == nil {
  55. initGPUHandles()
  56. }
  57. var memInfo C.mem_info_t
  58. resp := GpuInfo{"", 0, 0}
  59. if gpuHandles.cuda != nil {
  60. C.cuda_check_vram(*gpuHandles.cuda, &memInfo)
  61. if memInfo.err != nil {
  62. log.Printf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err))
  63. C.free(unsafe.Pointer(memInfo.err))
  64. } else {
  65. resp.Library = "cuda"
  66. }
  67. } else if gpuHandles.rocm != nil {
  68. C.rocm_check_vram(*gpuHandles.rocm, &memInfo)
  69. if memInfo.err != nil {
  70. log.Printf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err))
  71. C.free(unsafe.Pointer(memInfo.err))
  72. } else {
  73. resp.Library = "rocm"
  74. }
  75. }
  76. if resp.Library == "" {
  77. C.cpu_check_ram(&memInfo)
  78. // In the future we may offer multiple CPU variants to tune CPU features
  79. if runtime.GOOS == "windows" {
  80. resp.Library = "cpu"
  81. } else {
  82. resp.Library = "default"
  83. }
  84. }
  85. if memInfo.err != nil {
  86. log.Printf("error looking up CPU memory: %s", C.GoString(memInfo.err))
  87. C.free(unsafe.Pointer(memInfo.err))
  88. return resp
  89. }
  90. resp.FreeMemory = uint64(memInfo.free)
  91. resp.TotalMemory = uint64(memInfo.total)
  92. return resp
  93. }
  94. func CheckVRAM() (int64, error) {
  95. gpuInfo := GetGPUInfo()
  96. if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
  97. return int64(gpuInfo.FreeMemory), nil
  98. }
  99. return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
  100. }
  101. func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
  102. if opts.NumGPU != -1 {
  103. return opts.NumGPU
  104. }
  105. info := GetGPUInfo()
  106. if info.Library == "cpu" || info.Library == "default" {
  107. return 0
  108. }
  109. /*
  110. Calculate bytes per layer, this will roughly be the size of the model file divided by the number of layers.
  111. We can store the model weights and the kv cache in vram,
  112. to enable kv chache vram storage add two additional layers to the number of layers retrieved from the model file.
  113. */
  114. bytesPerLayer := uint64(fileSizeBytes / numLayer)
  115. // 75% of the absolute max number of layers we can fit in available VRAM, off-loading too many layers to the GPU can cause OOM errors
  116. layers := int(info.FreeMemory/bytesPerLayer) * 3 / 4
  117. log.Printf("%d MB VRAM available, loading up to %d %s GPU layers out of %d", info.FreeMemory/(1024*1024), layers, info.Library, numLayer)
  118. return layers
  119. }