gpu.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. //go:build linux || windows
  2. package gpu
  3. /*
  4. #cgo linux LDFLAGS: -lrt -lpthread -ldl -lstdc++ -lm
  5. #cgo windows LDFLAGS: -lpthread
  6. #include "gpu_info.h"
  7. */
  8. import "C"
  9. import (
  10. "fmt"
  11. "log"
  12. "sync"
  13. "unsafe"
  14. "github.com/jmorganca/ollama/api"
  15. )
  16. type handles struct {
  17. cuda *C.cuda_handle_t
  18. rocm *C.rocm_handle_t
  19. }
  20. var gpuMutex sync.Mutex
  21. var gpuHandles *handles = nil
  22. // Note: gpuMutex must already be held
  23. func initGPUHandles() {
  24. // TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
  25. log.Printf("Detecting GPU type")
  26. gpuHandles = &handles{nil, nil}
  27. var resp C.cuda_init_resp_t
  28. C.cuda_init(&resp)
  29. if resp.err != nil {
  30. log.Printf("CUDA not detected: %s", C.GoString(resp.err))
  31. C.free(unsafe.Pointer(resp.err))
  32. var resp C.rocm_init_resp_t
  33. C.rocm_init(&resp)
  34. if resp.err != nil {
  35. log.Printf("ROCm not detected: %s", C.GoString(resp.err))
  36. C.free(unsafe.Pointer(resp.err))
  37. } else {
  38. log.Printf("Radeon GPU detected")
  39. rocm := resp.rh
  40. gpuHandles.rocm = &rocm
  41. }
  42. } else {
  43. log.Printf("Nvidia GPU detected")
  44. cuda := resp.ch
  45. gpuHandles.cuda = &cuda
  46. }
  47. }
  48. func GetGPUInfo() GpuInfo {
  49. // TODO - consider exploring lspci (and equivalent on windows) to check for
  50. // GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
  51. gpuMutex.Lock()
  52. defer gpuMutex.Unlock()
  53. if gpuHandles == nil {
  54. initGPUHandles()
  55. }
  56. var memInfo C.mem_info_t
  57. resp := GpuInfo{"", 0, 0}
  58. if gpuHandles.cuda != nil {
  59. C.cuda_check_vram(*gpuHandles.cuda, &memInfo)
  60. if memInfo.err != nil {
  61. log.Printf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err))
  62. C.free(unsafe.Pointer(memInfo.err))
  63. } else {
  64. resp.Driver = "CUDA"
  65. }
  66. } else if gpuHandles.rocm != nil {
  67. C.rocm_check_vram(*gpuHandles.rocm, &memInfo)
  68. if memInfo.err != nil {
  69. log.Printf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err))
  70. C.free(unsafe.Pointer(memInfo.err))
  71. } else {
  72. resp.Driver = "ROCM"
  73. }
  74. }
  75. if resp.Driver == "" {
  76. C.cpu_check_ram(&memInfo)
  77. resp.Driver = "CPU"
  78. }
  79. if memInfo.err != nil {
  80. log.Printf("error looking up CPU memory: %s", C.GoString(memInfo.err))
  81. C.free(unsafe.Pointer(memInfo.err))
  82. return resp
  83. }
  84. resp.FreeMemory = uint64(memInfo.free)
  85. resp.TotalMemory = uint64(memInfo.total)
  86. return resp
  87. }
  88. func CheckVRAM() (int64, error) {
  89. gpuInfo := GetGPUInfo()
  90. if gpuInfo.FreeMemory > 0 && gpuInfo.Driver != "CPU" {
  91. return int64(gpuInfo.FreeMemory), nil
  92. }
  93. return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
  94. }
  95. func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
  96. if opts.NumGPU != -1 {
  97. return opts.NumGPU
  98. }
  99. info := GetGPUInfo()
  100. if info.Driver == "CPU" {
  101. return 0
  102. }
  103. /*
  104. Calculate bytes per layer, this will roughly be the size of the model file divided by the number of layers.
  105. We can store the model weights and the kv cache in vram,
  106. to enable kv chache vram storage add two additional layers to the number of layers retrieved from the model file.
  107. */
  108. bytesPerLayer := uint64(fileSizeBytes / numLayer)
  109. // 75% of the absolute max number of layers we can fit in available VRAM, off-loading too many layers to the GPU can cause OOM errors
  110. layers := int(info.FreeMemory/bytesPerLayer) * 3 / 4
  111. log.Printf("%d MB VRAM available, loading up to %d %s GPU layers out of %d", info.FreeMemory/(1024*1024), layers, info.Driver, numLayer)
  112. return layers
  113. }