Timothy Jaeryang Baek 4 mesi fa
parent
commit
9c337552e6

+ 25 - 2
backend/open_webui/models/messages.py

@@ -89,6 +89,8 @@ class Reactions(BaseModel):
 
 
 class MessageResponse(MessageModel):
+    latest_reply_at: Optional[int]
+    reply_count: int
     reactions: list[Reactions]
 
 
@@ -127,13 +129,34 @@ class MessageTable:
                 return None
 
             reactions = self.get_reactions_by_message_id(id)
+            replies = self.get_replies_by_message_id(id)
+
             return MessageResponse(
                 **{
                     **MessageModel.model_validate(message).model_dump(),
+                    "latest_reply_at": replies[0].created_at if replies else None,
+                    "reply_count": len(replies),
                     "reactions": reactions,
                 }
             )
 
+    def get_replies_by_message_id(self, id: str) -> list[MessageModel]:
+        with get_db() as db:
+            all_messages = (
+                db.query(Message)
+                .filter_by(parent_id=id)
+                .order_by(Message.created_at.desc())
+                .all()
+            )
+            return [MessageModel.model_validate(message) for message in all_messages]
+
+    def get_reply_user_ids_by_message_id(self, id: str) -> list[str]:
+        with get_db() as db:
+            return [
+                message.user_id
+                for message in db.query(Message).filter_by(parent_id=id).all()
+            ]
+
     def get_messages_by_channel_id(
         self, channel_id: str, skip: int = 0, limit: int = 50
     ) -> list[MessageModel]:
@@ -166,9 +189,9 @@ class MessageTable:
                 .all()
             )
 
