Browse Source

feat: add ENABLE_WEBSOCKET_SUPPORT to force socket.io to ignore websocket upgrades

Jun Siang Cheah 8 months ago
parent
commit
698976add0
2 changed files with 15 additions and 1 deletions
  1. 9 1
      backend/open_webui/apps/socket/main.py
  2. 6 0
      backend/open_webui/config.py

+ 9 - 1
backend/open_webui/apps/socket/main.py

@@ -2,9 +2,17 @@ import asyncio
 
 import socketio
 from open_webui.apps.webui.models.users import Users
+from open_webui.config import ENABLE_WEBSOCKET_SUPPORT
 from open_webui.utils.utils import decode_token
 
-sio = socketio.AsyncServer(cors_allowed_origins=[], async_mode="asgi")
+sio = socketio.AsyncServer(
+    cors_allowed_origins=[],
+    async_mode="asgi",
+    transports=(
+        ["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT.value else ["polling"]
+    ),
+    allow_upgrades=ENABLE_WEBSOCKET_SUPPORT.value,
+)
 app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io")
 
 # Dictionary to maintain the user pool

+ 6 - 0
backend/open_webui/config.py

@@ -810,6 +810,12 @@ ENABLE_MESSAGE_RATING = PersistentConfig(
     os.environ.get("ENABLE_MESSAGE_RATING", "True").lower() == "true",
 )
 
+ENABLE_WEBSOCKET_SUPPORT = PersistentConfig(
+    "ENABLE_WEBSOCKET_SUPPORT",
+    "ui.enable_websocket_support",
+    os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true",
+)
+
 
 def validate_cors_origins(origins):
     for origin in origins: