Jelajahi Sumber

feat: channel socket integration

Timothy Jaeryang Baek 4 bulan lalu
induk
melakukan
f1d21fc59a

+ 12 - 2
backend/open_webui/models/channels.py

@@ -4,8 +4,7 @@ import uuid
 from typing import Optional
 
 from open_webui.internal.db import Base, get_db
-from open_webui.models.tags import TagModel, Tag, Tags
-
+from open_webui.utils.access_control import has_access
 
 from pydantic import BaseModel, ConfigDict
 from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
@@ -85,6 +84,17 @@ class ChannelTable:
             channels = db.query(Channel).all()
             return [ChannelModel.model_validate(channel) for channel in channels]
 
+    def get_channels_by_user_id(
+        self, user_id: str, permission: str = "read"
+    ) -> list[ChannelModel]:
+        channels = self.get_channels()
+        return [
+            channel
+            for channel in channels
+            if channel.user_id == user_id
+            or has_access(user_id, permission, channel.access_control)
+        ]
+
     def get_channel_by_id(self, id: str) -> Optional[ChannelModel]:
         with get_db() as db:
             channel = db.query(Channel).filter(Channel.id == id).first()

+ 2 - 2
backend/open_webui/models/messages.py

@@ -95,7 +95,7 @@ class MessageTable:
             all_messages = (
                 db.query(Message)
                 .filter_by(channel_id=channel_id)
-                .order_by(Message.updated_at.desc())
+                .order_by(Message.updated_at.asc())
                 .limit(limit)
                 .offset(skip)
                 .all()
@@ -109,7 +109,7 @@ class MessageTable:
             all_messages = (
                 db.query(Message)
                 .filter_by(user_id=user_id)
-                .order_by(Message.updated_at.desc())
+                .order_by(Message.updated_at.asc())
                 .limit(limit)
                 .offset(skip)
                 .all()

+ 22 - 6
backend/open_webui/routers/channels.py

@@ -2,6 +2,12 @@ import json
 import logging
 from typing import Optional
 
+
+from fastapi import APIRouter, Depends, HTTPException, Request, status
+from pydantic import BaseModel
+
+
+from open_webui.socket.main import sio
 from open_webui.models.channels import Channels, ChannelModel, ChannelForm
 from open_webui.models.messages import Messages, MessageModel, MessageForm
 
@@ -9,12 +15,10 @@ from open_webui.models.messages import Messages, MessageModel, MessageForm
 from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.env import SRC_LOG_LEVELS
-from fastapi import APIRouter, Depends, HTTPException, Request, status
-from pydantic import BaseModel
 
 
 from open_webui.utils.auth import get_admin_user, get_verified_user
-from open_webui.utils.access_control import has_permission
+from open_webui.utils.access_control import has_access
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -53,7 +57,7 @@ async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user
 ############################
 
 
-@router.post("/{id}/messages", response_model=list[MessageModel])
+@router.get("/{id}/messages", response_model=list[MessageModel])
 async def get_channel_messages(id: str, page: int = 1, user=Depends(get_verified_user)):
     channel = Channels.get_channel_by_id(id)
     if not channel:
@@ -61,7 +65,7 @@ async def get_channel_messages(id: str, page: int = 1, user=Depends(get_verified
             status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
         )
 
-    if not has_permission(channel.access_control, user):
+    if not has_access(user.id, type="read", access_control=channel.access_control):
         raise HTTPException(
             status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
         )
@@ -87,13 +91,25 @@ async def post_new_message(
             status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
         )
 
-    if not has_permission(channel.access_control, user):
+    if not has_access(user.id, type="read", access_control=channel.access_control):
         raise HTTPException(
             status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
         )
 
     try:
         message = Messages.insert_new_message(form_data, channel.id, user.id)
+
+        if message:
+            await sio.emit(
+                "channel-events",
+                {
+                    "channel_id": channel.id,
+                    "message_id": message.id,
+                    "data": {"message": message.model_dump()},
+                },
+                to=f"channel:{channel.id}",
+            )
+
         return MessageModel(**message.model_dump())
     except Exception as e:
         log.exception(e)

+ 7 - 1
backend/open_webui/socket/main.py

@@ -5,6 +5,7 @@ import sys
 import time
 
 from open_webui.models.users import Users
+from open_webui.models.channels import Channels
 from open_webui.env import (
     ENABLE_WEBSOCKET_SUPPORT,
     WEBSOCKET_MANAGER,
@@ -162,7 +163,6 @@ async def connect(sid, environ, auth):
 
 @sio.on("user-join")
 async def user_join(sid, data):
-    # print("user-join", sid, data)
 
     auth = data["auth"] if "auth" in data else None
     if not auth or "token" not in auth:
@@ -182,6 +182,12 @@ async def user_join(sid, data):
     else:
         USER_POOL[user.id] = [sid]
 
+    # Join all the channels
+    channels = Channels.get_channels_by_user_id(user.id)
+    log.debug(f"{channels=}")
+    for channel in channels:
+        await sio.enter_room(sid, f"channel:{channel.id}")
+
     # print(f"user {user.name}({user.id}) connected with session ID {sid}")
 
     await sio.emit("user-count", {"count": len(USER_POOL.items())})

+ 71 - 0
src/lib/apis/channels/index.ts

@@ -69,3 +69,74 @@ export const getChannels = async (token: string = '') => {
 
 	return res;
 };
+
+
+export const getChannelMessages = async (token: string = '', channel_id: string, page: number = 1) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/channels/${channel_id}/messages?page=${page}`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+}
+
+type MessageForm = {
+	content: string;
+	data?: object;
+    meta?: object;
+
+}
+
+export const sendMessage = async (token: string = '', channel_id: string, message: MessageForm) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/channels/${channel_id}/messages/post`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({ ...message })
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+}

+ 92 - 1
src/lib/components/channel/Channel.svelte

@@ -1,5 +1,96 @@
 <script lang="ts">
+	import { getChannelMessages, sendMessage } from '$lib/apis/channels';
+	import { toast } from 'svelte-sonner';
+	import MessageInput from './MessageInput.svelte';
+	import Messages from './Messages.svelte';
+	import { socket } from '$lib/stores';
+	import { onDestroy, onMount, tick } from 'svelte';
+
 	export let id = '';
+
+	let scrollEnd = true;
+	let messagesContainerElement = null;
+
+	let top = false;
+	let page = 1;
+
+	let messages = null;
+
+	$: if (id) {
+		initHandler();
+	}
+
+	const initHandler = async () => {
+		top = false;
+		page = 1;
+		messages = null;
+
+		messages = await getChannelMessages(localStorage.token, id, page);
+
+		if (messages.length < 50) {
+			top = true;
+		}
+	};
+
+	const channelEventHandler = async (data) => {
+		console.log(data);
+	};
+
+	const submitHandler = async ({ content }) => {
+		if (!content) {
+			return;
+		}
+
+		const res = await sendMessage(localStorage.token, id, { content: content }).catch((error) => {
+			toast.error(error);
+			return null;
+		});
+
+		if (res) {
+			messagesContainerElement.scrollTop = messagesContainerElement.scrollHeight;
+		}
+	};
+
+	onMount(() => {
+		$socket?.on('channel-events', channelEventHandler);
+	});
+
+	onDestroy(() => {
+		$socket?.off('channel-events', channelEventHandler);
+	});
 </script>
 
-{id}
+<div class="h-full md:max-w-[calc(100%-260px)] w-full max-w-full flex flex-col">
+	<div
+		class=" pb-2.5 flex flex-col justify-between w-full flex-auto overflow-auto h-0 max-w-full z-10 scrollbar-hidden"
+		id="messages-container"
+		bind:this={messagesContainerElement}
+		on:scroll={(e) => {
+			scrollEnd =
+				messagesContainerElement.scrollHeight - messagesContainerElement.scrollTop <=
+				messagesContainerElement.clientHeight + 5;
+		}}
+	>
+		{#key id}
+			<Messages
+				{messages}
+				onLoad={async () => {
+					page += 1;
+
+					const newMessages = await getChannelMessages(localStorage.token, id, page);
+
+					if (newMessages.length === 0) {
+						top = true;
+						return;
+					}
+
+					messages = [...newMessages, ...messages];
+				}}
+			/>
+		{/key}
+	</div>
+
+	<div class=" pb-[1rem]">
+		<MessageInput onSubmit={submitHandler} />
+	</div>
+</div>

+ 266 - 0
src/lib/components/channel/MessageInput.svelte

@@ -0,0 +1,266 @@
+<script lang="ts">
+	import { toast } from 'svelte-sonner';
+	import { tick, getContext } from 'svelte';
+
+	const i18n = getContext('i18n');
+
+	import { mobile, settings } from '$lib/stores';
+
+	import Tooltip from '../common/Tooltip.svelte';
+	import RichTextInput from '../common/RichTextInput.svelte';
+	import VoiceRecording from '../chat/MessageInput/VoiceRecording.svelte';
+
+	export let placeholder = $i18n.t('Send a Message');
+	export let transparentBackground = false;
+
+	let recording = false;
+
+	let content = '';
+
+	export let onSubmit: Function;
+
+	let submitHandler = async () => {
+		if (content === '') {
+			return;
+		}
+
+		onSubmit({
+			content
+		});
+
+		content = '';
+		await tick();
+
+		const chatInputElement = document.getElementById('chat-input');
+		chatInputElement?.focus();
+	};
+</script>
+
+<div class="{transparentBackground ? 'bg-transparent' : 'bg-white dark:bg-gray-900'} ">
+	<div class="max-w-6xl px-2.5 mx-auto inset-x-0">
+		<div class="">
+			{#if recording}
+				<VoiceRecording
+					bind:recording
+					on:cancel={async () => {
+						recording = false;
+
+						await tick();
+						document.getElementById('chat-input')?.focus();
+					}}
+					on:confirm={async (e) => {
+						const { text, filename } = e.detail;
+						content = `${content}${text} `;
+						recording = false;
+
+						await tick();
+						document.getElementById('chat-input')?.focus();
+					}}
+				/>
+			{:else}
+				<form
+					class="w-full flex gap-1.5"
+					on:submit|preventDefault={() => {
+						submitHandler();
+					}}
+				>
+					<div
+						class="flex-1 flex flex-col relative w-full rounded-3xl px-1 bg-gray-50 dark:bg-gray-400/5 dark:text-gray-100"
+						dir={$settings?.chatDirection ?? 'LTR'}
+					>
+						<div class=" flex">
+							<div class="ml-1 self-end mb-1.5 flex space-x-1">
+								<button
+									class="bg-transparent hover:bg-white/80 text-gray-800 dark:text-white dark:hover:bg-gray-800 transition rounded-full p-2 outline-none focus:outline-none"
+									type="button"
+									aria-label="More"
+								>
+									<svg
+										xmlns="http://www.w3.org/2000/svg"
+										viewBox="0 0 20 20"
+										fill="currentColor"
+										class="size-5"
+									>
+										<path
+											d="M10.75 4.75a.75.75 0 0 0-1.5 0v4.5h-4.5a.75.75 0 0 0 0 1.5h4.5v4.5a.75.75 0 0 0 1.5 0v-4.5h4.5a.75.75 0 0 0 0-1.5h-4.5v-4.5Z"
+										/>
+									</svg>
+								</button>
+							</div>
+
+							{#if $settings?.richTextInput ?? true}
+								<div
+									class="scrollbar-hidden text-left bg-transparent dark:text-gray-100 outline-none w-full py-2.5 px-1 rounded-xl resize-none h-fit max-h-80 overflow-auto"
+								>
+									<RichTextInput
+										bind:value={content}
+										id="chat-input"
+										messageInput={true}
+										shiftEnter={!$mobile ||
+											!(
+												'ontouchstart' in window ||
+												navigator.maxTouchPoints > 0 ||
+												navigator.msMaxTouchPoints > 0
+											)}
+										{placeholder}
+										largeTextAsFile={$settings?.largeTextAsFile ?? false}
+										on:keydown={async (e) => {
+											e = e.detail.event;
+											const isCtrlPressed = e.ctrlKey || e.metaKey; // metaKey is for Cmd key on Mac
+											if (
+												!$mobile ||
+												!(
+													'ontouchstart' in window ||
+													navigator.maxTouchPoints > 0 ||
+													navigator.msMaxTouchPoints > 0
+												)
+											) {
+												// Prevent Enter key from creating a new line
+												// Uses keyCode '13' for Enter key for chinese/japanese keyboards
+												if (e.keyCode === 13 && !e.shiftKey) {
+													e.preventDefault();
+												}
+
+												// Submit the content when Enter key is pressed
+												if (content !== '' && e.keyCode === 13 && !e.shiftKey) {
+													submitHandler();
+												}
+											}
+
+											if (e.key === 'Escape') {
+												console.log('Escape');
+											}
+										}}
+										on:paste={async (e) => {
+											e = e.detail.event;
+											console.log(e);
+										}}
+									/>
+								</div>
+							{:else}
+								<textarea
+									id="chat-input"
+									class="scrollbar-hidden bg-transparent dark:text-gray-100 outline-none w-full py-3 px-1 rounded-xl resize-none h-[48px]"
+									{placeholder}
+									bind:value={content}
+									on:keydown={async (e) => {
+										e = e.detail.event;
+										const isCtrlPressed = e.ctrlKey || e.metaKey; // metaKey is for Cmd key on Mac
+										if (
+											!$mobile ||
+											!(
+												'ontouchstart' in window ||
+												navigator.maxTouchPoints > 0 ||
+												navigator.msMaxTouchPoints > 0
+											)
+										) {
+											// Prevent Enter key from creating a new line
+											// Uses keyCode '13' for Enter key for chinese/japanese keyboards
+											if (e.keyCode === 13 && !e.shiftKey) {
+												e.preventDefault();
+											}
+
+											// Submit the content when Enter key is pressed
+											if (content !== '' && e.keyCode === 13 && !e.shiftKey) {
+												submitHandler();
+											}
+										}
+
+										if (e.key === 'Escape') {
+											console.log('Escape');
+										}
+									}}
+									rows="1"
+									on:input={async (e) => {
+										e.target.style.height = '';
+										e.target.style.height = Math.min(e.target.scrollHeight, 320) + 'px';
+									}}
+									on:focus={async (e) => {
+										e.target.style.height = '';
+										e.target.style.height = Math.min(e.target.scrollHeight, 320) + 'px';
+									}}
+								/>
+							{/if}
+
+							<div class="self-end mb-1.5 flex space-x-1 mr-1">
+								{#if content === ''}
+									<Tooltip content={$i18n.t('Record voice')}>
+										<button
+											id="voice-input-button"
+											class=" text-gray-600 dark:text-gray-300 hover:text-gray-700 dark:hover:text-gray-200 transition rounded-full p-1.5 mr-0.5 self-center"
+											type="button"
+											on:click={async () => {
+												try {
+													let stream = await navigator.mediaDevices
+														.getUserMedia({ audio: true })
+														.catch(function (err) {
+															toast.error(
+																$i18n.t(`Permission denied when accessing microphone: {{error}}`, {
+																	error: err
+																})
+															);
+															return null;
+														});
+
+													if (stream) {
+														recording = true;
+														const tracks = stream.getTracks();
+														tracks.forEach((track) => track.stop());
+													}
+													stream = null;
+												} catch {
+													toast.error($i18n.t('Permission denied when accessing microphone'));
+												}
+											}}
+											aria-label="Voice Input"
+										>
+											<svg
+												xmlns="http://www.w3.org/2000/svg"
+												viewBox="0 0 20 20"
+												fill="currentColor"
+												class="w-5 h-5 translate-y-[0.5px]"
+											>
+												<path d="M7 4a3 3 0 016 0v6a3 3 0 11-6 0V4z" />
+												<path
+													d="M5.5 9.643a.75.75 0 00-1.5 0V10c0 3.06 2.29 5.585 5.25 5.954V17.5h-1.5a.75.75 0 000 1.5h4.5a.75.75 0 000-1.5h-1.5v-1.546A6.001 6.001 0 0016 10v-.357a.75.75 0 00-1.5 0V10a4.5 4.5 0 01-9 0v-.357z"
+												/>
+											</svg>
+										</button>
+									</Tooltip>
+								{/if}
+
+								<div class=" flex items-center">
+									<div class=" flex items-center">
+										<Tooltip content={$i18n.t('Send message')}>
+											<button
+												id="send-message-button"
+												class="{content !== ''
+													? 'bg-black text-white hover:bg-gray-900 dark:bg-white dark:text-black dark:hover:bg-gray-100 '
+													: 'text-white bg-gray-200 dark:text-gray-900 dark:bg-gray-700 disabled'} transition rounded-full p-1.5 self-center"
+												type="submit"
+												disabled={content === ''}
+											>
+												<svg
+													xmlns="http://www.w3.org/2000/svg"
+													viewBox="0 0 16 16"
+													fill="currentColor"
+													class="size-6"
+												>
+													<path
+														fill-rule="evenodd"
+														d="M8 14a.75.75 0 0 1-.75-.75V4.56L4.03 7.78a.75.75 0 0 1-1.06-1.06l4.5-4.5a.75.75 0 0 1 1.06 0l4.5 4.5a.75.75 0 0 1-1.06 1.06L8.75 4.56v8.69A.75.75 0 0 1 8 14Z"
+														clip-rule="evenodd"
+													/>
+												</svg>
+											</button>
+										</Tooltip>
+									</div>
+								</div>
+							</div>
+						</div>
+					</div>
+				</form>
+			{/if}
+		</div>
+	</div>
+</div>

+ 14 - 13
src/lib/components/channel/Messages.svelte

@@ -9,14 +9,15 @@
 	import Message from './Messages/Message.svelte';
 	import Loader from '../common/Loader.svelte';
 	import Spinner from '../common/Spinner.svelte';
+	import { getChannelMessages } from '$lib/apis/channels';
 
 	const i18n = getContext('i18n');
 
-	export let channelId;
+	export let messages = [];
+	export let top = false;
 
-	let messages = null;
+	export let onLoad: Function = () => {};
 
-	let messagesCount = 50;
 	let messagesLoading = false;
 
 	const loadMoreMessages = async () => {
@@ -25,19 +26,19 @@
 		element.scrollTop = element.scrollTop + 100;
 
 		messagesLoading = true;
-		messagesCount += 50;
 
-		await tick();
+		await onLoad();
 
+		await tick();
 		messagesLoading = false;
 	};
 </script>
 
-<div class="h-full flex pt-8">
-	<div class="w-full pt-2">
-		{#key channelId}
+{#if messages}
+	<div class="h-full w-full flex-1 flex">
+		<div class="w-full pt-2">
 			<div class="w-full">
-				{#if messages.at(0)?.parentId !== null}
+				{#if !top}
 					<Loader
 						on:visible={(e) => {
 							console.log('visible');
@@ -54,10 +55,10 @@
 				{/if}
 
 				{#each messages as message, messageIdx (message.id)}
-					<Message {channelId} id={message.id} content={message.content} />
+					<Message {message} />
 				{/each}
 			</div>
-			<div class="pb-12" />
-		{/key}
+			<div class="pb-6" />
+		</div>
 	</div>
-</div>
+{/if}

+ 13 - 0
src/lib/components/channel/Messages/Message.svelte

@@ -0,0 +1,13 @@
+<script lang="ts">
+	import Markdown from '$lib/components/chat/Messages/Markdown.svelte';
+
+	export let message;
+</script>
+
+{#if message}
+	<div>
+		<div>
+			<Markdown id={message.id} content={message.content} />
+		</div>
+	</div>
+{/if}

+ 1 - 1
src/lib/components/chat/MessageInput.svelte

@@ -49,7 +49,7 @@
 
 	export let autoScroll = false;
 
-	export let atSelectedModel: Model | undefined;
+	export let atSelectedModel: Model | undefined = undefined;
 	export let selectedModels: [''];
 
 	let selectedModelIds = [];

+ 1 - 0
src/lib/components/common/Image.svelte

@@ -19,6 +19,7 @@
 	on:click={() => {
 		showImagePreview = true;
 	}}
+	type="button"
 >
 	<img src={_src} {alt} class={imageClassName} draggable="false" data-cy="image" />
 </button>

+ 3 - 2
src/lib/components/layout/Sidebar.svelte

@@ -17,7 +17,8 @@
 		scrollPaginationEnabled,
 		currentChatPage,
 		temporaryChatEnabled,
-		channels
+		channels,
+		socket
 	} from '$lib/stores';
 	import { onMount, getContext, tick, onDestroy } from 'svelte';
 
@@ -151,7 +152,7 @@
 	};
 
 	const initChannels = async () => {
-		channels.set(await getChannels(localStorage.token));
+		await channels.set(await getChannels(localStorage.token));
 	};
 
 	const initChatList = async () => {

+ 5 - 3
src/routes/+layout.svelte

@@ -38,7 +38,7 @@
 	let loaded = false;
 	const BREAKPOINT = 768;
 
-	const setupSocket = (enableWebsocket) => {
+	const setupSocket = async (enableWebsocket) => {
 		const _socket = io(`${WEBUI_BASE_URL}` || undefined, {
 			reconnection: true,
 			reconnectionDelay: 1000,
@@ -49,7 +49,7 @@
 			auth: { token: localStorage.token }
 		});
 
-		socket.set(_socket);
+		await socket.set(_socket);
 
 		_socket.on('connect_error', (err) => {
 			console.log('connect_error', err);
@@ -127,7 +127,7 @@
 			await WEBUI_NAME.set(backendConfig.name);
 
 			if ($config) {
-				setupSocket($config.features?.enable_websocket ?? true);
+				await setupSocket($config.features?.enable_websocket ?? true);
 
 				if (localStorage.token) {
 					// Get Session User Info
@@ -138,6 +138,8 @@
 
 					if (sessionUser) {
 						// Save Session User to Store
+						$socket.emit('user-join', { auth: { token: sessionUser.token } });
+
 						await user.set(sessionUser);
 						await config.set(await getBackendConfig());
 					} else {