cuda_common.go 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. //go:build linux || windows
  2. package gpu
  3. import (
  4. "log/slog"
  5. "os"
  6. "regexp"
  7. "runtime"
  8. "strconv"
  9. "strings"
  10. )
  11. // Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed.
  12. // Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
  13. var CudaTegra string = os.Getenv("JETSON_JETPACK")
  14. func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
  15. ids := []string{}
  16. for _, info := range gpuInfo {
  17. if info.Library != "cuda" {
  18. // TODO shouldn't happen if things are wired correctly...
  19. slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library)
  20. continue
  21. }
  22. ids = append(ids, info.ID)
  23. }
  24. return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
  25. }
  26. func cudaVariant(gpuInfo CudaGPUInfo) string {
  27. if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
  28. if CudaTegra != "" {
  29. ver := strings.Split(CudaTegra, ".")
  30. if len(ver) > 0 {
  31. return "jetpack" + ver[0]
  32. }
  33. } else if data, err := os.ReadFile("/etc/nv_tegra_release"); err == nil {
  34. r := regexp.MustCompile(` R(\d+) `)
  35. m := r.FindSubmatch(data)
  36. if len(m) != 2 {
  37. slog.Info("Unexpected format for /etc/nv_tegra_release. Set JETSON_JETPACK to select version")
  38. } else {
  39. if l4t, err := strconv.Atoi(string(m[1])); err == nil {
  40. // Note: mapping from L4t -> JP is inconsistent (can't just subtract 30)
  41. // https://developer.nvidia.com/embedded/jetpack-archive
  42. switch l4t {
  43. case 35:
  44. return "jetpack5"
  45. case 36:
  46. return "jetpack6"
  47. default:
  48. slog.Info("unsupported L4T version", "nv_tegra_release", string(data))
  49. }
  50. }
  51. }
  52. }
  53. }
  54. if gpuInfo.computeMajor < 6 || gpuInfo.DriverMajor < 12 {
  55. return "v11"
  56. }
  57. return "v12"
  58. }