amd.go 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. package gpu
  2. import (
  3. "bufio"
  4. "fmt"
  5. "io"
  6. "log/slog"
  7. "os"
  8. "path/filepath"
  9. "strconv"
  10. "strings"
  11. )
  12. // TODO - windows vs. non-windows vs darwin
  13. // Discovery logic for AMD/ROCm GPUs
  14. const (
  15. DriverVersionFile = "/sys/module/amdgpu/version"
  16. GPUPropertiesFileGlob = "/sys/class/kfd/kfd/topology/nodes/*/properties"
  17. // TODO probably break these down per GPU to make the logic simpler
  18. GPUTotalMemoryFileGlob = "/sys/class/kfd/kfd/topology/nodes/*/mem_banks/*/properties" // size_in_bytes line
  19. GPUUsedMemoryFileGlob = "/sys/class/kfd/kfd/topology/nodes/*/mem_banks/*/used_memory"
  20. )
  21. func AMDDetected() bool {
  22. _, err := AMDDriverVersion()
  23. return err == nil
  24. }
  25. func AMDDriverVersion() (string, error) {
  26. _, err := os.Stat(DriverVersionFile)
  27. if err != nil {
  28. return "", err
  29. }
  30. fp, err := os.Open(DriverVersionFile)
  31. if err != nil {
  32. return "", err
  33. }
  34. defer fp.Close()
  35. verString, err := io.ReadAll(fp)
  36. if err != nil {
  37. return "", err
  38. }
  39. return strings.TrimSpace(string(verString)), nil
  40. }
  41. func AMDGFXVersions() []Version {
  42. res := []Version{}
  43. matches, _ := filepath.Glob(GPUPropertiesFileGlob)
  44. for _, match := range matches {
  45. fp, err := os.Open(match)
  46. if err != nil {
  47. slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err))
  48. continue
  49. }
  50. defer fp.Close()
  51. scanner := bufio.NewScanner(fp)
  52. // optionally, resize scanner's capacity for lines over 64K, see next example
  53. for scanner.Scan() {
  54. line := strings.TrimSpace(scanner.Text())
  55. if strings.HasPrefix(line, "gfx_target_version") {
  56. ver := strings.Fields(line)
  57. if len(ver) != 2 || len(ver[1]) < 5 {
  58. slog.Debug("malformed " + line)
  59. continue
  60. }
  61. l := len(ver[1])
  62. patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32)
  63. minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
  64. major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32)
  65. if err1 != nil || err2 != nil || err3 != nil {
  66. slog.Debug("malformed int " + line)
  67. continue
  68. }
  69. res = append(res, Version{
  70. Major: uint(major),
  71. Minor: uint(minor),
  72. Patch: uint(patch),
  73. })
  74. }
  75. }
  76. }
  77. return res
  78. }
  79. func (v Version) ToGFXString() string {
  80. return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch)
  81. }