amd_hip_windows.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. package gpu
  2. import (
  3. "fmt"
  4. "log/slog"
  5. "syscall"
  6. "unsafe"
  7. "golang.org/x/sys/windows"
  8. )
  9. const (
  10. hipSuccess = 0
  11. hipErrorNoDevice = 100
  12. )
  13. type hipDevicePropMinimal struct {
  14. Name [256]byte
  15. unused1 [140]byte
  16. GcnArchName [256]byte // gfx####
  17. iGPU int // Doesn't seem to actually report correctly
  18. unused2 [128]byte
  19. }
  20. // Wrap the amdhip64.dll library for GPU discovery
  21. type HipLib struct {
  22. dll windows.Handle
  23. hipGetDeviceCount uintptr
  24. hipGetDeviceProperties uintptr
  25. hipMemGetInfo uintptr
  26. hipSetDevice uintptr
  27. hipDriverGetVersion uintptr
  28. }
  29. func NewHipLib() (*HipLib, error) {
  30. // At runtime we depend on v6, so discover GPUs with the same library for a consistent set of GPUs
  31. h, err := windows.LoadLibrary("amdhip64_6.dll")
  32. if err != nil {
  33. return nil, fmt.Errorf("unable to load amdhip64_6.dll, please make sure to upgrade to the latest amd driver: %w", err)
  34. }
  35. hl := &HipLib{}
  36. hl.dll = h
  37. hl.hipGetDeviceCount, err = windows.GetProcAddress(hl.dll, "hipGetDeviceCount")
  38. if err != nil {
  39. return nil, err
  40. }
  41. hl.hipGetDeviceProperties, err = windows.GetProcAddress(hl.dll, "hipGetDeviceProperties")
  42. if err != nil {
  43. return nil, err
  44. }
  45. hl.hipMemGetInfo, err = windows.GetProcAddress(hl.dll, "hipMemGetInfo")
  46. if err != nil {
  47. return nil, err
  48. }
  49. hl.hipSetDevice, err = windows.GetProcAddress(hl.dll, "hipSetDevice")
  50. if err != nil {
  51. return nil, err
  52. }
  53. hl.hipDriverGetVersion, err = windows.GetProcAddress(hl.dll, "hipDriverGetVersion")
  54. if err != nil {
  55. return nil, err
  56. }
  57. return hl, nil
  58. }
  59. // The hip library only evaluates the HIP_VISIBLE_DEVICES variable at startup
  60. // so we have to unload/reset the library after we do our initial discovery
  61. // to make sure our updates to that variable are processed by llama.cpp
  62. func (hl *HipLib) Release() {
  63. err := windows.FreeLibrary(hl.dll)
  64. if err != nil {
  65. slog.Warn("failed to unload amdhip64.dll", "error", err)
  66. }
  67. hl.dll = 0
  68. }
  69. func (hl *HipLib) AMDDriverVersion() (driverMajor, driverMinor int, err error) {
  70. if hl.dll == 0 {
  71. return 0, 0, fmt.Errorf("dll has been unloaded")
  72. }
  73. var version int
  74. status, _, err := syscall.SyscallN(hl.hipDriverGetVersion, uintptr(unsafe.Pointer(&version)))
  75. if status != hipSuccess {
  76. return 0, 0, fmt.Errorf("failed call to hipDriverGetVersion: %d %s", status, err)
  77. }
  78. slog.Debug("hipDriverGetVersion", "version", version)
  79. driverMajor = version / 10000000
  80. driverMinor = (version - (driverMajor * 10000000)) / 100000
  81. return driverMajor, driverMinor, nil
  82. }
  83. func (hl *HipLib) HipGetDeviceCount() int {
  84. if hl.dll == 0 {
  85. slog.Error("dll has been unloaded")
  86. return 0
  87. }
  88. var count int
  89. status, _, err := syscall.SyscallN(hl.hipGetDeviceCount, uintptr(unsafe.Pointer(&count)))
  90. if status == hipErrorNoDevice {
  91. slog.Info("AMD ROCm reports no devices found")
  92. return 0
  93. }
  94. if status != hipSuccess {
  95. slog.Warn("failed call to hipGetDeviceCount", "status", status, "error", err)
  96. }
  97. return count
  98. }
  99. func (hl *HipLib) HipSetDevice(device int) error {
  100. if hl.dll == 0 {
  101. return fmt.Errorf("dll has been unloaded")
  102. }
  103. status, _, err := syscall.SyscallN(hl.hipSetDevice, uintptr(device))
  104. if status != hipSuccess {
  105. return fmt.Errorf("failed call to hipSetDevice: %d %s", status, err)
  106. }
  107. return nil
  108. }
  109. func (hl *HipLib) HipGetDeviceProperties(device int) (*hipDevicePropMinimal, error) {
  110. if hl.dll == 0 {
  111. return nil, fmt.Errorf("dll has been unloaded")
  112. }
  113. var props hipDevicePropMinimal
  114. status, _, err := syscall.SyscallN(hl.hipGetDeviceProperties, uintptr(unsafe.Pointer(&props)), uintptr(device))
  115. if status != hipSuccess {
  116. return nil, fmt.Errorf("failed call to hipGetDeviceProperties: %d %s", status, err)
  117. }
  118. return &props, nil
  119. }
  120. // free, total, err
  121. func (hl *HipLib) HipMemGetInfo() (uint64, uint64, error) {
  122. if hl.dll == 0 {
  123. return 0, 0, fmt.Errorf("dll has been unloaded")
  124. }
  125. var totalMemory uint64
  126. var freeMemory uint64
  127. status, _, err := syscall.SyscallN(hl.hipMemGetInfo, uintptr(unsafe.Pointer(&freeMemory)), uintptr(unsafe.Pointer(&totalMemory)))
  128. if status != hipSuccess {
  129. return 0, 0, fmt.Errorf("failed call to hipMemGetInfo: %d %s", status, err)
  130. }
  131. return freeMemory, totalMemory, nil
  132. }