amd_hip_windows.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. package gpu
  2. import (
  3. "fmt"
  4. "log/slog"
  5. "strconv"
  6. "syscall"
  7. "unsafe"
  8. "golang.org/x/sys/windows"
  9. )
  10. const (
  11. hipSuccess = 0
  12. hipErrorNoDevice = 100
  13. )
  14. type hipDevicePropMinimal struct {
  15. Name [256]byte
  16. unused1 [140]byte
  17. GcnArchName [256]byte // gfx####
  18. iGPU int // Doesn't seem to actually report correctly
  19. unused2 [128]byte
  20. }
  21. // Wrap the amdhip64.dll library for GPU discovery
  22. type HipLib struct {
  23. dll windows.Handle
  24. hipGetDeviceCount uintptr
  25. hipGetDeviceProperties uintptr
  26. hipMemGetInfo uintptr
  27. hipSetDevice uintptr
  28. hipDriverGetVersion uintptr
  29. }
  30. func NewHipLib() (*HipLib, error) {
  31. h, err := windows.LoadLibrary("amdhip64.dll")
  32. if err != nil {
  33. return nil, fmt.Errorf("unable to load amdhip64.dll: %w", err)
  34. }
  35. hl := &HipLib{}
  36. hl.dll = h
  37. hl.hipGetDeviceCount, err = windows.GetProcAddress(hl.dll, "hipGetDeviceCount")
  38. if err != nil {
  39. return nil, err
  40. }
  41. hl.hipGetDeviceProperties, err = windows.GetProcAddress(hl.dll, "hipGetDeviceProperties")
  42. if err != nil {
  43. return nil, err
  44. }
  45. hl.hipMemGetInfo, err = windows.GetProcAddress(hl.dll, "hipMemGetInfo")
  46. if err != nil {
  47. return nil, err
  48. }
  49. hl.hipSetDevice, err = windows.GetProcAddress(hl.dll, "hipSetDevice")
  50. if err != nil {
  51. return nil, err
  52. }
  53. hl.hipDriverGetVersion, err = windows.GetProcAddress(hl.dll, "hipDriverGetVersion")
  54. if err != nil {
  55. return nil, err
  56. }
  57. return hl, nil
  58. }
  59. // The hip library only evaluates the HIP_VISIBLE_DEVICES variable at startup
  60. // so we have to unload/reset the library after we do our initial discovery
  61. // to make sure our updates to that variable are processed by llama.cpp
  62. func (hl *HipLib) Release() {
  63. err := windows.FreeLibrary(hl.dll)
  64. if err != nil {
  65. slog.Warn(fmt.Sprintf("failed to unload amdhip64.dll: %s", err))
  66. }
  67. hl.dll = 0
  68. }
  69. func (hl *HipLib) AMDDriverVersion() (string, error) {
  70. if hl.dll == 0 {
  71. return "", fmt.Errorf("dll has been unloaded")
  72. }
  73. var version int
  74. status, _, err := syscall.SyscallN(hl.hipDriverGetVersion, uintptr(unsafe.Pointer(&version)))
  75. if status != hipSuccess {
  76. return "", fmt.Errorf("failed call to hipDriverGetVersion: %d %s", status, err)
  77. }
  78. return strconv.Itoa(version), nil
  79. }
  80. func (hl *HipLib) HipGetDeviceCount() int {
  81. if hl.dll == 0 {
  82. slog.Error("dll has been unloaded")
  83. return 0
  84. }
  85. var count int
  86. status, _, err := syscall.SyscallN(hl.hipGetDeviceCount, uintptr(unsafe.Pointer(&count)))
  87. if status == hipErrorNoDevice {
  88. slog.Info("AMD ROCm reports no devices found")
  89. return 0
  90. }
  91. if status != hipSuccess {
  92. slog.Warn(fmt.Sprintf("failed call to hipGetDeviceCount: %d %s", status, err))
  93. }
  94. return count
  95. }
  96. func (hl *HipLib) HipSetDevice(device int) error {
  97. if hl.dll == 0 {
  98. return fmt.Errorf("dll has been unloaded")
  99. }
  100. status, _, err := syscall.SyscallN(hl.hipSetDevice, uintptr(device))
  101. if status != hipSuccess {
  102. return fmt.Errorf("failed call to hipSetDevice: %d %s", status, err)
  103. }
  104. return nil
  105. }
  106. func (hl *HipLib) HipGetDeviceProperties(device int) (*hipDevicePropMinimal, error) {
  107. if hl.dll == 0 {
  108. return nil, fmt.Errorf("dll has been unloaded")
  109. }
  110. var props hipDevicePropMinimal
  111. status, _, err := syscall.SyscallN(hl.hipGetDeviceProperties, uintptr(unsafe.Pointer(&props)), uintptr(device))
  112. if status != hipSuccess {
  113. return nil, fmt.Errorf("failed call to hipGetDeviceProperties: %d %s", status, err)
  114. }
  115. return &props, nil
  116. }
  117. // free, total, err
  118. func (hl *HipLib) HipMemGetInfo() (uint64, uint64, error) {
  119. if hl.dll == 0 {
  120. return 0, 0, fmt.Errorf("dll has been unloaded")
  121. }
  122. var totalMemory uint64
  123. var freeMemory uint64
  124. status, _, err := syscall.SyscallN(hl.hipMemGetInfo, uintptr(unsafe.Pointer(&freeMemory)), uintptr(unsafe.Pointer(&totalMemory)))
  125. if status != hipSuccess {
  126. return 0, 0, fmt.Errorf("failed call to hipMemGetInfo: %d %s", status, err)
  127. }
  128. return freeMemory, totalMemory, nil
  129. }