Przeglądaj źródła

enh: Actions `__webui__` flag support

Timothy J. Baek 8 miesięcy temu
rodzic
commit
8c2ba7f7ea

+ 6 - 0
backend/main.py

@@ -1026,6 +1026,10 @@ async def get_all_models():
                 function_module, _, _ = load_function_module_by_id(action_id)
                 webui_app.state.FUNCTIONS[action_id] = function_module
 
+            __webui__ = False
+            if hasattr(function_module, "__webui__"):
+                __webui__ = function_module.__webui__
+
             if hasattr(function_module, "actions"):
                 actions = function_module.actions
                 model["actions"].extend(
@@ -1039,6 +1043,7 @@ async def get_all_models():
                             "icon_url": _action.get(
                                 "icon_url", action.meta.manifest.get("icon_url", None)
                             ),
+                            **({"__webui__": __webui__} if __webui__ else {}),
                         }
                         for _action in actions
                     ]
@@ -1050,6 +1055,7 @@ async def get_all_models():
                         "name": action.name,
                         "description": action.meta.description,
                         "icon_url": action.meta.manifest.get("icon_url", None),
+                        **({"__webui__": __webui__} if __webui__ else {}),
                     }
                 )
 

+ 2 - 1
src/lib/components/chat/Chat.svelte

@@ -430,7 +430,7 @@
 		}
 	};
 
-	const chatActionHandler = async (chatId, actionId, modelId, responseMessageId) => {
+	const chatActionHandler = async (chatId, actionId, modelId, responseMessageId, event = null) => {
 		const res = await chatAction(localStorage.token, actionId, {
 			model: modelId,
 			messages: messages.map((m) => ({
@@ -440,6 +440,7 @@
 				info: m.info ? m.info : undefined,
 				timestamp: m.timestamp
 			})),
+			...(event ? { event: event } : {}),
 			chat_id: chatId,
 			session_id: $socket?.id,
 			id: responseMessageId

+ 7 - 1
src/lib/components/chat/Messages.svelte

@@ -342,7 +342,13 @@
 										{continueGeneration}
 										{regenerateResponse}
 										on:action={async (e) => {
-											await chatActionHandler(chatId, e.detail, message.model, message.id);
+											console.log('action', e);
+											if (typeof e.detail === 'string') {
+												await chatActionHandler(chatId, e.detail, message.model, message.id);
+											} else {
+												const { id, event } = e.detail;
+												await chatActionHandler(chatId, id, message.model, message.id, event);
+											}
 										}}
 										on:save={async (e) => {
 											console.log('save', e);

+ 4 - 1
src/lib/components/chat/Messages/RateComment.svelte

@@ -57,7 +57,10 @@
 		message.annotation.reason = selectedReason;
 		message.annotation.comment = comment;
 
-		dispatch('submit');
+		dispatch('submit', {
+			reason: selectedReason,
+			comment: comment
+		});
 
 		toast.success($i18n.t('Thanks for your feedback!'));
 		show = false;

+ 80 - 7
src/lib/components/chat/Messages/ResponseMessage.svelte

@@ -821,10 +821,24 @@
 												?.annotation?.rating ?? null) === 1
 												? 'bg-gray-100 dark:bg-gray-800'
 												: ''} dark:hover:text-white hover:text-black transition"
-											on:click={() => {
-												rateMessage(message.id, 1);
-												showRateComment = true;
+											on:click={async () => {
+												await rateMessage(message.id, 1);
+
+												(model?.actions ?? [])
+													.filter((action) => action?.__webui__ ?? false)
+													.forEach((action) => {
+														dispatch('action', {
+															id: action.id,
+															event: {
+																id: 'good-response',
+																data: {
+																	messageId: message.id
+																}
+															}
+														});
+													});
 
+												showRateComment = true;
 												window.setTimeout(() => {
 													document
 														.getElementById(`message-feedback-${message.id}`)
@@ -856,8 +870,23 @@
 												?.annotation?.rating ?? null) === -1
 												? 'bg-gray-100 dark:bg-gray-800'
 												: ''} dark:hover:text-white hover:text-black transition"
-											on:click={() => {
-												rateMessage(message.id, -1);
+											on:click={async () => {
+												await rateMessage(message.id, -1);
+
+												(model?.actions ?? [])
+													.filter((action) => action?.__webui__ ?? false)
+													.forEach((action) => {
+														dispatch('action', {
+															id: action.id,
+															event: {
+																id: 'bad-response',
+																data: {
+																	messageId: message.id
+																}
+															}
+														});
+													});
+
 												showRateComment = true;
 												window.setTimeout(() => {
 													document
@@ -891,6 +920,20 @@
 													: 'invisible group-hover:visible'} p-1.5 hover:bg-black/5 dark:hover:bg-white/5 rounded-lg dark:hover:text-white hover:text-black transition regenerate-response-button"
 												on:click={() => {
 													continueGeneration();
+
+													(model?.actions ?? [])
+														.filter((action) => action?.__webui__ ?? false)
+														.forEach((action) => {
+															dispatch('action', {
+																id: action.id,
+																event: {
+																	id: 'continue-response',
+																	data: {
+																		messageId: message.id
+																	}
+																}
+															});
+														});
 												}}
 											>
 												<svg
@@ -924,6 +967,20 @@
 												on:click={() => {
 													showRateComment = false;
 													regenerateResponse(message);
+
+													(model?.actions ?? [])
+														.filter((action) => action?.__webui__ ?? false)
+														.forEach((action) => {
+															dispatch('action', {
+																id: action.id,
+																event: {
+																	id: 'regenerate-response',
+																	data: {
+																		messageId: message.id
+																	}
+																}
+															});
+														});
 												}}
 											>
 												<svg
@@ -943,7 +1000,7 @@
 											</button>
 										</Tooltip>
 
-										{#each model?.actions ?? [] as action}
+										{#each (model?.actions ?? []).filter((action) => !(action?.__webui__ ?? false)) as action}
 											<Tooltip content={action.name} placement="bottom">
 												<button
 													type="button"
@@ -980,8 +1037,24 @@
 							messageId={message.id}
 							bind:show={showRateComment}
 							bind:message
-							on:submit={() => {
+							on:submit={(e) => {
 								updateChatMessages();
+
+								(model?.actions ?? [])
+									.filter((action) => action?.__webui__ ?? false)
+									.forEach((action) => {
+										dispatch('action', {
+											id: action.id,
+											event: {
+												id: 'rate-comment',
+												data: {
+													messageId: message.id,
+													comment: e.detail.comment,
+													reason: e.detail.reason
+												}
+											}
+										});
+									});
 							}}
 						/>
 					{/if}