Bläddra i källkod

feat: __event_call__ support

Timothy J. Baek 10 månader sedan
förälder
incheckning
1b7ff1c5df

+ 32 - 4
backend/main.py

@@ -302,6 +302,7 @@ async def get_function_call_response(
     user,
     user,
     model,
     model,
     __event_emitter__=None,
     __event_emitter__=None,
+    __event_call__=None,
 ):
 ):
     tool = Tools.get_tool_by_id(tool_id)
     tool = Tools.get_tool_by_id(tool_id)
     tools_specs = json.dumps(tool.specs, indent=2)
     tools_specs = json.dumps(tool.specs, indent=2)
@@ -445,6 +446,13 @@ async def get_function_call_response(
                             "__event_emitter__": __event_emitter__,
                             "__event_emitter__": __event_emitter__,
                         }
                         }
 
 
+                    if "__event_call__" in sig.parameters:
+                        # Call the function with the '__event_call__' parameter included
+                        params = {
+                            **params,
+                            "__event_call__": __event_call__,
+                        }
+
                     if inspect.iscoroutinefunction(function):
                     if inspect.iscoroutinefunction(function):
                         function_result = await function(**params)
                         function_result = await function(**params)
                     else:
                     else:
@@ -468,7 +476,9 @@ async def get_function_call_response(
     return None, None, False
     return None, None, False
 
 
 
 
-async def chat_completion_functions_handler(body, model, user, __event_emitter__):
+async def chat_completion_functions_handler(
+    body, model, user, __event_emitter__, __event_call__
+):
     skip_files = None
     skip_files = None
 
 
     filter_ids = get_filter_function_ids(model)
     filter_ids = get_filter_function_ids(model)
@@ -534,12 +544,19 @@ async def chat_completion_functions_handler(body, model, user, __event_emitter__
                             **params,
                             **params,
                             "__model__": model,
                             "__model__": model,
                         }
                         }
+
                     if "__event_emitter__" in sig.parameters:
                     if "__event_emitter__" in sig.parameters:
                         params = {
                         params = {
                             **params,
                             **params,
                             "__event_emitter__": __event_emitter__,
                             "__event_emitter__": __event_emitter__,
                         }
                         }
 
 
+                    if "__event_call__" in sig.parameters:
+                        params = {
+                            **params,
+                            "__event_call__": __event_call__,
+                        }
+
                     if inspect.iscoroutinefunction(inlet):
                     if inspect.iscoroutinefunction(inlet):
                         body = await inlet(**params)
                         body = await inlet(**params)
                     else:
                     else:
@@ -556,7 +573,9 @@ async def chat_completion_functions_handler(body, model, user, __event_emitter__
     return body, {}
     return body, {}
 
 
 
 
-async def chat_completion_tools_handler(body, model, user, __event_emitter__):
+async def chat_completion_tools_handler(
+    body, model, user, __event_emitter__, __event_call__
+):
     skip_files = None
     skip_files = None
 
 
     contexts = []
     contexts = []
@@ -579,6 +598,7 @@ async def chat_completion_tools_handler(body, model, user, __event_emitter__):
                     user=user,
                     user=user,
                     model=model,
                     model=model,
                     __event_emitter__=__event_emitter__,
                     __event_emitter__=__event_emitter__,
+                    __event_call__=__event_call__,
                 )
                 )
 
 
                 print(file_handler)
                 print(file_handler)
@@ -676,6 +696,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                     to=session_id,
                     to=session_id,
                 )
                 )
 
 
+            async def __event_call__(data):
+                response = await sio.call(
+                    "chat-events",
+                    {"chat_id": chat_id, "message_id": message_id, "data": data},
+                    to=session_id,
+                )
+                return response
+
             # Initialize data_items to store additional data to be sent to the client
             # Initialize data_items to store additional data to be sent to the client
             data_items = []
             data_items = []
 
 
@@ -685,7 +713,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
 
 
             try:
             try:
                 body, flags = await chat_completion_functions_handler(
                 body, flags = await chat_completion_functions_handler(
-                    body, model, user, __event_emitter__
+                    body, model, user, __event_emitter__, __event_call__
                 )
                 )
             except Exception as e:
             except Exception as e:
                 return JSONResponse(
                 return JSONResponse(
@@ -695,7 +723,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
 
 
             try:
             try:
                 body, flags = await chat_completion_tools_handler(
                 body, flags = await chat_completion_tools_handler(
-                    body, model, user, __event_emitter__
+                    body, model, user, __event_emitter__, __event_call__
                 )
                 )
 
 
                 contexts.extend(flags.get("contexts", []))
                 contexts.extend(flags.get("contexts", []))

+ 27 - 3
src/lib/components/chat/Chat.svelte

@@ -61,6 +61,7 @@
 	import CallOverlay from './MessageInput/CallOverlay.svelte';
 	import CallOverlay from './MessageInput/CallOverlay.svelte';
 	import { error } from '@sveltejs/kit';
 	import { error } from '@sveltejs/kit';
 	import ChatControls from './ChatControls.svelte';
 	import ChatControls from './ChatControls.svelte';
+	import EventConfirmDialog from '../common/ConfirmDialog.svelte';
 
 
 	const i18n: Writable<i18nType> = getContext('i18n');
 	const i18n: Writable<i18nType> = getContext('i18n');
 
 
@@ -74,6 +75,11 @@
 	let processing = '';
 	let processing = '';
 	let messagesContainerElement: HTMLDivElement;
 	let messagesContainerElement: HTMLDivElement;
 
 
+	let showEventConfirmation = false;
+	let eventConfirmationTitle = '';
+	let eventConfirmationMessage = '';
+	let eventCallback = null;
+
 	let showModelSelector = true;
 	let showModelSelector = true;
 
 
 	let selectedModels = [''];
 	let selectedModels = [''];
@@ -129,7 +135,7 @@
 		})();
 		})();
 	}
 	}
 
 
