Kaynağa Gözat

enh: socket full redis support

Timothy J. Baek 7 ay önce
ebeveyn
işleme
5f84145a2d

+ 69 - 62
backend/open_webui/apps/socket/main.py

@@ -1,6 +1,7 @@
 import asyncio
-
 import socketio
+import time
+
 from open_webui.apps.webui.models.users import Users
 from open_webui.env import (
     ENABLE_WEBSOCKET_SUPPORT,
@@ -8,6 +9,7 @@ from open_webui.env import (
     WEBSOCKET_REDIS_URL,
 )
 from open_webui.utils.utils import decode_token
+from open_webui.apps.socket.utils import RedisDict
 
 
 if WEBSOCKET_MANAGER == "redis":
@@ -38,13 +40,72 @@ app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io")
 
 # Dictionary to maintain the user pool
 
-SESSION_POOL = {}
-USER_POOL = {}
-USAGE_POOL = {}
+if WEBSOCKET_MANAGER == "redis":
+    SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL)
+    USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL)
+    USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL)
+else:
+    SESSION_POOL = {}
+    USER_POOL = {}
+    USAGE_POOL = {}
+
+
 # Timeout duration in seconds
 TIMEOUT_DURATION = 3
 
 
+async def periodic_usage_pool_cleanup():
+    while True:
+        now = int(time.time())
+        print("Cleaning up usage pool", now)
+        for model_id, connections in list(USAGE_POOL.items()):
+            # Creating a list of sids to remove if they have timed out
+            expired_sids = [
+                sid
+                for sid, details in connections.items()
+                if now - details["updated_at"] > TIMEOUT_DURATION
+            ]
+
+            for sid in expired_sids:
+                del connections[sid]
+
+            if not connections:
+                del USAGE_POOL[model_id]
+            else:
+                USAGE_POOL[model_id] = connections
+
+            # Emit updated usage information after cleaning
+            await sio.emit("usage", {"models": get_models_in_use()})
+
+        await asyncio.sleep(TIMEOUT_DURATION)
+
+
+# Start the cleanup task when your app starts
+asyncio.create_task(periodic_usage_pool_cleanup())
+
+
+def get_models_in_use():
+    # List models that are currently in use
+    models_in_use = list(USAGE_POOL.keys())
+    return models_in_use
+
+
+@sio.on("usage")
+async def usage(sid, data):
+    model_id = data["model"]
+    # Record the timestamp for the last update
+    current_time = int(time.time())
+
+    # Store the new usage data and task
+    USAGE_POOL[model_id] = {
+        **(USAGE_POOL[model_id] if model_id in USAGE_POOL else {}),
+        sid: {"updated_at": current_time},
+    }
+
+    # Broadcast the usage data to all clients
+    await sio.emit("usage", {"models": get_models_in_use()})
+
+
 @sio.event
 async def connect(sid, environ, auth):
     user = None
@@ -62,8 +123,7 @@ async def connect(sid, environ, auth):
                 USER_POOL[user.id] = [sid]
 
             # print(f"user {user.name}({user.id}) connected with session ID {sid}")
-
-            await sio.emit("user-count", {"count": len(set(USER_POOL))})
+            await sio.emit("user-count", {"count": len(USER_POOL.items())})
             await sio.emit("usage", {"models": get_models_in_use()})
 
 
@@ -91,65 +151,12 @@ async def user_join(sid, data):
 
     # print(f"user {user.name}({user.id}) connected with session ID {sid}")
 
-    await sio.emit("user-count", {"count": len(set(USER_POOL))})
+    await sio.emit("user-count", {"count": len(USER_POOL.items())})
 
 
 @sio.on("user-count")
 async def user_count(sid):
-    await sio.emit("user-count", {"count": len(set(USER_POOL))})
-
-
-def get_models_in_use():
-    # Aggregate all models in use
-    models_in_use = []
-    for model_id, data in USAGE_POOL.items():
-        models_in_use.append(model_id)
-
-    return models_in_use
-
-
-@sio.on("usage")
-async def usage(sid, data):
-    model_id = data["model"]
-
-    # Cancel previous callback if there is one
-    if model_id in USAGE_POOL:
-        USAGE_POOL[model_id]["callback"].cancel()
-
-    # Store the new usage data and task
-
-    if model_id in USAGE_POOL:
-        USAGE_POOL[model_id]["sids"].append(sid)
-        USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))
-
-    else:
-        USAGE_POOL[model_id] = {"sids": [sid]}
-
-    # Schedule a task to remove the usage data after TIMEOUT_DURATION
-    USAGE_POOL[model_id]["callback"] = asyncio.create_task(
-        remove_after_timeout(sid, model_id)
-    )
-
-    # Broadcast the usage data to all clients
-    await sio.emit("usage", {"models": get_models_in_use()})
-
-
-async def remove_after_timeout(sid, model_id):
-    try:
-        await asyncio.sleep(TIMEOUT_DURATION)
-        if model_id in USAGE_POOL:
-            # print(USAGE_POOL[model_id]["sids"])
-            USAGE_POOL[model_id]["sids"].remove(sid)
-            USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))
-
-            if len(USAGE_POOL[model_id]["sids"]) == 0:
-                del USAGE_POOL[model_id]
-
-            # Broadcast the usage data to all clients
-            await sio.emit("usage", {"models": get_models_in_use()})
-    except asyncio.CancelledError:
-        # Task was cancelled due to new 'usage' event
-        pass
+    await sio.emit("user-count", {"count": len(USER_POOL.items())})
 
 
 @sio.event
@@ -158,7 +165,7 @@ async def disconnect(sid):
         user_id = SESSION_POOL[sid]
         del SESSION_POOL[sid]
 
-        USER_POOL[user_id].remove(sid)
+        USER_POOL[user_id] = [_sid for _sid in USER_POOL[user_id] if _sid != sid]
 
         if len(USER_POOL[user_id]) == 0:
             del USER_POOL[user_id]

+ 59 - 0
backend/open_webui/apps/socket/utils.py

@@ -0,0 +1,59 @@
+import json
+import redis
+
+
+class RedisDict:
+    def __init__(self, name, redis_url):
+        self.name = name
+        self.redis = redis.Redis.from_url(redis_url, decode_responses=True)
+
+    def __setitem__(self, key, value):
+        serialized_value = json.dumps(value)
+        self.redis.hset(self.name, key, serialized_value)
+
+    def __getitem__(self, key):
+        value = self.redis.hget(self.name, key)
+        if value is None:
+            raise KeyError(key)
+        return json.loads(value)
+
+    def __delitem__(self, key):
+        result = self.redis.hdel(self.name, key)
+        if result == 0:
+            raise KeyError(key)
+
+    def __contains__(self, key):
+        return self.redis.hexists(self.name, key)
+
+    def __len__(self):
+        return self.redis.hlen(self.name)
+
+    def keys(self):
+        return self.redis.hkeys(self.name)
+
+    def values(self):
+        return [json.loads(v) for v in self.redis.hvals(self.name)]
+
+    def items(self):
+        return [(k, json.loads(v)) for k, v in self.redis.hgetall(self.name).items()]
+
+    def get(self, key, default=None):
+        try:
+            return self[key]
+        except KeyError:
+            return default
+
+    def clear(self):
+        self.redis.delete(self.name)
+
+    def update(self, other=None, **kwargs):
+        if other is not None:
+            for k, v in other.items() if hasattr(other, "items") else other:
+                self[k] = v
+        for k, v in kwargs.items():
+            self[k] = v
+
+    def setdefault(self, key, default=None):
+        if key not in self:
+            self[key] = default
+        return self[key]