|
@@ -1,6 +1,7 @@
|
|
|
import asyncio
|
|
|
-
|
|
|
import socketio
|
|
|
+import time
|
|
|
+
|
|
|
from open_webui.apps.webui.models.users import Users
|
|
|
from open_webui.env import (
|
|
|
ENABLE_WEBSOCKET_SUPPORT,
|
|
@@ -8,6 +9,7 @@ from open_webui.env import (
|
|
|
WEBSOCKET_REDIS_URL,
|
|
|
)
|
|
|
from open_webui.utils.utils import decode_token
|
|
|
+from open_webui.apps.socket.utils import RedisDict
|
|
|
|
|
|
|
|
|
if WEBSOCKET_MANAGER == "redis":
|
|
@@ -38,13 +40,72 @@ app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io")
|
|
|
|
|
|
# Dictionary to maintain the user pool
|
|
|
|
|
|
-SESSION_POOL = {}
|
|
|
-USER_POOL = {}
|
|
|
-USAGE_POOL = {}
|
|
|
+if WEBSOCKET_MANAGER == "redis":
|
|
|
+ SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL)
|
|
|
+ USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL)
|
|
|
+ USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL)
|
|
|
+else:
|
|
|
+ SESSION_POOL = {}
|
|
|
+ USER_POOL = {}
|
|
|
+ USAGE_POOL = {}
|
|
|
+
|
|
|
+
|
|
|
# Timeout duration in seconds
|
|
|
TIMEOUT_DURATION = 3
|
|
|
|
|
|
|
|
|
+async def periodic_usage_pool_cleanup():
|
|
|
+ while True:
|
|
|
+ now = int(time.time())
|
|
|
+ print("Cleaning up usage pool", now)
|
|
|
+ for model_id, connections in list(USAGE_POOL.items()):
|
|
|
+ # Creating a list of sids to remove if they have timed out
|
|
|
+ expired_sids = [
|
|
|
+ sid
|
|
|
+ for sid, details in connections.items()
|
|
|
+ if now - details["updated_at"] > TIMEOUT_DURATION
|
|
|
+ ]
|
|
|
+
|
|
|
+ for sid in expired_sids:
|
|
|
+ del connections[sid]
|
|
|
+
|
|
|
+ if not connections:
|
|
|
+ del USAGE_POOL[model_id]
|
|
|
+ else:
|
|
|
+ USAGE_POOL[model_id] = connections
|
|
|
+
|
|
|
+ # Emit updated usage information after cleaning
|
|
|
+ await sio.emit("usage", {"models": get_models_in_use()})
|
|
|
+
|
|
|
+ await asyncio.sleep(TIMEOUT_DURATION)
|
|
|
+
|
|
|
+
|
|
|
+# Start the cleanup task when your app starts
|
|
|
+asyncio.create_task(periodic_usage_pool_cleanup())
|
|
|
+
|
|
|
+
|
|
|
+def get_models_in_use():
|
|
|
+ # List models that are currently in use
|
|
|
+ models_in_use = list(USAGE_POOL.keys())
|
|
|
+ return models_in_use
|
|
|
+
|
|
|
+
|
|
|
+@sio.on("usage")
|
|
|
+async def usage(sid, data):
|
|
|
+ model_id = data["model"]
|
|
|
+ # Record the timestamp for the last update
|
|
|
+ current_time = int(time.time())
|
|
|
+
|
|
|
+ # Store the new usage data and task
|
|
|
+ USAGE_POOL[model_id] = {
|
|
|
+ **(USAGE_POOL[model_id] if model_id in USAGE_POOL else {}),
|
|
|
+ sid: {"updated_at": current_time},
|
|
|
+ }
|
|
|
+
|
|
|
+ # Broadcast the usage data to all clients
|
|
|
+ await sio.emit("usage", {"models": get_models_in_use()})
|
|
|
+
|
|
|
+
|
|
|
@sio.event
|
|
|
async def connect(sid, environ, auth):
|
|
|
user = None
|
|
@@ -62,8 +123,7 @@ async def connect(sid, environ, auth):
|
|
|
USER_POOL[user.id] = [sid]
|
|
|
|
|
|
# print(f"user {user.name}({user.id}) connected with session ID {sid}")
|
|
|
-
|
|
|
- await sio.emit("user-count", {"count": len(set(USER_POOL))})
|
|
|
+ await sio.emit("user-count", {"count": len(USER_POOL.items())})
|
|
|
await sio.emit("usage", {"models": get_models_in_use()})
|
|
|
|
|
|
|
|
@@ -91,65 +151,12 @@ async def user_join(sid, data):
|
|
|
|
|
|
# print(f"user {user.name}({user.id}) connected with session ID {sid}")
|
|
|
|
|
|
- await sio.emit("user-count", {"count": len(set(USER_POOL))})
|
|
|
+ await sio.emit("user-count", {"count": len(USER_POOL.items())})
|
|
|
|
|
|
|
|
|
@sio.on("user-count")
|
|
|
async def user_count(sid):
|
|
|
- await sio.emit("user-count", {"count": len(set(USER_POOL))})
|
|
|
-
|
|
|
-
|
|
|
-def get_models_in_use():
|
|
|
- # Aggregate all models in use
|
|
|
- models_in_use = []
|
|
|
- for model_id, data in USAGE_POOL.items():
|
|
|
- models_in_use.append(model_id)
|
|
|
-
|
|
|
- return models_in_use
|
|
|
-
|
|
|
-
|
|
|
-@sio.on("usage")
|
|
|
-async def usage(sid, data):
|
|
|
- model_id = data["model"]
|
|
|
-
|
|
|
- # Cancel previous callback if there is one
|
|
|
- if model_id in USAGE_POOL:
|
|
|
- USAGE_POOL[model_id]["callback"].cancel()
|
|
|
-
|
|
|
- # Store the new usage data and task
|
|
|
-
|
|
|
- if model_id in USAGE_POOL:
|
|
|
- USAGE_POOL[model_id]["sids"].append(sid)
|
|
|
- USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))
|
|
|
-
|
|
|
- else:
|
|
|
- USAGE_POOL[model_id] = {"sids": [sid]}
|
|
|
-
|
|
|
- # Schedule a task to remove the usage data after TIMEOUT_DURATION
|
|
|
- USAGE_POOL[model_id]["callback"] = asyncio.create_task(
|
|
|
- remove_after_timeout(sid, model_id)
|
|
|
- )
|
|
|
-
|
|
|
- # Broadcast the usage data to all clients
|
|
|
- await sio.emit("usage", {"models": get_models_in_use()})
|
|
|
-
|
|
|
-
|
|
|
-async def remove_after_timeout(sid, model_id):
|
|
|
- try:
|
|
|
- await asyncio.sleep(TIMEOUT_DURATION)
|
|
|
- if model_id in USAGE_POOL:
|
|
|
- # print(USAGE_POOL[model_id]["sids"])
|
|
|
- USAGE_POOL[model_id]["sids"].remove(sid)
|
|
|
- USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))
|
|
|
-
|
|
|
- if len(USAGE_POOL[model_id]["sids"]) == 0:
|
|
|
- del USAGE_POOL[model_id]
|
|
|
-
|
|
|
- # Broadcast the usage data to all clients
|
|
|
- await sio.emit("usage", {"models": get_models_in_use()})
|
|
|
- except asyncio.CancelledError:
|
|
|
- # Task was cancelled due to new 'usage' event
|
|
|
- pass
|
|
|
+ await sio.emit("user-count", {"count": len(USER_POOL.items())})
|
|
|
|
|
|
|
|
|
@sio.event
|
|
@@ -158,7 +165,7 @@ async def disconnect(sid):
|
|
|
user_id = SESSION_POOL[sid]
|
|
|
del SESSION_POOL[sid]
|
|
|
|
|
|
- USER_POOL[user_id].remove(sid)
|
|
|
+ USER_POOL[user_id] = [_sid for _sid in USER_POOL[user_id] if _sid != sid]
|
|
|
|
|
|
if len(USER_POOL[user_id]) == 0:
|
|
|
del USER_POOL[user_id]
|