amd.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. package gpu
  2. import (
  3. "bufio"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "log/slog"
  8. "os"
  9. "path/filepath"
  10. "strconv"
  11. "strings"
  12. )
  13. // TODO - windows vs. non-windows vs darwin
  14. // Discovery logic for AMD/ROCm GPUs
  15. const (
  16. DriverVersionFile = "/sys/module/amdgpu/version"
  17. GPUPropertiesFileGlob = "/sys/class/kfd/kfd/topology/nodes/*/properties"
  18. // TODO probably break these down per GPU to make the logic simpler
  19. GPUTotalMemoryFileGlob = "/sys/class/kfd/kfd/topology/nodes/*/mem_banks/*/properties" // size_in_bytes line
  20. GPUUsedMemoryFileGlob = "/sys/class/kfd/kfd/topology/nodes/*/mem_banks/*/used_memory"
  21. )
  22. func AMDDetected() bool {
  23. // Some driver versions (older?) don't have a version file, so just lookup the parent dir
  24. sysfsDir := filepath.Dir(DriverVersionFile)
  25. _, err := os.Stat(sysfsDir)
  26. if errors.Is(err, os.ErrNotExist) {
  27. slog.Debug("amd driver not detected " + sysfsDir)
  28. return false
  29. } else if err != nil {
  30. slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err))
  31. return false
  32. }
  33. return true
  34. }
  35. func AMDDriverVersion() (string, error) {
  36. _, err := os.Stat(DriverVersionFile)
  37. if err != nil {
  38. return "", fmt.Errorf("amdgpu file stat error: %s %w", DriverVersionFile, err)
  39. }
  40. fp, err := os.Open(DriverVersionFile)
  41. if err != nil {
  42. return "", err
  43. }
  44. defer fp.Close()
  45. verString, err := io.ReadAll(fp)
  46. if err != nil {
  47. return "", err
  48. }
  49. return strings.TrimSpace(string(verString)), nil
  50. }
  51. func AMDGFXVersions() []Version {
  52. res := []Version{}
  53. matches, _ := filepath.Glob(GPUPropertiesFileGlob)
  54. for _, match := range matches {
  55. fp, err := os.Open(match)
  56. if err != nil {
  57. slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err))
  58. continue
  59. }
  60. defer fp.Close()
  61. scanner := bufio.NewScanner(fp)
  62. // optionally, resize scanner's capacity for lines over 64K, see next example
  63. for scanner.Scan() {
  64. line := strings.TrimSpace(scanner.Text())
  65. if strings.HasPrefix(line, "gfx_target_version") {
  66. ver := strings.Fields(line)
  67. if len(ver) != 2 || len(ver[1]) < 5 {
  68. slog.Debug("malformed " + line)
  69. continue
  70. }
  71. l := len(ver[1])
  72. patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32)
  73. minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
  74. major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32)
  75. if err1 != nil || err2 != nil || err3 != nil {
  76. slog.Debug("malformed int " + line)
  77. continue
  78. }
  79. res = append(res, Version{
  80. Major: uint(major),
  81. Minor: uint(minor),
  82. Patch: uint(patch),
  83. })
  84. }
  85. }
  86. }
  87. return res
  88. }
  89. func (v Version) ToGFXString() string {
  90. return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch)
  91. }