浏览代码

enh: channel notification

Timothy Jaeryang Baek 4 月之前
父节点
当前提交
d701b69e05

+ 7 - 0
backend/open_webui/models/groups.py

@@ -146,6 +146,13 @@ class GroupTable:
         except Exception:
             return None
 
+    def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
+        group = self.get_group_by_id(id)
+        if group:
+            return group.user_ids
+        else:
+            return None
+
     def update_group_by_id(
         self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
     ) -> Optional[GroupModel]:

+ 27 - 11
backend/open_webui/models/users.py

@@ -154,13 +154,25 @@ class UsersTable:
         except Exception:
             return None
 
-    def get_users(self, skip: int = 0, limit: int = 50) -> list[UserModel]:
+    def get_users(
+        self, skip: Optional[int] = None, limit: Optional[int] = None
+    ) -> list[UserModel]:
         with get_db() as db:
-            users = (
-                db.query(User)
-                # .offset(skip).limit(limit)
-                .all()
-            )
+
+            query = db.query(User).order_by(User.created_at.desc())
+
+            if skip:
+                query = query.offset(skip)
+            if limit:
+                query = query.limit(limit)
+
+            users = query.all()
+
+            return [UserModel.model_validate(user) for user in users]
+
+    def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]:
+        with get_db() as db:
+            users = db.query(User).filter(User.id.in_(user_ids)).all()
             return [UserModel.model_validate(user) for user in users]
 
     def get_num_users(self) -> Optional[int]:
@@ -179,11 +191,15 @@ class UsersTable:
         try:
             with get_db() as db:
                 user = db.query(User).filter_by(id=id).first()
-                return (
-                    user.settings.get("ui", {})
-                    .get("notifications", {})
-                    .get("webhook_url", None)
-                )
+
+                if user.settings is None:
+                    return None
+                else:
+                    return (
+                        user.settings.get("ui", {})
+                        .get("notifications", {})
+                        .get("webhook_url", None)
+                    )
         except Exception:
             return None
 

+ 68 - 17
backend/open_webui/routers/channels.py

@@ -3,11 +3,11 @@ import logging
 from typing import Optional
 
 
-from fastapi import APIRouter, Depends, HTTPException, Request, status
+from fastapi import APIRouter, Depends, HTTPException, Request, status, BackgroundTasks
 from pydantic import BaseModel
 
 
-from open_webui.socket.main import sio
+from open_webui.socket.main import sio, SESSION_POOL
 from open_webui.models.users import Users, UserNameResponse
 
 from open_webui.models.channels import Channels, ChannelModel, ChannelForm
@@ -16,11 +16,12 @@ from open_webui.models.messages import Messages, MessageModel, MessageForm
 
 from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
 from open_webui.constants import ERROR_MESSAGES
-from open_webui.env import SRC_LOG_LEVELS
+from open_webui.env import SRC_LOG_LEVELS, WEBUI_URL
 
 
 from open_webui.utils.auth import get_admin_user, get_verified_user
-from open_webui.utils.access_control import has_access
+from open_webui.utils.access_control import has_access, get_users_with_access
+from open_webui.utils.webhook import post_webhook
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -180,9 +181,39 @@ async def get_channel_messages(
 ############################
 
 
+async def send_notification(channel, message, active_user_ids):
+
+    print(f"Sending notification to {channel=}, {message=}, {active_user_ids=}")
+    users = get_users_with_access("read", channel.access_control)
+
+    for user in users:
+        if user.id in active_user_ids:
+            continue
+        else:
+            if user.settings:
+                webhook_url = user.settings.ui.get("notifications", {}).get(
+                    "webhook_url", None
+                )
+
+                if webhook_url:
+                    post_webhook(
+                        webhook_url,
+                        f"#{channel.name} - {WEBUI_URL}/c/{channel.id}\n\n{message.content}",
+                        {
+                            "action": "channel",
+                            "message": message.content,
+                            "title": channel.name,
+                            "url": f"{WEBUI_URL}/c/{channel.id}",
+                        },
+                    )
+
+
 @router.post("/{id}/messages/post", response_model=Optional[MessageModel])
 async def post_new_message(
-    id: str, form_data: MessageForm, user=Depends(get_verified_user)
+    id: str,
+    form_data: MessageForm,
+    background_tasks: BackgroundTasks,
+    user=Depends(get_verified_user),
 ):
     channel = Channels.get_channel_by_id(id)
     if not channel:
@@ -201,24 +232,44 @@ async def post_new_message(
         message = Messages.insert_new_message(form_data, channel.id, user.id)
 
         if message:
-            await sio.emit(
-                "channel-events",
-                {
-                    "channel_id": channel.id,
-                    "message_id": message.id,
+            event_data = {
+                "channel_id": channel.id,
+                "message_id": message.id,
+                "data": {
+                    "type": "message",
                     "data": {
-                        "type": "message",
-                        "data": {
-                            **message.model_dump(),
-                            "user": UserNameResponse(**user.model_dump()).model_dump(),
-                        },
+                        **message.model_dump(),
+                        "user": UserNameResponse(**user.model_dump()).model_dump(),
                     },
-                    "user": UserNameResponse(**user.model_dump()).model_dump(),
-                    "channel": channel.model_dump(),
                 },
+                "user": UserNameResponse(**user.model_dump()).model_dump(),
+                "channel": channel.model_dump(),
+            }
+
+            await sio.emit(
+                "channel-events",
+                event_data,
                 to=f"channel:{channel.id}",
             )
 
+            active_session_ids = sio.manager.get_participants(
+                namespace="/",
+                room=f"channel:{channel.id}",
+            )
+
+            active_user_ids = list(
+                set(
+                    [
+                        SESSION_POOL.get(session_id[0])
+                        for session_id in active_session_ids
+                    ]
+                )
+            )
+
+            background_tasks.add_task(
+                send_notification, channel, message, active_user_ids
+            )
+
         return MessageModel(**message.model_dump())
     except Exception as e:
         log.exception(e)

+ 22 - 0
backend/open_webui/utils/access_control.py

@@ -1,4 +1,5 @@
 from typing import Optional, Union, List, Dict, Any
+from open_webui.models.users import Users, UserModel
 from open_webui.models.groups import Groups
 import json
 
@@ -93,3 +94,24 @@ def has_access(
     return user_id in permitted_user_ids or any(
         group_id in permitted_group_ids for group_id in user_group_ids
     )
+
+
+# Get all users with access to a resource
+def get_users_with_access(
+    type: str = "write", access_control: Optional[dict] = None
+) -> List[UserModel]:
+    if access_control is None:
+        return Users.get_users()
+
+    permission_access = access_control.get(type, {})
+    permitted_group_ids = permission_access.get("group_ids", [])
+    permitted_user_ids = permission_access.get("user_ids", [])
+
+    user_ids_with_access = set(permitted_user_ids)
+
+    for group_id in permitted_group_ids:
+        group_user_ids = Groups.get_group_user_ids_by_id(group_id)
+        if group_user_ids:
+            user_ids_with_access.update(group_user_ids)
+
+    return Users.get_users_by_user_ids(list(user_ids_with_access))

+ 1 - 1
backend/open_webui/utils/webhook.py

@@ -21,7 +21,7 @@ def post_webhook(url: str, message: str, event_data: dict) -> bool:
         elif "https://discord.com/api/webhooks" in url:
             payload["content"] = (
                 message
-                if len(message) > 2000
+                if len(message) < 2000
                 else f"{message[: 2000 - 20]}... (truncated)"
             )
         # Microsoft Teams Webhooks