accelerator_rocm.go 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. //go:build rocm
  2. package llm
  3. import (
  4. "bytes"
  5. "encoding/csv"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "log"
  10. "os"
  11. "os/exec"
  12. "path"
  13. "path/filepath"
  14. "strconv"
  15. "strings"
  16. )
  17. var errNoAccel = errors.New("rocm-smi command failed")
  18. // acceleratedRunner returns the runner for this accelerator given the provided buildPath string.
  19. func acceleratedRunner(buildPath string) []ModelRunner {
  20. return []ModelRunner{
  21. ModelRunner{
  22. Path: path.Join(buildPath, "rocm", "bin", "ollama-runner"),
  23. Accelerated: true,
  24. },
  25. }
  26. }
  27. // CheckVRAM returns the available VRAM in MiB on Linux machines with AMD GPUs
  28. func CheckVRAM() (int64, error) {
  29. rocmHome := os.Getenv("ROCM_PATH")
  30. if rocmHome == "" {
  31. rocmHome = os.Getenv("ROCM_HOME")
  32. }
  33. if rocmHome == "" {
  34. log.Println("warning: ROCM_PATH is not set. Trying a likely fallback path, but it is recommended to set this variable in the environment.")
  35. rocmHome = "/opt/rocm"
  36. }
  37. cmd := exec.Command(filepath.Join(rocmHome, "bin/rocm-smi"), "--showmeminfo", "VRAM", "--csv")
  38. var stdout bytes.Buffer
  39. cmd.Stdout = &stdout
  40. err := cmd.Run()
  41. if err != nil {
  42. return 0, errNoAccel
  43. }
  44. csvData := csv.NewReader(&stdout)
  45. // llama.cpp or ROCm don't seem to understand splitting the VRAM allocations across them properly, so try to find the biggest card instead :(. FIXME.
  46. totalBiggestCard := int64(0)
  47. bigCardName := ""
  48. for {
  49. record, err := csvData.Read()
  50. if err == io.EOF {
  51. break
  52. }
  53. if err != nil {
  54. return 0, fmt.Errorf("failed to parse available VRAM: %v", err)
  55. }
  56. if !strings.HasPrefix(record[0], "card") {
  57. continue
  58. }
  59. cardTotal, err := strconv.ParseInt(record[1], 10, 64)
  60. if err != nil {
  61. return 0, err
  62. }
  63. cardUsed, err := strconv.ParseInt(record[2], 10, 64)
  64. if err != nil {
  65. return 0, err
  66. }
  67. possible := (cardTotal - cardUsed)
  68. log.Printf("ROCm found %d MiB of available VRAM on device %q", possible/1024/1024, record[0])
  69. if possible > totalBiggestCard {
  70. totalBiggestCard = possible
  71. bigCardName = record[0]
  72. }
  73. }
  74. if totalBiggestCard == 0 {
  75. log.Printf("found ROCm GPU but failed to parse free VRAM!")
  76. return 0, errNoAccel
  77. }
  78. log.Printf("ROCm selecting device %q", bigCardName)
  79. return totalBiggestCard, nil
  80. }