Pārlūkot izejas kodu

feat: __event_emitter__

Timothy J. Baek 10 mēneši atpakaļ
vecāks
revīzija
a07051f51b
2 mainītis faili ar 54 papildinājumiem un 9 dzēšanām
  1. 46 9
      backend/main.py
  2. 8 0
      src/lib/components/chat/Chat.svelte

+ 46 - 9
backend/main.py

@@ -33,7 +33,7 @@ from starlette.middleware.sessions import SessionMiddleware
 from starlette.responses import StreamingResponse, Response, RedirectResponse
 
 
-from apps.socket.main import app as socket_app
+from apps.socket.main import sio, app as socket_app
 from apps.ollama.main import (
     app as ollama_app,
     OpenAIChatCompletionForm,
@@ -277,7 +277,14 @@ def get_filter_function_ids(model):
 
 
 async def get_function_call_response(
-    messages, files, tool_id, template, task_model_id, user, model
+    messages,
+    files,
+    tool_id,
+    template,
+    task_model_id,
+    user,
+    model,
+    __event_emitter__=None,
 ):
     tool = Tools.get_tool_by_id(tool_id)
     tools_specs = json.dumps(tool.specs, indent=2)
@@ -414,6 +421,13 @@ async def get_function_call_response(
                             "__id__": tool_id,
                         }
 
+                    if "__event_emitter__" in sig.parameters:
+                        # Call the function with the '__event_emitter__' parameter included
+                        params = {
+                            **params,
+                            "__event_emitter__": model,
+                        }
+
                     if inspect.iscoroutinefunction(function):
                         function_result = await function(**params)
                     else:
@@ -437,7 +451,7 @@ async def get_function_call_response(
     return None, None, False
 
 
-async def chat_completion_functions_handler(body, model, user):
+async def chat_completion_functions_handler(body, model, user, __event_emitter__):
     skip_files = None
 
     filter_ids = get_filter_function_ids(model)
@@ -503,6 +517,11 @@ async def chat_completion_functions_handler(body, model, user):
                             **params,
                             "__model__": model,
                         }
+                    if "__event_emitter__" in sig.parameters:
+                        params = {
+                            **params,
+                            "__event_emitter__": __event_emitter__,
+                        }
 
                     if inspect.iscoroutinefunction(inlet):
                         body = await inlet(**params)
@@ -520,7 +539,7 @@ async def chat_completion_functions_handler(body, model, user):
     return body, {}
 
 
-async def chat_completion_tools_handler(body, model, user):
+async def chat_completion_tools_handler(body, model, user, __event_emitter__):
     skip_files = None
 
     contexts = []
@@ -542,6 +561,7 @@ async def chat_completion_tools_handler(body, model, user):
                     task_model_id=task_model_id,
                     user=user,
                     model=model,
+                    __event_emitter__=__event_emitter__,
                 )
 
                 print(file_handler)
@@ -614,7 +634,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                     content={"detail": str(e)},
                 )
 
-            # Extract chat_id and message_id from the request body
+            # Extract session_id, chat_id and message_id from the request body
+            session_id = None
+            if "session_id" in body:
+                session_id = body["session_id"]
+                del body["session_id"]
             chat_id = None
             if "chat_id" in body:
                 chat_id = body["chat_id"]
@@ -624,6 +648,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                 message_id = body["id"]
                 del body["id"]
 
+            async def __event_emitter__(data):
+                await sio.emit(
+                    "chat-events",
+                    {
+                        "chat_id": chat_id,
+                        "message_id": message_id,
+                        "data": data,
+                    },
+                    to=session_id,
+                )
+
             # Initialize data_items to store additional data to be sent to the client
             data_items = []
 
@@ -631,10 +666,10 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             contexts = []
             citations = []
 
-            print(body)
-
             try:
-                body, flags = await chat_completion_functions_handler(body, model, user)
+                body, flags = await chat_completion_functions_handler(
+                    body, model, user, __event_emitter__
+                )
             except Exception as e:
                 return JSONResponse(
                     status_code=status.HTTP_400_BAD_REQUEST,
@@ -642,7 +677,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                 )
 
             try:
-                body, flags = await chat_completion_tools_handler(body, model, user)
+                body, flags = await chat_completion_tools_handler(
+                    body, model, user, __event_emitter__
+                )
 
                 contexts.extend(flags.get("contexts", []))
                 citations.extend(flags.get("citations", []))

+ 8 - 0
src/lib/components/chat/Chat.svelte

@@ -163,6 +163,10 @@
 		};
 		window.addEventListener('message', onMessageHandler);
 
+		$socket.on('chat-events', async (data) => {
+			console.log(data);
+		});
+
 		if (!$chatId) {
 			chatId.subscribe(async (value) => {
 				if (!value) {
@@ -177,6 +181,8 @@
 
 		return () => {
 			window.removeEventListener('message', onMessageHandler);
+
+			$socket.off('chat-events');
 		};
 	});
 
@@ -683,6 +689,7 @@
 			keep_alive: $settings.keepAlive ?? undefined,
 			tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
 			files: files.length > 0 ? files : undefined,
+			session_id: $socket?.id,
 			chat_id: $chatId,
 			id: responseMessageId
 		});
@@ -984,6 +991,7 @@
 					max_tokens: $settings?.params?.max_tokens ?? undefined,
 					tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
 					files: files.length > 0 ? files : undefined,
+					session_id: $socket?.id,
 					chat_id: $chatId,
 					id: responseMessageId
 				},