Timothy J. Baek 9 months ago
parent
commit
0ef27bfc5e
2 changed files with 52 additions and 36 deletions
  1. 31 0
      backend/apps/socket/main.py
  2. 21 36
      backend/main.py

+ 31 - 0
backend/apps/socket/main.py

@@ -137,3 +137,34 @@ async def disconnect(sid):
         await sio.emit("user-count", {"count": len(USER_POOL)})
     else:
         print(f"Unknown session ID {sid} disconnected")
+
+
+async def get_event_emitter(request_info):
+    async def __event_emitter__(event_data):
+        await sio.emit(
+            "chat-events",
+            {
+                "chat_id": request_info["chat_id"],
+                "message_id": request_info["id"],
+                "data": event_data,
+            },
+            to=request_info["session_id"],
+        )
+
+    return __event_emitter__
+
+
+async def get_event_call(request_info):
+    async def __event_call__(event_data):
+        response = await sio.call(
+            "chat-events",
+            {
+                "chat_id": request_info["chat_id"],
+                "message_id": request_info["id"],
+                "data": event_data,
+            },
+            to=request_info["session_id"],
+        )
+        return response
+
+    return __event_call__

+ 21 - 36
backend/main.py

@@ -29,7 +29,7 @@ from starlette.middleware.sessions import SessionMiddleware
 from starlette.responses import StreamingResponse, Response, RedirectResponse
 
 
-from apps.socket.main import sio, app as socket_app
+from apps.socket.main import sio, app as socket_app, get_event_emitter, get_event_call
 from apps.ollama.main import (
     app as ollama_app,
     get_all_models as get_ollama_models,
@@ -632,24 +632,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                 message_id = body["id"]
                 del body["id"]
 
-            async def __event_emitter__(data):
-                await sio.emit(
-                    "chat-events",
-                    {
-                        "chat_id": chat_id,
-                        "message_id": message_id,
-                        "data": data,
-                    },
-                    to=session_id,
-                )
-
-            async def __event_call__(data):
-                response = await sio.call(
-                    "chat-events",
-                    {"chat_id": chat_id, "message_id": message_id, "data": data},
-                    to=session_id,
-                )
-                return response
+            __event_emitter__ = await get_event_emitter(
+                {"chat_id": chat_id, "message_id": message_id, "session_id": session_id}
+            )
+            __event_call__ = await get_event_call(
+                {"chat_id": chat_id, "message_id": message_id, "session_id": session_id}
+            )
 
             # Initialize data_items to store additional data to be sent to the client
             data_items = []
@@ -1107,24 +1095,21 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
             else:
                 pass
 
-    async def __event_emitter__(event_data):
-        await sio.emit(
-            "chat-events",
-            {
-                "chat_id": data["chat_id"],
-                "message_id": data["id"],
-                "data": event_data,
-            },
-            to=data["session_id"],
-        )
+    __event_emitter__ = await get_event_emitter(
+        {
+            "chat_id": data["chat_id"],
+            "message_id": data["id"],
+            "session_id": data["session_id"],
+        }
+    )
 
-    async def __event_call__(event_data):
-        response = await sio.call(
-            "chat-events",
-            {"chat_id": data["chat_id"], "message_id": data["id"], "data": event_data},
-            to=data["session_id"],
-        )
-        return response
+    __event_call__ = await get_event_call(
+        {
+            "chat_id": data["chat_id"],
+            "message_id": data["id"],
+            "session_id": data["session_id"],
+        }
+    )
 
     def get_priority(function_id):
         function = Functions.get_function_by_id(function_id)