-            return [MessageModel.model_validate(message)] + [
+            return [
                 MessageModel.model_validate(message) for message in all_messages
-            ]
+            ] + [MessageModel.model_validate(message)]
 
     def update_message_by_id(
         self, id: str, form_data: MessageForm

+ 101 - 13
backend/open_webui/routers/channels.py

@@ -169,10 +169,15 @@ async def get_channel_messages(
             user = Users.get_user_by_id(message.user_id)
             users[message.user_id] = user
 
+        replies = Messages.get_replies_by_message_id(message.id)
+        latest_reply_at = replies[0].created_at if replies else None
+
         messages.append(
             MessageUserResponse(
                 **{
                     **message.model_dump(),
+                    "reply_count": len(replies),
+                    "latest_reply_at": latest_reply_at,
                     "reactions": Messages.get_reactions_by_message_id(message.id),
                     "user": UserNameResponse(**users[message.user_id].model_dump()),
                 }
@@ -242,10 +247,17 @@ async def post_new_message(
                 "message_id": message.id,
                 "data": {
                     "type": "message",
-                    "data": {
-                        **message.model_dump(),
-                        "user": UserNameResponse(**user.model_dump()).model_dump(),
-                    },
+                    "data": MessageUserResponse(
+                        **{
+                            **message.model_dump(),
+                            "reply_count": 0,
+                            "latest_reply_at": None,
+                            "reactions": Messages.get_reactions_by_message_id(
+                                message.id
+                            ),
+                            "user": UserNameResponse(**user.model_dump()),
+                        }
+                    ).model_dump(),
                 },
                 "user": UserNameResponse(**user.model_dump()).model_dump(),
                 "channel": channel.model_dump(),
@@ -257,6 +269,35 @@ async def post_new_message(
                 to=f"channel:{channel.id}",
             )
 
+            if message.parent_id:
+                # If this message is a reply, emit to the parent message as well
+                parent_message = Messages.get_message_by_id(message.parent_id)
+
+                if parent_message:
+                    await sio.emit(
+                        "channel-events",
+                        {
+                            "channel_id": channel.id,
+                            "message_id": parent_message.id,
+                            "data": {
+                                "type": "message:reply",
+                                "data": MessageUserResponse(
+                                    **{
+                                        **parent_message.model_dump(),
+                                        "user": UserNameResponse(
+                                            **Users.get_user_by_id(
+                                                parent_message.user_id
+                                            ).model_dump()
+                                        ),
+                                    }
+                                ).model_dump(),
+                            },
+                            "user": UserNameResponse(**user.model_dump()).model_dump(),
+                            "channel": channel.model_dump(),
+                        },
+                        to=f"channel:{channel.id}",
+                    )
+
             active_user_ids = get_user_ids_from_room(f"channel:{channel.id}")
 
             background_tasks.add_task(
@@ -275,6 +316,49 @@ async def post_new_message(
         )
 
 
+############################
+# GetChannelMessage
+############################
+
+
+@router.get("/{id}/messages/{message_id}", response_model=Optional[MessageUserResponse])
+async def get_channel_message(
+    id: str, message_id: str, user=Depends(get_verified_user)
+):
+    channel = Channels.get_channel_by_id(id)
+    if not channel:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
+        )
+
+    if user.role != "admin" and not has_access(
+        user.id, type="read", access_control=channel.access_control
+    ):
+        raise HTTPException(
+            status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
+        )
+
+    message = Messages.get_message_by_id(message_id)
+    if not message:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
+        )
+
+    if message.channel_id != id:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
+        )
+
+    return MessageUserResponse(
+        **{
+            **message.model_dump(),
+            "user": UserNameResponse(
+                **Users.get_user_by_id(message.user_id).model_dump()
+            ),
+        }
+    )
+
+
 ############################
 # GetChannelThreadMessages
 ############################
@@ -316,6 +400,8 @@ async def get_channel_thread_messages(
             MessageUserResponse(
                 **{
                     **message.model_dump(),
+                    "reply_count": 0,
+                    "latest_reply_at": None,
                     "reactions": Messages.get_reactions_by_message_id(message.id),
                     "user": UserNameResponse(**users[message.user_id].model_dump()),
                 }
@@ -372,10 +458,14 @@ async def update_message_by_id(
                     "message_id": message.id,
                     "data": {
                         "type": "message:update",
-                        "data": {
-                            **message.model_dump(),
-                            "user": UserNameResponse(**user.model_dump()).model_dump(),
-                        },
+                        "data": MessageUserResponse(
+                            **{
+                                **message.model_dump(),
+                                "user": UserNameResponse(
+                                    **user.model_dump()
+                                ).model_dump(),
+                            }
+                        ).model_dump(),
                     },
                     "user": UserNameResponse(**user.model_dump()).model_dump(),
                     "channel": channel.model_dump(),
@@ -430,18 +520,17 @@ async def add_reaction_to_message(
 
     try:
         Messages.add_reaction_to_message(message_id, user.id, form_data.name)
-
         message = Messages.get_message_by_id(message_id)
+
         await sio.emit(
             "channel-events",
             {
                 "channel_id": channel.id,
                 "message_id": message.id,
                 "data": {
-                    "type": "message:reaction",
+                    "type": "message:reaction:add",
                     "data": {
                         **message.model_dump(),
-                        "user": UserNameResponse(**user.model_dump()).model_dump(),
                         "name": form_data.name,
                     },
                 },
@@ -505,10 +594,9 @@ async def remove_reaction_by_id_and_user_id_and_name(
                 "channel_id": channel.id,
                 "message_id": message.id,
                 "data": {
-                    "type": "message:reaction",
+                    "type": "message:reaction:remove",
                     "data": {
                         **message.model_dump(),
-                        "user": UserNameResponse(**user.model_dump()).model_dump(),
                         "name": form_data.name,
                     },
                 },

+ 1 - 0
src/lib/apis/channels/index.ts

@@ -250,6 +250,7 @@ export const getChannelThreadMessages = async (
 }
 
 type MessageForm = {
+	parent_id?: string;
 	content: string;
 	data?: object;
 	meta?: object;

+ 10 - 8
src/lib/components/channel/Channel.svelte

@@ -74,15 +74,17 @@
 			const data = event?.data?.data ?? null;
 
 			if (type === 'message') {
-				messages = [data, ...messages];
+				if ((data?.parent_id ?? null) === null) {
+					messages = [data, ...messages];
 
-				if (typingUsers.find((user) => user.id === event.user.id)) {
-					typingUsers = typingUsers.filter((user) => user.id !== event.user.id);
-				}
+					if (typingUsers.find((user) => user.id === event.user.id)) {
+						typingUsers = typingUsers.filter((user) => user.id !== event.user.id);
+					}
 
-				await tick();
-				if (scrollEnd) {
-					messagesContainerElement.scrollTop = messagesContainerElement.scrollHeight;
+					await tick();
+					if (scrollEnd) {
+						messagesContainerElement.scrollTop = messagesContainerElement.scrollHeight;
+					}
 				}
 			} else if (type === 'message:update') {
 				const idx = messages.findIndex((message) => message.id === data.id);
@@ -92,7 +94,7 @@
 				}
 			} else if (type === 'message:delete') {
 				messages = messages.filter((message) => message.id !== data.id);
-			} else if (type === 'message:reaction') {
+			} else if (type.includes('message:reaction')) {
 				const idx = messages.findIndex((message) => message.id === data.id);
 				if (idx !== -1) {
 					messages[idx] = data;

+ 24 - 0
src/lib/components/channel/Messages/Message.svelte

@@ -29,6 +29,7 @@
 	import ChatBubbleOvalEllipsis from '$lib/components/icons/ChatBubbleOvalEllipsis.svelte';
 	import FaceSmile from '$lib/components/icons/FaceSmile.svelte';
 	import ReactionPicker from './Message/ReactionPicker.svelte';
+	import ChevronRight from '$lib/components/icons/ChevronRight.svelte';
 
 	export let message;
 	export let showUserProfile = true;
@@ -324,6 +325,29 @@
 							</div>
 						</div>
 					{/if}
+
+					{#if message.reply_count > 0}
+						<div class="flex items-center gap-1.5 -mt-0.5 mb-1.5">
+							<button
+								class="flex items-center text-xs py-1 text-gray-500 dark:text-gray-400 hover:text-gray-700 dark:hover:text-gray-300 transition"
+								on:click={() => {
+									onThread(message.id);
+								}}
+							>
+								<span class="font-medium mr-1">
+									{$i18n.t('{{COUNT}} Replies', { COUNT: message.reply_count })}</span
+								><span>
+									{' - '}{$i18n.t('Last reply')}
+									{dayjs.unix(message.latest_reply_at / 1000000000).fromNow()}</span
+								>
+
+								<span class="ml-1">
+									<ChevronRight className="size-2.5" strokeWidth="3" />
+								</span>
+								<!-- {$i18n.t('View Replies')} -->
+							</button>
+						</div>
+					{/if}
 				{/if}
 			</div>
 		</div>

+ 48 - 46
src/lib/components/channel/Messages/Message/ReactionPicker.svelte

@@ -75,7 +75,7 @@
 
 	<slot name="content">
 		<DropdownMenu.Content
-			class="max-w-full  w-80  bg-gray-50 dark:bg-gray-850 rounded-lg z-50 shadow-lg text-white"
+			class="max-w-full  w-80  bg-gray-50 dark:bg-gray-850 rounded-lg z-50 shadow-lg dark:text-white"
 			sideOffset={8}
 			{side}
 			{align}
@@ -90,54 +90,56 @@
 				/>
 			</div>
 			<div class=" w-full flex justify-start h-96 overflow-y-auto px-3 pb-3 text-sm">
-				<div>
-					{#if Object.keys(emojis).length === 0}
-						<div class="text-center text-xs text-gray-500 dark:text-gray-400">No results</div>
-					{:else}
-						{#each Object.keys(emojiGroups) as group}
-							{@const groupEmojis = emojiGroups[group].filter((emoji) => emojis[emoji])}
-							{#if groupEmojis.length > 0}
-								<div class="flex flex-col">
-									<div class="text-xs font-medium mb-2 text-gray-500 dark:text-gray-400">
-										{group}
-									</div>
-
-									<div class="flex mb-2 flex-wrap gap-1">
-										{#each groupEmojis as emoji (emoji)}
-											<Tooltip
-												content={(typeof emojiShortCodes[emoji] === 'string'
-													? [emojiShortCodes[emoji]]
-													: emojiShortCodes[emoji]
-												)
-													.map((code) => `:${code}:`)
-													.join(', ')}
-												placement="top"
-											>
-												<button
-													class="p-1.5 rounded-lg cursor-pointer hover:bg-gray-200 dark:hover:bg-gray-700 transition"
-													on:click={() => {
-														typeof emojiShortCodes[emoji] === 'string'
-															? onSubmit(emojiShortCodes[emoji])
-															: onSubmit(emojiShortCodes[emoji][0]);
+				{#if show}
+					<div>
+						{#if Object.keys(emojis).length === 0}
+							<div class="text-center text-xs text-gray-500 dark:text-gray-400">No results</div>
+						{:else}
+							{#each Object.keys(emojiGroups) as group}
+								{@const groupEmojis = emojiGroups[group].filter((emoji) => emojis[emoji])}
+								{#if groupEmojis.length > 0}
+									<div class="flex flex-col">
+										<div class="text-xs font-medium mb-2 text-gray-500 dark:text-gray-400">
+											{group}
+										</div>
 
-														show = false;
-													}}
+										<div class="flex mb-2 flex-wrap gap-1">
+											{#each groupEmojis as emoji (emoji)}
+												<Tooltip
+													content={(typeof emojiShortCodes[emoji] === 'string'
+														? [emojiShortCodes[emoji]]
+														: emojiShortCodes[emoji]
+													)
+														.map((code) => `:${code}:`)
+														.join(', ')}
+													placement="top"
 												>
-													<img
-														src="/assets/emojis/{emoji.toLowerCase()}.svg"
-														alt={emoji}
-														class="size-5"
-														loading="lazy"
-													/>
-												</button>
-											</Tooltip>
-										{/each}
+													<button
+														class="p-1.5 rounded-lg cursor-pointer hover:bg-gray-200 dark:hover:bg-gray-700 transition"
+														on:click={() => {
+															typeof emojiShortCodes[emoji] === 'string'
+																? onSubmit(emojiShortCodes[emoji])
+																: onSubmit(emojiShortCodes[emoji][0]);
+
+															show = false;
+														}}
+													>
+														<img
+															src="/assets/emojis/{emoji.toLowerCase()}.svg"
+															alt={emoji}
+															class="size-5"
+															loading="lazy"
+														/>
+													</button>
+												</Tooltip>
+											{/each}
+										</div>
 									</div>
-								</div>
-							{/if}
-						{/each}
-					{/if}
-				</div>
+								{/if}
+							{/each}
+						{/if}
+					</div>
+				{/if}
 			</div>
 		</DropdownMenu.Content>
 	</slot>

+ 15 - 5
src/lib/components/channel/Thread.svelte

@@ -3,12 +3,13 @@
 
 	import { socket } from '$lib/stores';
 
-	import { getChannelThreadMessages } from '$lib/apis/channels';
+	import { getChannelThreadMessages, sendMessage } from '$lib/apis/channels';
 
 	import XMark from '$lib/components/icons/XMark.svelte';
 	import MessageInput from './MessageInput.svelte';
 	import Messages from './Messages.svelte';
 	import { onMount } from 'svelte';
+	import { toast } from 'svelte-sonner';
 
 	export let threadId = null;
 	export let channel = null;
@@ -43,10 +44,19 @@
 		}
 	};
 
-	const submitHandler = async (message) => {
-		// if (message) {
-		// 	await sendMessage(localStorage.token, channel.id, message, threadId);
-		// }
+	const submitHandler = async ({ content, data }) => {
+		if (!content) {
+			return;
+		}
+
+		const res = await sendMessage(localStorage.token, channel.id, {
+			parent_id: threadId,
+			content: content,
+			data: data
+		}).catch((error) => {
+			toast.error(error);
+			return null;
+		});
 	};
 
 	const onChange = async () => {