main.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. import asyncio
  2. import socketio
  3. import logging
  4. import sys
  5. import time
  6. from open_webui.models.users import Users
  7. from open_webui.models.channels import Channels
  8. from open_webui.env import (
  9. ENABLE_WEBSOCKET_SUPPORT,
  10. WEBSOCKET_MANAGER,
  11. WEBSOCKET_REDIS_URL,
  12. )
  13. from open_webui.utils.auth import decode_token
  14. from open_webui.socket.utils import RedisDict, RedisLock
  15. from open_webui.env import (
  16. GLOBAL_LOG_LEVEL,
  17. SRC_LOG_LEVELS,
  18. )
  19. logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
  20. log = logging.getLogger(__name__)
  21. log.setLevel(SRC_LOG_LEVELS["SOCKET"])
  22. if WEBSOCKET_MANAGER == "redis":
  23. mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL)
  24. sio = socketio.AsyncServer(
  25. cors_allowed_origins=[],
  26. async_mode="asgi",
  27. transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
  28. allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
  29. always_connect=True,
  30. client_manager=mgr,
  31. )
  32. else:
  33. sio = socketio.AsyncServer(
  34. cors_allowed_origins=[],
  35. async_mode="asgi",
  36. transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
  37. allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
  38. always_connect=True,
  39. )
  40. # Timeout duration in seconds
  41. TIMEOUT_DURATION = 3
  42. # Dictionary to maintain the user pool
  43. if WEBSOCKET_MANAGER == "redis":
  44. log.debug("Using Redis to manage websockets.")
  45. SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL)
  46. USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL)
  47. USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL)
  48. clean_up_lock = RedisLock(
  49. redis_url=WEBSOCKET_REDIS_URL,
  50. lock_name="usage_cleanup_lock",
  51. timeout_secs=TIMEOUT_DURATION * 2,
  52. )
  53. aquire_func = clean_up_lock.aquire_lock
  54. renew_func = clean_up_lock.renew_lock
  55. release_func = clean_up_lock.release_lock
  56. else:
  57. SESSION_POOL = {}
  58. USER_POOL = {}
  59. USAGE_POOL = {}
  60. aquire_func = release_func = renew_func = lambda: True
  61. async def periodic_usage_pool_cleanup():
  62. if not aquire_func():
  63. log.debug("Usage pool cleanup lock already exists. Not running it.")
  64. return
  65. log.debug("Running periodic_usage_pool_cleanup")
  66. try:
  67. while True:
  68. if not renew_func():
  69. log.error(f"Unable to renew cleanup lock. Exiting usage pool cleanup.")
  70. raise Exception("Unable to renew usage pool cleanup lock.")
  71. now = int(time.time())
  72. send_usage = False
  73. for model_id, connections in list(USAGE_POOL.items()):
  74. # Creating a list of sids to remove if they have timed out
  75. expired_sids = [
  76. sid
  77. for sid, details in connections.items()
  78. if now - details["updated_at"] > TIMEOUT_DURATION
  79. ]
  80. for sid in expired_sids:
  81. del connections[sid]
  82. if not connections:
  83. log.debug(f"Cleaning up model {model_id} from usage pool")
  84. del USAGE_POOL[model_id]
  85. else:
  86. USAGE_POOL[model_id] = connections
  87. send_usage = True
  88. if send_usage:
  89. # Emit updated usage information after cleaning
  90. await sio.emit("usage", {"models": get_models_in_use()})
  91. await asyncio.sleep(TIMEOUT_DURATION)
  92. finally:
  93. release_func()
  94. app = socketio.ASGIApp(
  95. sio,
  96. socketio_path="/ws/socket.io",
  97. )
  98. def get_models_in_use():
  99. # List models that are currently in use
  100. models_in_use = list(USAGE_POOL.keys())
  101. return models_in_use
  102. @sio.on("usage")
  103. async def usage(sid, data):
  104. model_id = data["model"]
  105. # Record the timestamp for the last update
  106. current_time = int(time.time())
  107. # Store the new usage data and task
  108. USAGE_POOL[model_id] = {
  109. **(USAGE_POOL[model_id] if model_id in USAGE_POOL else {}),
  110. sid: {"updated_at": current_time},
  111. }
  112. # Broadcast the usage data to all clients
  113. await sio.emit("usage", {"models": get_models_in_use()})
  114. @sio.event
  115. async def connect(sid, environ, auth):
  116. user = None
  117. if auth and "token" in auth:
  118. data = decode_token(auth["token"])
  119. if data is not None and "id" in data:
  120. user = Users.get_user_by_id(data["id"])
  121. if user:
  122. SESSION_POOL[sid] = user.id
  123. if user.id in USER_POOL:
  124. USER_POOL[user.id] = USER_POOL[user.id] + [sid]
  125. else:
  126. USER_POOL[user.id] = [sid]
  127. # print(f"user {user.name}({user.id}) connected with session ID {sid}")
  128. await sio.emit("user-count", {"count": len(USER_POOL.items())})
  129. await sio.emit("usage", {"models": get_models_in_use()})
  130. @sio.on("user-join")
  131. async def user_join(sid, data):
  132. auth = data["auth"] if "auth" in data else None
  133. if not auth or "token" not in auth:
  134. return
  135. data = decode_token(auth["token"])
  136. if data is None or "id" not in data:
  137. return
  138. user = Users.get_user_by_id(data["id"])
  139. if not user:
  140. return
  141. SESSION_POOL[sid] = user.id
  142. if user.id in USER_POOL:
  143. USER_POOL[user.id] = USER_POOL[user.id] + [sid]
  144. else:
  145. USER_POOL[user.id] = [sid]
  146. # Join all the channels
  147. channels = Channels.get_channels_by_user_id(user.id)
  148. log.debug(f"{channels=}")
  149. for channel in channels:
  150. await sio.enter_room(sid, f"channel:{channel.id}")
  151. # print(f"user {user.name}({user.id}) connected with session ID {sid}")
  152. await sio.emit("user-count", {"count": len(USER_POOL.items())})
  153. @sio.on("user-count")
  154. async def user_count(sid):
  155. await sio.emit("user-count", {"count": len(USER_POOL.items())})
  156. @sio.on("chat")
  157. async def chat(sid, data):
  158. print("chat", sid, SESSION_POOL[sid], data)
  159. @sio.event
  160. async def disconnect(sid):
  161. if sid in SESSION_POOL:
  162. user_id = SESSION_POOL[sid]
  163. del SESSION_POOL[sid]
  164. USER_POOL[user_id] = [_sid for _sid in USER_POOL[user_id] if _sid != sid]
  165. if len(USER_POOL[user_id]) == 0:
  166. del USER_POOL[user_id]
  167. await sio.emit("user-count", {"count": len(USER_POOL)})
  168. else:
  169. pass
  170. # print(f"Unknown session ID {sid} disconnected")
  171. def get_event_emitter(request_info):
  172. async def __event_emitter__(event_data):
  173. user_id = request_info["user_id"]
  174. session_ids = list(
  175. set(USER_POOL.get(user_id, []) + [request_info["session_id"]])
  176. )
  177. for session_id in session_ids:
  178. await sio.emit(
  179. "chat-events",
  180. {
  181. "chat_id": request_info["chat_id"],
  182. "message_id": request_info["message_id"],
  183. "data": event_data,
  184. },
  185. to=session_id,
  186. )
  187. return __event_emitter__
  188. def get_event_call(request_info):
  189. async def __event_call__(event_data):
  190. response = await sio.call(
  191. "chat-events",
  192. {
  193. "chat_id": request_info["chat_id"],
  194. "message_id": request_info["message_id"],
  195. "data": event_data,
  196. },
  197. to=request_info["session_id"],
  198. )
  199. return response
  200. return __event_call__
  201. def get_user_id_from_session_pool(sid):
  202. return SESSION_POOL.get(sid)