-	const chatEventHandler = async (event) => {
+	const chatEventHandler = async (event, cb) => {
 		if (event.chat_id === $chatId) {
 		if (event.chat_id === $chatId) {
 			await tick();
 			await tick();
 			console.log(event);
 			console.log(event);
@@ -139,17 +145,23 @@
 			const data = event?.data?.data ?? null;
 			const data = event?.data?.data ?? null;
 
 
 			if (type === 'status') {
 			if (type === 'status') {
-				if (message.statusHistory) {
+				if (message?.statusHistory) {
 					message.statusHistory.push(data);
 					message.statusHistory.push(data);
 				} else {
 				} else {
 					message.statusHistory = [data];
 					message.statusHistory = [data];
 				}
 				}
 			} else if (type === 'citation') {
 			} else if (type === 'citation') {
-				if (message.citations) {
+				if (message?.citations) {
 					message.citations.push(data);
 					message.citations.push(data);
 				} else {
 				} else {
 					message.citations = [data];
 					message.citations = [data];
 				}
 				}
+			} else if (type === 'confirmation') {
+				eventCallback = cb;
+				showEventConfirmation = true;
+
+				eventConfirmationTitle = data.title;
+				eventConfirmationMessage = data.message;
 			} else {
 			} else {
 				console.log('Unknown message type', data);
 				console.log('Unknown message type', data);
 			}
 			}
@@ -1392,6 +1404,18 @@
 
 
 <audio id="audioElement" src="" style="display: none;" />
 <audio id="audioElement" src="" style="display: none;" />
 
 
+<EventConfirmDialog
+	bind:show={showEventConfirmation}
+	title={eventConfirmationTitle}
+	message={eventConfirmationMessage}
+	on:confirm={(e) => {
+		eventCallback(true);
+	}}
+	on:cancel={() => {
+		eventCallback(false);
+	}}
+/>
+
 {#if $showCallOverlay}
 {#if $showCallOverlay}
 	<CallOverlay
 	<CallOverlay
 		{submitPrompt}
 		{submitPrompt}

+ 15 - 4
src/lib/components/common/ConfirmDialog.svelte

@@ -7,8 +7,8 @@
 
 
 	const dispatch = createEventDispatcher();
 	const dispatch = createEventDispatcher();
 
 
-	export let title = $i18n.t('Confirm your action');
-	export let message = $i18n.t('This action cannot be undone. Do you wish to continue?');
+	export let title = '';
+	export let message = '';
 
 
 	export let cancelLabel = $i18n.t('Cancel');
 	export let cancelLabel = $i18n.t('Cancel');
 	export let confirmLabel = $i18n.t('Confirm');
 	export let confirmLabel = $i18n.t('Confirm');
@@ -58,11 +58,21 @@
 			}}
 			}}
 		>
 		>
 			<div class="px-[1.75rem] py-6">
 			<div class="px-[1.75rem] py-6">
-				<div class=" text-lg font-semibold dark:text-gray-200 mb-2.5">{title}</div>
+				<div class=" text-lg font-semibold dark:text-gray-200 mb-2.5">
+					{#if title !== ''}
+						{title}
+					{:else}
+						{$i18n.t('Confirm your action')}
+					{/if}
+				</div>
 
 
 				<slot>
 				<slot>
 					<div class=" text-sm text-gray-500">
 					<div class=" text-sm text-gray-500">
-						{message}
+						{#if message !== ''}
+							{message}
+						{:else}
+							{$i18n.t('This action cannot be undone. Do you wish to continue?')}
+						{/if}
 					</div>
 					</div>
 				</slot>
 				</slot>
 
 
@@ -71,6 +81,7 @@
 						class="bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-white font-medium w-full py-2.5 rounded-lg transition"
 						class="bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-white font-medium w-full py-2.5 rounded-lg transition"
 						on:click={() => {
 						on:click={() => {
 							show = false;
 							show = false;
+							dispatch('cancel');
 						}}
 						}}
 						type="button"
 						type="button"
 					>
 					>