瀏覽代碼

refac: threads

Timothy Jaeryang Baek 4 月之前
父節點
當前提交
584e9e6da5

+ 23 - 1
backend/open_webui/models/messages.py

@@ -140,7 +140,7 @@ class MessageTable:
         with get_db() as db:
             all_messages = (
                 db.query(Message)
-                .filter_by(channel_id=channel_id)
+                .filter_by(channel_id=channel_id, parent_id=None)
                 .order_by(Message.created_at.desc())
                 .offset(skip)
                 .limit(limit)
@@ -148,6 +148,28 @@ class MessageTable:
             )
             return [MessageModel.model_validate(message) for message in all_messages]
 
+    def get_messages_by_parent_id(
+        self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50
+    ) -> list[MessageModel]:
+        with get_db() as db:
+            message = db.get(Message, parent_id)
+
+            if not message:
+                return []
+
+            all_messages = (
+                db.query(Message)
+                .filter_by(channel_id=channel_id, parent_id=parent_id)
+                .order_by(Message.created_at.desc())
+                .offset(skip)
+                .limit(limit)
+                .all()
+            )
+
+            return [MessageModel.model_validate(message)] + [
+                MessageModel.model_validate(message) for message in all_messages
+            ]
+
     def update_message_by_id(
         self, id: str, form_data: MessageForm
     ) -> Optional[MessageModel]:

+ 50 - 0
backend/open_webui/routers/channels.py

@@ -275,6 +275,56 @@ async def post_new_message(
         )
 
 
+############################
+# GetChannelThreadMessages
+############################
+
+
+@router.get(
+    "/{id}/messages/{message_id}/thread", response_model=list[MessageUserResponse]
+)
+async def get_channel_thread_messages(
+    id: str,
+    message_id: str,
+    skip: int = 0,
+    limit: int = 50,
+    user=Depends(get_verified_user),
+):
+    channel = Channels.get_channel_by_id(id)
+    if not channel:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
+        )
+
+    if user.role != "admin" and not has_access(
+        user.id, type="read", access_control=channel.access_control
+    ):
+        raise HTTPException(
+            status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
+        )
+
+    message_list = Messages.get_messages_by_parent_id(id, message_id, skip, limit)
+    users = {}
+
+    messages = []
+    for message in message_list:
+        if message.user_id not in users:
+            user = Users.get_user_by_id(message.user_id)
+            users[message.user_id] = user
+
+        messages.append(
+            MessageUserResponse(
+                **{
+                    **message.model_dump(),
+                    "reactions": Messages.get_reactions_by_message_id(message.id),
+                    "user": UserNameResponse(**users[message.user_id].model_dump()),
+                }
+            )
+        )
+
+    return messages
+
+
 ############################
 # UpdateMessageById
 ############################

+ 1 - 0
backend/open_webui/socket/main.py

@@ -237,6 +237,7 @@ async def channel_events(sid, data):
             "channel-events",
             {
                 "channel_id": data["channel_id"],
+                "message_id": data.get("message_id", None),
                 "data": event_data,
                 "user": UserNameResponse(**SESSION_POOL[sid]).model_dump(),
             },

+ 42 - 0
src/lib/apis/channels/index.ts

@@ -1,4 +1,5 @@
 import { WEBUI_API_BASE_URL } from '$lib/constants';
