amd_hip_windows.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. package gpu
  2. import (
  3. "errors"
  4. "fmt"
  5. "log/slog"
  6. "syscall"
  7. "unsafe"
  8. "golang.org/x/sys/windows"
  9. )
  10. const (
  11. hipSuccess = 0
  12. hipErrorNoDevice = 100
  13. )
  14. type hipDevicePropMinimal struct {
  15. Name [256]byte
  16. unused1 [140]byte
  17. GcnArchName [256]byte // gfx####
  18. iGPU int // Doesn't seem to actually report correctly
  19. unused2 [128]byte
  20. }
  21. // Wrap the amdhip64.dll library for GPU discovery
  22. type HipLib struct {
  23. dll windows.Handle
  24. hipGetDeviceCount uintptr
  25. hipGetDeviceProperties uintptr
  26. hipMemGetInfo uintptr
  27. hipSetDevice uintptr
  28. hipDriverGetVersion uintptr
  29. }
  30. func NewHipLib() (*HipLib, error) {
  31. // At runtime we depend on v6, so discover GPUs with the same library for a consistent set of GPUs
  32. h, err := windows.LoadLibrary("amdhip64_6.dll")
  33. if err != nil {
  34. return nil, fmt.Errorf("unable to load amdhip64_6.dll, please make sure to upgrade to the latest amd driver: %w", err)
  35. }
  36. hl := &HipLib{}
  37. hl.dll = h
  38. hl.hipGetDeviceCount, err = windows.GetProcAddress(hl.dll, "hipGetDeviceCount")
  39. if err != nil {
  40. return nil, err
  41. }
  42. hl.hipGetDeviceProperties, err = windows.GetProcAddress(hl.dll, "hipGetDeviceProperties")
  43. if err != nil {
  44. return nil, err
  45. }
  46. hl.hipMemGetInfo, err = windows.GetProcAddress(hl.dll, "hipMemGetInfo")
  47. if err != nil {
  48. return nil, err
  49. }
  50. hl.hipSetDevice, err = windows.GetProcAddress(hl.dll, "hipSetDevice")
  51. if err != nil {
  52. return nil, err
  53. }
  54. hl.hipDriverGetVersion, err = windows.GetProcAddress(hl.dll, "hipDriverGetVersion")
  55. if err != nil {
  56. return nil, err
  57. }
  58. return hl, nil
  59. }
  60. // The hip library only evaluates the HIP_VISIBLE_DEVICES variable at startup
  61. // so we have to unload/reset the library after we do our initial discovery
  62. // to make sure our updates to that variable are processed by llama.cpp
  63. func (hl *HipLib) Release() {
  64. err := windows.FreeLibrary(hl.dll)
  65. if err != nil {
  66. slog.Warn("failed to unload amdhip64.dll", "error", err)
  67. }
  68. hl.dll = 0
  69. }
  70. func (hl *HipLib) AMDDriverVersion() (driverMajor, driverMinor int, err error) {
  71. if hl.dll == 0 {
  72. return 0, 0, errors.New("dll has been unloaded")
  73. }
  74. var version int
  75. status, _, err := syscall.SyscallN(hl.hipDriverGetVersion, uintptr(unsafe.Pointer(&version)))
  76. if status != hipSuccess {
  77. return 0, 0, fmt.Errorf("failed call to hipDriverGetVersion: %d %s", status, err)
  78. }
  79. slog.Debug("hipDriverGetVersion", "version", version)
  80. driverMajor = version / 10000000
  81. driverMinor = (version - (driverMajor * 10000000)) / 100000
  82. return driverMajor, driverMinor, nil
  83. }
  84. func (hl *HipLib) HipGetDeviceCount() int {
  85. if hl.dll == 0 {
  86. slog.Error("dll has been unloaded")
  87. return 0
  88. }
  89. var count int
  90. status, _, err := syscall.SyscallN(hl.hipGetDeviceCount, uintptr(unsafe.Pointer(&count)))
  91. if status == hipErrorNoDevice {
  92. slog.Info("AMD ROCm reports no devices found")
  93. return 0
  94. }
  95. if status != hipSuccess {
  96. slog.Warn("failed call to hipGetDeviceCount", "status", status, "error", err)
  97. }
  98. return count
  99. }
  100. func (hl *HipLib) HipSetDevice(device int) error {
  101. if hl.dll == 0 {
  102. return errors.New("dll has been unloaded")
  103. }
  104. status, _, err := syscall.SyscallN(hl.hipSetDevice, uintptr(device))
  105. if status != hipSuccess {
  106. return fmt.Errorf("failed call to hipSetDevice: %d %s", status, err)
  107. }
  108. return nil
  109. }
  110. func (hl *HipLib) HipGetDeviceProperties(device int) (*hipDevicePropMinimal, error) {
  111. if hl.dll == 0 {
  112. return nil, errors.New("dll has been unloaded")
  113. }
  114. var props hipDevicePropMinimal
  115. status, _, err := syscall.SyscallN(hl.hipGetDeviceProperties, uintptr(unsafe.Pointer(&props)), uintptr(device))
  116. if status != hipSuccess {
  117. return nil, fmt.Errorf("failed call to hipGetDeviceProperties: %d %s", status, err)
  118. }
  119. return &props, nil
  120. }
  121. // free, total, err
  122. func (hl *HipLib) HipMemGetInfo() (uint64, uint64, error) {
  123. if hl.dll == 0 {
  124. return 0, 0, errors.New("dll has been unloaded")
  125. }
  126. var totalMemory uint64
  127. var freeMemory uint64
  128. status, _, err := syscall.SyscallN(hl.hipMemGetInfo, uintptr(unsafe.Pointer(&freeMemory)), uintptr(unsafe.Pointer(&totalMemory)))
  129. if status != hipSuccess {
  130. return 0, 0, fmt.Errorf("failed call to hipMemGetInfo: %d %s", status, err)
  131. }
  132. return freeMemory, totalMemory, nil
  133. }