Co-Authored-By: Rich Tong <1782087+richtong@users.noreply.github.com>
@@ -54,6 +54,8 @@ else:
DEVICE_TYPE = "cpu"
try:
+ import torch
+
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
DEVICE_TYPE = "mps"
except Exception: