Browse Source

Merge pull request #7887 from jk-f5/disablepolling

Disable Polling Transport When WebSockets Are Enabled and Implement Cleanup Locking Mechanism
Timothy Jaeryang Baek 4 months ago
parent
commit
0523ebcc5e

+ 2 - 0
backend/open_webui/main.py

@@ -267,6 +267,7 @@ from open_webui.env import (
     WEBUI_SESSION_COOKIE_SECURE,
     WEBUI_SESSION_COOKIE_SECURE,
     WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
     WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
     WEBUI_AUTH_TRUSTED_NAME_HEADER,
     WEBUI_AUTH_TRUSTED_NAME_HEADER,
+    ENABLE_WEBSOCKET_SUPPORT,
     BYPASS_MODEL_ACCESS_CONTROL,
     BYPASS_MODEL_ACCESS_CONTROL,
     RESET_CONFIG_ON_START,
     RESET_CONFIG_ON_START,
     OFFLINE_MODE,
     OFFLINE_MODE,
@@ -947,6 +948,7 @@ async def get_app_config(request: Request):
             "enable_api_key": app.state.config.ENABLE_API_KEY,
             "enable_api_key": app.state.config.ENABLE_API_KEY,
             "enable_signup": app.state.config.ENABLE_SIGNUP,
             "enable_signup": app.state.config.ENABLE_SIGNUP,
             "enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
             "enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
+            "enable_websocket": ENABLE_WEBSOCKET_SUPPORT,
             **(
             **(
                 {
                 {
                     "enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,
                     "enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,

+ 56 - 34
backend/open_webui/socket/main.py

@@ -11,7 +11,7 @@ from open_webui.env import (
     WEBSOCKET_REDIS_URL,
     WEBSOCKET_REDIS_URL,
 )
 )
 from open_webui.utils.auth import decode_token
 from open_webui.utils.auth import decode_token
-from open_webui.socket.utils import RedisDict
+from open_webui.socket.utils import RedisDict, RedisLock
 
 
 from open_webui.env import (
 from open_webui.env import (
     GLOBAL_LOG_LEVEL,
     GLOBAL_LOG_LEVEL,
@@ -29,9 +29,7 @@ if WEBSOCKET_MANAGER == "redis":
     sio = socketio.AsyncServer(
     sio = socketio.AsyncServer(
         cors_allowed_origins=[],
         cors_allowed_origins=[],
         async_mode="asgi",
         async_mode="asgi",
-        transports=(
-            ["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]
-        ),
+        transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
         allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
         allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
         always_connect=True,
         always_connect=True,
         client_manager=mgr,
         client_manager=mgr,
@@ -40,54 +38,78 @@ else:
     sio = socketio.AsyncServer(
     sio = socketio.AsyncServer(
         cors_allowed_origins=[],
         cors_allowed_origins=[],
         async_mode="asgi",
         async_mode="asgi",
-        transports=(
-            ["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]
-        ),
+        transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
         allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
         allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
         always_connect=True,
         always_connect=True,
     )
     )
 
 
 
 
+# Timeout duration in seconds
+TIMEOUT_DURATION = 3
+
 # Dictionary to maintain the user pool
 # Dictionary to maintain the user pool
 
 
+run_cleanup = True
 if WEBSOCKET_MANAGER == "redis":
 if WEBSOCKET_MANAGER == "redis":
+    log.debug("Using Redis to manage websockets.")
     SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL)
     SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL)
     USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL)
     USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL)
     USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL)
     USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL)
+
+    clean_up_lock = RedisLock(
+        redis_url=WEBSOCKET_REDIS_URL,
+        lock_name="usage_cleanup_lock",
+        timeout_secs=TIMEOUT_DURATION * 2,
+    )
+    run_cleanup = clean_up_lock.aquire_lock()
+    renew_func = clean_up_lock.renew_lock
+    release_func = clean_up_lock.release_lock
 else:
 else:
     SESSION_POOL = {}
     SESSION_POOL = {}
     USER_POOL = {}
     USER_POOL = {}
     USAGE_POOL = {}
     USAGE_POOL = {}
-
-
-# Timeout duration in seconds
-TIMEOUT_DURATION = 3
+    release_func = renew_func = lambda: True
 
 
 
 
 async def periodic_usage_pool_cleanup():
 async def periodic_usage_pool_cleanup():
-    while True:
-        now = int(time.time())
-        for model_id, connections in list(USAGE_POOL.items()):
-            # Creating a list of sids to remove if they have timed out
-            expired_sids = [
-                sid
-                for sid, details in connections.items()
-                if now - details["updated_at"] > TIMEOUT_DURATION
-            ]
-
-            for sid in expired_sids:
-                del connections[sid]
-
-            if not connections:
-                log.debug(f"Cleaning up model {model_id} from usage pool")
-                del USAGE_POOL[model_id]
-            else:
-                USAGE_POOL[model_id] = connections
-
-            # Emit updated usage information after cleaning
-            await sio.emit("usage", {"models": get_models_in_use()})
-
-        await asyncio.sleep(TIMEOUT_DURATION)
+    if not run_cleanup:
+        log.debug("Usage pool cleanup lock already exists. Not running it.")
+        return
+    log.debug("Running periodic_usage_pool_cleanup")
+    try:
+        while True:
+            if not renew_func():
+                log.error(f"Unable to renew cleanup lock. Exiting usage pool cleanup.")
+                raise Exception("Unable to renew usage pool cleanup lock.")
+
+            now = int(time.time())
+            send_usage = False
+            for model_id, connections in list(USAGE_POOL.items()):
+                # Creating a list of sids to remove if they have timed out
+                expired_sids = [
+                    sid
+                    for sid, details in connections.items()
+                    if now - details["updated_at"] > TIMEOUT_DURATION
+                ]
+
+                for sid in expired_sids:
+                    del connections[sid]
+
+                if not connections:
+                    log.debug(f"Cleaning up model {model_id} from usage pool")
+                    del USAGE_POOL[model_id]
+                else:
+                    USAGE_POOL[model_id] = connections
+
+                send_usage = True
+
+            if send_usage:
+                # Emit updated usage information after cleaning
+                await sio.emit("usage", {"models": get_models_in_use()})
+
+            await asyncio.sleep(TIMEOUT_DURATION)
+    finally:
+        release_func()
 
 
 
 
 app = socketio.ASGIApp(
 app = socketio.ASGIApp(

+ 28 - 0
backend/open_webui/socket/utils.py

@@ -1,5 +1,33 @@
 import json
 import json
 import redis
 import redis
+import uuid
+
+
+class RedisLock:
+    def __init__(self, redis_url, lock_name, timeout_secs):
+        self.lock_name = lock_name
+        self.lock_id = str(uuid.uuid4())
+        self.timeout_secs = timeout_secs
+        self.lock_obtained = False
+        self.redis = redis.Redis.from_url(redis_url, decode_responses=True)
+
+    def aquire_lock(self):
+        # nx=True will only set this key if it _hasn't_ already been set
+        self.lock_obtained = self.redis.set(
+            self.lock_name, self.lock_id, nx=True, ex=self.timeout_secs
+        )
+        return self.lock_obtained
+
+    def renew_lock(self):
+        # xx=True will only set this key if it _has_ already been set
+        return self.redis.set(
+            self.lock_name, self.lock_id, xx=True, ex=self.timeout_secs
+        )
+
+    def release_lock(self):
+        lock_value = self.redis.get(self.lock_name)
+        if lock_value and lock_value.decode("utf-8") == self.lock_id:
+            self.redis.delete(self.lock_name)
 
 
 
 
 class RedisDict:
 class RedisDict:

+ 3 - 2
src/routes/+layout.svelte

@@ -38,13 +38,14 @@
 	let loaded = false;
 	let loaded = false;
 	const BREAKPOINT = 768;
 	const BREAKPOINT = 768;
 
 
-	const setupSocket = () => {
+	const setupSocket = (enableWebsocket) => {
 		const _socket = io(`${WEBUI_BASE_URL}` || undefined, {
 		const _socket = io(`${WEBUI_BASE_URL}` || undefined, {
 			reconnection: true,
 			reconnection: true,
 			reconnectionDelay: 1000,
 			reconnectionDelay: 1000,
 			reconnectionDelayMax: 5000,
 			reconnectionDelayMax: 5000,
 			randomizationFactor: 0.5,
 			randomizationFactor: 0.5,
 			path: '/ws/socket.io',
 			path: '/ws/socket.io',
+			transports: enableWebsocket ? ['websocket'] : ['polling', 'websocket'],
 			auth: { token: localStorage.token }
 			auth: { token: localStorage.token }
 		});
 		});
 
 
@@ -126,7 +127,7 @@
 			await WEBUI_NAME.set(backendConfig.name);
 			await WEBUI_NAME.set(backendConfig.name);
 
 
 			if ($config) {
 			if ($config) {
-				setupSocket();
+				setupSocket($config.features?.enable_websocket ?? true);
 
 
 				if (localStorage.token) {
 				if (localStorage.token) {
 					// Get Session User Info
 					// Get Session User Info