+import { t } from 'i18next';
 
 type ChannelForm = {
 	name: string;
@@ -207,6 +208,47 @@ export const getChannelMessages = async (
 	return res;
 };
 
+
+export const getChannelThreadMessages = async (
+	token: string = '',
+	channel_id: string,
+	message_id: string,
+	skip: number = 0,
+	limit: number = 50
+) => {
+	let error = null;
+
+	const res = await fetch(
+		`${WEBUI_API_BASE_URL}/channels/${channel_id}/messages/${message_id}/thread?skip=${skip}&limit=${limit}`,
+		{
+			method: 'GET',
+			headers: {
+				Accept: 'application/json',
+				'Content-Type': 'application/json',
+				authorization: `Bearer ${token}`
+			}
+		}
+	)
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+}
+
 type MessageForm = {
 	content: string;
 	data?: object;

+ 3 - 2
src/lib/components/channel/Channel.svelte

@@ -13,7 +13,7 @@
 	import Navbar from './Navbar.svelte';
 	import Drawer from '../common/Drawer.svelte';
 	import EllipsisVertical from '../icons/EllipsisVertical.svelte';
-	import Thread from './Messages/Thread.svelte';
+	import Thread from './Thread.svelte';
 
 	export let id = '';
 
@@ -147,6 +147,7 @@
 	const onChange = async () => {
 		$socket?.emit('channel-events', {
 			channel_id: id,
+			message_id: null,
 			data: {
 				type: 'typing',
 				data: {
@@ -276,7 +277,7 @@
 				</div>
 			</PaneResizer>
 
-			<Pane defaultSize={50} minSize={20} class="h-full w-full">
+			<Pane defaultSize={50} minSize={30} class="h-full w-full">
 				<div class="h-full w-full shadow-xl">
 					<Thread
 						{threadId}

+ 3 - 3
src/lib/components/channel/MessageInput.svelte

@@ -37,7 +37,7 @@
 	export let onSubmit: Function;
 	export let onChange: Function;
 	export let scrollEnd = true;
-	export let scrollToBottom: Function;
+	export let scrollToBottom: Function = () => {};
 
 	const screenCaptureHandler = async () => {
 		try {
@@ -313,7 +313,7 @@
 		filesInputElement.value = '';
 	}}
 />
-<div class="{transparentBackground ? 'bg-transparent' : 'bg-white dark:bg-gray-900'} ">
+<div class="bg-transparent">
 	<div
 		class="{($settings?.widescreenMode ?? null)
 			? 'max-w-full'
@@ -392,7 +392,7 @@
 					}}
 				>
 					<div
-						class="flex-1 flex flex-col relative w-full rounded-3xl px-1 bg-gray-50 dark:bg-gray-400/5 dark:text-gray-100"
+						class="flex-1 flex flex-col relative w-full rounded-3xl px-1 bg-gray-600/5 dark:bg-gray-400/5 dark:text-gray-100"
 						dir={$settings?.chatDirection ?? 'LTR'}
 					>
 						{#if files.length > 0}

+ 5 - 2
src/lib/components/channel/Messages.svelte

@@ -20,9 +20,11 @@
 
 	const i18n = getContext('i18n');
 
+	export let id = null;
 	export let channel = null;
 	export let messages = [];
 	export let top = false;
+	export let thread = false;
 
 	export let onLoad: Function = () => {};
 	export let onThread: Function = () => {};
@@ -60,7 +62,7 @@
 					<div class=" ">Loading...</div>
 				</div>
 			</Loader>
-		{:else}
+		{:else if !thread}
 			<div
 				class="px-5
 			
@@ -89,9 +91,10 @@
 			</div>
 		{/if}
 
-		{#each messageList as message, messageIdx (message.id)}
+		{#each messageList as message, messageIdx (id ? `${id}-${message.id}` : message.id)}
 			<Message
 				{message}
+				{thread}
 				showUserProfile={messageIdx === 0 ||
 					messageList.at(messageIdx - 1)?.user_id !== message.user_id}
 				onDelete={() => {

+ 2 - 0
src/lib/components/channel/Messages/Message/ReactionPicker.svelte

@@ -56,6 +56,8 @@
 	};
 </script>
 
+<!-- TODO: Rendering Optimisation, This works but it's slow af -->
+
 <DropdownMenu.Root
 	bind:open={show}
 	closeFocus={false}

+ 0 - 28
src/lib/components/channel/Messages/Thread.svelte

@@ -1,28 +0,0 @@
-<script lang="ts">
-	import XMark from '$lib/components/icons/XMark.svelte';
-
-	export let threadId = null;
-	export let channel = null;
-
-	export let onClose = () => {};
-</script>
-
-<div class="flex flex-col w-full h-full bg-gray-50 dark:bg-gray-850 px-3.5 py-3">
-	<div class="flex items-center justify-between">
-		<div class=" font-medium text-lg">Thread</div>
-
-		<div>
-			<button
-				class="text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-300 p-2"
-				on:click={() => {
-					onClose();
-				}}
-			>
-				<XMark />
-			</button>
-		</div>
-	</div>
-	{threadId}
-
-	{channel}
-</div>

+ 108 - 0
src/lib/components/channel/Thread.svelte

@@ -0,0 +1,108 @@
+<script lang="ts">
+	import { goto } from '$app/navigation';
+
+	import { socket } from '$lib/stores';
+
+	import { getChannelThreadMessages } from '$lib/apis/channels';
+
+	import XMark from '$lib/components/icons/XMark.svelte';
+	import MessageInput from './MessageInput.svelte';
+	import Messages from './Messages.svelte';
+	import { onMount } from 'svelte';
+
+	export let threadId = null;
+	export let channel = null;
+
+	export let onClose = () => {};
+
+	let messages = null;
+	let top = false;
+
+	let typingUsers = [];
+	let typingUsersTimeout = {};
+
+	$: if (threadId) {
+		initHandler();
+	}
+
+	const initHandler = async () => {
+		messages = null;
+		top = false;
+
+		typingUsers = [];
+		typingUsersTimeout = {};
+
+		if (channel) {
+			messages = await getChannelThreadMessages(localStorage.token, channel.id, threadId);
+
+			if (messages.length < 50) {
+				top = true;
+			}
+		} else {
+			goto('/');
+		}
+	};
+
+	const submitHandler = async (message) => {
+		// if (message) {
+		// 	await sendMessage(localStorage.token, channel.id, message, threadId);
+		// }
+	};
+
+	const onChange = async () => {
+		$socket?.emit('channel-events', {
+			channel_id: channel.id,
+			message_id: threadId,
+			data: {
+				type: 'typing',
+				data: {
+					typing: true
+				}
+			}
+		});
+	};
+</script>
+
+{#if channel}
+	<div class="flex flex-col w-full h-full bg-gray-50 dark:bg-gray-850">
+		<div class="flex items-center justify-between mb-2 px-3.5 py-3">
+			<div class=" font-medium text-lg">Thread</div>
+
+			<div>
+				<button
+					class="text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-300 p-2"
+					on:click={() => {
+						onClose();
+					}}
+				>
+					<XMark />
+				</button>
+			</div>
+		</div>
+
+		<Messages
+			id={threadId}
+			{channel}
+			{messages}
+			{top}
+			thread={true}
+			onLoad={async () => {
+				const newMessages = await getChannelThreadMessages(
+					localStorage.token,
+					channel.id,
+					threadId,
+					messages.length
+				);
+
+				messages = [...messages, ...newMessages];
+
+				if (newMessages.length < 50) {
+					top = true;
+					return;
+				}
+			}}
+		/>
+
+		<MessageInput {typingUsers} {onChange} onSubmit={submitHandler} />
+	</div>
+{/if}

+ 1 - 1
src/lib/components/chat/MessageInput.svelte

@@ -544,7 +544,7 @@
 							}}
 						>
 							<div
-								class="flex-1 flex flex-col relative w-full rounded-3xl px-1 bg-gray-50 dark:bg-gray-400/5 dark:text-gray-100"
+								class="flex-1 flex flex-col relative w-full rounded-3xl px-1 bg-gray-600/5 dark:bg-gray-400/5 dark:text-gray-100"
 								dir={$settings?.chatDirection ?? 'LTR'}
 							>
 								{#if files.length > 0}