@@ -54,6 +54,9 @@ else:
DEVICE_TYPE = "cpu"
+if torch.backends.mps.is_available() and torch.backends.mps.is_built():
+ DEVICE_TYPE = "mps"
+
####################################
# LOGGING