Browse Source

Merge pull request #5400 from thiswillbeyourgithub/fix_fallback_cuda

fix: if cuda is not available fallback to cpu
Timothy Jaeryang Baek 7 tháng trước cách đây
mục cha
commit
2f9f568dd9
2 tập tin đã thay đổi với 29 bổ sung1 xóa
  1. 13 0
      backend/open_webui/__init__.py
  2. 16 1
      backend/open_webui/env.py

+ 13 - 0
backend/open_webui/__init__.py

@@ -39,6 +39,19 @@ def serve(
                 "/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib",
             ]
         )
+        try:
+            import torch
+
+            assert torch.cuda.is_available(), "CUDA not available"
+            typer.echo("CUDA seems to be working")
+        except Exception as e:
+            typer.echo(
+                "Error when testing CUDA but USE_CUDA_DOCKER is true. "
+                "Resetting USE_CUDA_DOCKER to false and removing "
+                f"LD_LIBRARY_PATH modifications: {e}"
+            )
+            os.environ["USE_CUDA_DOCKER"] = "false"
+            os.environ["LD_LIBRARY_PATH"] = ":".join(LD_LIBRARY_PATH)
     import open_webui.main  # we need set environment variables before importing main
 
     uvicorn.run(open_webui.main.app, host=host, port=port, forwarded_allow_ips="*")

+ 16 - 1
backend/open_webui/env.py

@@ -36,7 +36,19 @@ except ImportError:
 USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
 
 if USE_CUDA.lower() == "true":
-    DEVICE_TYPE = "cuda"
+    try:
+        import torch
+
+        assert torch.cuda.is_available(), "CUDA not available"
+        DEVICE_TYPE = "cuda"
+    except Exception as e:
+        cuda_error = (
+            "Error when testing CUDA but USE_CUDA_DOCKER is true. "
+            f"Resetting USE_CUDA_DOCKER to false: {e}"
+        )
+        os.environ["USE_CUDA_DOCKER"] = "false"
+        USE_CUDA = "false"
+        DEVICE_TYPE = "cpu"
 else:
     DEVICE_TYPE = "cpu"
 
@@ -56,6 +68,9 @@ else:
 log = logging.getLogger(__name__)
 log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
 
+if "cuda_error" in locals():
+    log.exception(cuda_error)
+
 log_sources = [
     "AUDIO",
     "COMFYUI",