main.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. import asyncio
  2. import socketio
  3. import logging
  4. import sys
  5. import time
  6. from open_webui.models.users import Users
  7. from open_webui.env import (
  8. ENABLE_WEBSOCKET_SUPPORT,
  9. WEBSOCKET_MANAGER,
  10. WEBSOCKET_REDIS_URL,
  11. )
  12. from open_webui.utils.auth import decode_token
  13. from open_webui.socket.utils import RedisDict, RedisLock
  14. from open_webui.env import (
  15. GLOBAL_LOG_LEVEL,
  16. SRC_LOG_LEVELS,
  17. )
  18. logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
  19. log = logging.getLogger(__name__)
  20. log.setLevel(SRC_LOG_LEVELS["SOCKET"])
  21. if WEBSOCKET_MANAGER == "redis":
  22. mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL)
  23. sio = socketio.AsyncServer(
  24. cors_allowed_origins=[],
  25. async_mode="asgi",
  26. transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
  27. allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
  28. always_connect=True,
  29. client_manager=mgr,
  30. )
  31. else:
  32. sio = socketio.AsyncServer(
  33. cors_allowed_origins=[],
  34. async_mode="asgi",
  35. transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
  36. allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
  37. always_connect=True,
  38. )
  39. # Timeout duration in seconds
  40. TIMEOUT_DURATION = 3
  41. # Dictionary to maintain the user pool
  42. run_cleanup = True
  43. if WEBSOCKET_MANAGER == "redis":
  44. log.debug("Using Redis to manage websockets.")
  45. SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL)
  46. USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL)
  47. USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL)
  48. clean_up_lock = RedisLock(
  49. redis_url=WEBSOCKET_REDIS_URL,
  50. lock_name="usage_cleanup_lock",
  51. timeout_secs=TIMEOUT_DURATION * 2,
  52. )
  53. run_cleanup = clean_up_lock.aquire_lock()
  54. renew_func = clean_up_lock.renew_lock
  55. release_func = clean_up_lock.release_lock
  56. else:
  57. SESSION_POOL = {}
  58. USER_POOL = {}
  59. USAGE_POOL = {}
  60. release_func = renew_func = lambda: True
  61. async def periodic_usage_pool_cleanup():
  62. if not run_cleanup:
  63. log.debug("Usage pool cleanup lock already exists. Not running it.")
  64. return
  65. log.debug("Running periodic_usage_pool_cleanup")
  66. try:
  67. while True:
  68. if not renew_func():
  69. log.error(f"Unable to renew cleanup lock. Exiting usage pool cleanup.")
  70. raise Exception("Unable to renew usage pool cleanup lock.")
  71. now = int(time.time())
  72. send_usage = False
  73. for model_id, connections in list(USAGE_POOL.items()):
  74. # Creating a list of sids to remove if they have timed out
  75. expired_sids = [
  76. sid
  77. for sid, details in connections.items()
  78. if now - details["updated_at"] > TIMEOUT_DURATION
  79. ]
  80. for sid in expired_sids:
  81. del connections[sid]
  82. if not connections:
  83. log.debug(f"Cleaning up model {model_id} from usage pool")
  84. del USAGE_POOL[model_id]
  85. else:
  86. USAGE_POOL[model_id] = connections
  87. send_usage = True
  88. if send_usage:
  89. # Emit updated usage information after cleaning
  90. await sio.emit("usage", {"models": get_models_in_use()})
  91. await asyncio.sleep(TIMEOUT_DURATION)
  92. finally:
  93. release_func()
  94. app = socketio.ASGIApp(
  95. sio,
  96. socketio_path="/ws/socket.io",
  97. )
  98. def get_models_in_use():
  99. # List models that are currently in use
  100. models_in_use = list(USAGE_POOL.keys())
  101. return models_in_use
  102. @sio.on("usage")
  103. async def usage(sid, data):
  104. model_id = data["model"]
  105. # Record the timestamp for the last update
  106. current_time = int(time.time())
  107. # Store the new usage data and task
  108. USAGE_POOL[model_id] = {
  109. **(USAGE_POOL[model_id] if model_id in USAGE_POOL else {}),
  110. sid: {"updated_at": current_time},
  111. }
  112. # Broadcast the usage data to all clients
  113. await sio.emit("usage", {"models": get_models_in_use()})
  114. @sio.event
  115. async def connect(sid, environ, auth):
  116. user = None
  117. if auth and "token" in auth:
  118. data = decode_token(auth["token"])
  119. if data is not None and "id" in data:
  120. user = Users.get_user_by_id(data["id"])
  121. if user:
  122. SESSION_POOL[sid] = user.id
  123. if user.id in USER_POOL:
  124. USER_POOL[user.id].append(sid)
  125. else:
  126. USER_POOL[user.id] = [sid]
  127. # print(f"user {user.name}({user.id}) connected with session ID {sid}")
  128. await sio.emit("user-count", {"count": len(USER_POOL.items())})
  129. await sio.emit("usage", {"models": get_models_in_use()})
  130. @sio.on("user-join")
  131. async def user_join(sid, data):
  132. # print("user-join", sid, data)
  133. auth = data["auth"] if "auth" in data else None
  134. if not auth or "token" not in auth:
  135. return
  136. data = decode_token(auth["token"])
  137. if data is None or "id" not in data:
  138. return
  139. user = Users.get_user_by_id(data["id"])
  140. if not user:
  141. return
  142. SESSION_POOL[sid] = user.id
  143. if user.id in USER_POOL:
  144. USER_POOL[user.id].append(sid)
  145. else:
  146. USER_POOL[user.id] = [sid]
  147. # print(f"user {user.name}({user.id}) connected with session ID {sid}")
  148. await sio.emit("user-count", {"count": len(USER_POOL.items())})
  149. @sio.on("user-count")
  150. async def user_count(sid):
  151. await sio.emit("user-count", {"count": len(USER_POOL.items())})
  152. @sio.on("chat")
  153. async def chat(sid, data):
  154. print("chat", sid, SESSION_POOL[sid], data)
  155. @sio.event
  156. async def disconnect(sid):
  157. if sid in SESSION_POOL:
  158. user_id = SESSION_POOL[sid]
  159. del SESSION_POOL[sid]
  160. USER_POOL[user_id] = [_sid for _sid in USER_POOL[user_id] if _sid != sid]
  161. if len(USER_POOL[user_id]) == 0:
  162. del USER_POOL[user_id]
  163. await sio.emit("user-count", {"count": len(USER_POOL)})
  164. else:
  165. pass
  166. # print(f"Unknown session ID {sid} disconnected")
  167. def get_event_emitter(request_info):
  168. async def __event_emitter__(event_data):
  169. user_id = request_info["user_id"]
  170. session_ids = USER_POOL.get(user_id, [])
  171. for session_id in session_ids:
  172. await sio.emit(
  173. "chat-events",
  174. {
  175. "chat_id": request_info["chat_id"],
  176. "message_id": request_info["message_id"],
  177. "data": event_data,
  178. },
  179. to=session_id,
  180. )
  181. return __event_emitter__
  182. def get_event_call(request_info):
  183. async def __event_call__(event_data):
  184. response = await sio.call(
  185. "chat-events",
  186. {
  187. "chat_id": request_info["chat_id"],
  188. "message_id": request_info["message_id"],
  189. "data": event_data,
  190. },
  191. to=request_info["session_id"],
  192. )
  193. return response
  194. return __event_call__