Explorar o código

feat: channels backend

Timothy Jaeryang Baek hai 4 meses
pai
achega
7c8de9e221

+ 3 - 0
backend/open_webui/main.py

@@ -58,6 +58,7 @@ from open_webui.routers import (
     pipelines,
     pipelines,
     tasks,
     tasks,
     auths,
     auths,
+    channels,
     chats,
     chats,
     folders,
     folders,
     configs,
     configs,
@@ -737,6 +738,8 @@ app.include_router(configs.router, prefix="/api/v1/configs", tags=["configs"])
 app.include_router(auths.router, prefix="/api/v1/auths", tags=["auths"])
 app.include_router(auths.router, prefix="/api/v1/auths", tags=["auths"])
 app.include_router(users.router, prefix="/api/v1/users", tags=["users"])
 app.include_router(users.router, prefix="/api/v1/users", tags=["users"])
 
 
+
+app.include_router(channels.router, prefix="/api/v1/channels", tags=["channels"])
 app.include_router(chats.router, prefix="/api/v1/chats", tags=["chats"])
 app.include_router(chats.router, prefix="/api/v1/chats", tags=["chats"])
 
 
 app.include_router(models.router, prefix="/api/v1/models", tags=["models"])
 app.include_router(models.router, prefix="/api/v1/models", tags=["models"])

+ 47 - 0
backend/open_webui/migrations/versions/57c599a3cb57_add_channel_table.py

@@ -0,0 +1,47 @@
+"""Add channel table
+
+Revision ID: 57c599a3cb57
+Revises: 922e7a387820
+Create Date: 2024-12-22 03:00:00.000000
+
+"""
+
+from alembic import op
+import sqlalchemy as sa
+
+revision = "57c599a3cb57"
+down_revision = "922e7a387820"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    op.create_table(
+        "channel",
+        sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
+        sa.Column("user_id", sa.Text()),
+        sa.Column("name", sa.Text()),
+        sa.Column("data", sa.JSON(), nullable=True),
+        sa.Column("meta", sa.JSON(), nullable=True),
+        sa.Column("access_control", sa.JSON(), nullable=True),
+        sa.Column("created_at", sa.BigInteger(), nullable=True),
+        sa.Column("updated_at", sa.BigInteger(), nullable=True),
+    )
+
+    op.create_table(
+        "message",
+        sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
+        sa.Column("user_id", sa.Text()),
+        sa.Column("channel_id", sa.Text(), nullable=True),
+        sa.Column("content", sa.Text()),
+        sa.Column("data", sa.JSON(), nullable=True),
+        sa.Column("meta", sa.JSON(), nullable=True),
+        sa.Column("created_at", sa.BigInteger(), nullable=True),
+        sa.Column("updated_at", sa.BigInteger(), nullable=True),
+    )
+
+
+def downgrade():
+    op.drop_table("channel")
+
+    op.drop_table("message")

+ 115 - 0
backend/open_webui/models/channels.py

@@ -0,0 +1,115 @@
+import json
+import time
+import uuid
+from typing import Optional
+
+from open_webui.internal.db import Base, get_db
+from open_webui.models.tags import TagModel, Tag, Tags
+
+
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
+from sqlalchemy import or_, func, select, and_, text
+from sqlalchemy.sql import exists
+
+####################
+# Channel DB Schema
+####################
+
+
+class Channel(Base):
+    __tablename__ = "channel"
+
+    id = Column(Text, primary_key=True)
+    user_id = Column(Text)
+
+    name = Column(Text)
+    data = Column(JSON, nullable=True)
+    meta = Column(JSON, nullable=True)
+    access_control = Column(JSON, nullable=True)
+
+    created_at = Column(BigInteger)
+    updated_at = Column(BigInteger)
+
+
+class ChannelModel(BaseModel):
+    model_config = ConfigDict(from_attributes=True)
+
+    id: str
+    user_id: str
+
+    name: str
+    data: Optional[dict] = None
+    meta: Optional[dict] = None
+    access_control: Optional[dict] = None
+
+    created_at: int  # timestamp in epoch
+    updated_at: int  # timestamp in epoch
+
+
+####################
+# Forms
+####################
+
+
+class ChannelForm(BaseModel):
+    name: str
+    data: Optional[dict] = None
+    meta: Optional[dict] = None
+    access_control: Optional[dict] = None
+
+
+class ChannelTable:
+    def insert_new_channel(
+        self, form_data: ChannelForm, user_id: str
+    ) -> Optional[ChannelModel]:
+        with get_db() as db:
+            new_channel = Channel(
+                **{
+                    **form_data.model_dump(),
+                    "id": str(uuid.uuid4()),
+                    "user_id": user_id,
+                    "created_at": int(time.time()),
+                    "updated_at": int(time.time()),
+                }
+            )
+
+            db.add(new_channel)
+            db.commit()
+            return new_channel
+
+    def get_channels(self) -> list[ChannelModel]:
+        with get_db() as db:
+            channels = db.query(Channel).all()
+            return [ChannelModel.model_validate(channel) for channel in channels]
+
+    def get_channel_by_id(self, id: str) -> Optional[ChannelModel]:
+        with get_db() as db:
+            channel = db.query(Channel).filter(Channel.id == id).first()
+            return ChannelModel.model_validate(channel) if channel else None
+
+    def update_channel_by_id(
+        self, id: str, form_data: ChannelForm
+    ) -> Optional[ChannelModel]:
+        with get_db() as db:
+            channel = db.query(Channel).filter(Channel.id == id).first()
+            if not channel:
+                return None
+
+            channel.name = form_data.name
+            channel.data = form_data.data
+            channel.meta = form_data.meta
+            channel.access_control = form_data.access_control
+            channel.updated_at = int(time.time())
+
+            db.commit()
+            return ChannelModel.model_validate(channel) if channel else None
+
+    def delete_channel_by_id(self, id: str):
+        with get_db() as db:
+            db.query(Channel).filter(Channel.id == id).delete()
+            db.commit()
+            return True
+
+
+Channels = ChannelTable()

+ 139 - 0
backend/open_webui/models/messages.py

@@ -0,0 +1,139 @@
+import json
+import time
+import uuid
+from typing import Optional
+
+from open_webui.internal.db import Base, get_db
+from open_webui.models.tags import TagModel, Tag, Tags
+
+
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
+from sqlalchemy import or_, func, select, and_, text
+from sqlalchemy.sql import exists
+
+####################
+# Message DB Schema
+####################
+
+
+class Message(Base):
+    __tablename__ = "message"
+    id = Column(Text, primary_key=True)
+
+    user_id = Column(Text)
+    channel_id = Column(Text, nullable=True)
+
+    content = Column(Text)
+    data = Column(JSON, nullable=True)
+    meta = Column(JSON, nullable=True)
+
+    created_at = Column(BigInteger)
+    updated_at = Column(BigInteger)
+
+
+class MessageModel(BaseModel):
+    model_config = ConfigDict(from_attributes=True)
+
+    id: str
+    user_id: str
+    channel_id: Optional[str] = None
+
+    content: str
+    data: Optional[dict] = None
+    meta: Optional[dict] = None
+
+    created_at: int  # timestamp in epoch
+    updated_at: int  # timestamp in epoch
+
+
+####################
+# Forms
+####################
+
+
+class MessageForm(BaseModel):
+    content: str
+    data: Optional[dict] = None
+    meta: Optional[dict] = None
+
+
+class MessageTable:
+    def insert_new_message(
+        self, form_data: MessageForm, channel_id: str, user_id: str
+    ) -> Optional[MessageModel]:
+        with get_db() as db:
+            id = str(uuid.uuid4())
+            message = MessageModel(
+                **{
+                    "id": id,
+                    "user_id": user_id,
+                    "channel_id": channel_id,
+                    "content": form_data.content,
+                    "data": form_data.data,
+                    "meta": form_data.meta,
+                    "created_at": int(time.time()),
+                    "updated_at": int(time.time()),
+                }
+            )
+
+            result = Message(**message.model_dump())
+            db.add(result)
+            db.commit()
+            db.refresh(result)
+            return MessageModel.model_validate(result) if result else None
+
+    def get_message_by_id(self, id: str) -> Optional[MessageModel]:
+        with get_db() as db:
+            message = db.get(Message, id)
+            return MessageModel.model_validate(message) if message else None
+
+    def get_messages_by_channel_id(
+        self, channel_id: str, skip: int = 0, limit: int = 50
+    ) -> list[MessageModel]:
+        with get_db() as db:
+            all_messages = (
+                db.query(Message)
+                .filter_by(channel_id=channel_id)
+                .order_by(Message.updated_at.desc())
+                .limit(limit)
+                .offset(skip)
+                .all()
+            )
+            return [MessageModel.model_validate(message) for message in all_messages]
+
+    def get_messages_by_user_id(
+        self, user_id: str, skip: int = 0, limit: int = 50
+    ) -> list[MessageModel]:
+        with get_db() as db:
+            all_messages = (
+                db.query(Message)
+                .filter_by(user_id=user_id)
+                .order_by(Message.updated_at.desc())
+                .limit(limit)
+                .offset(skip)
+                .all()
+            )
+            return [MessageModel.model_validate(message) for message in all_messages]
+
+    def update_message_by_id(
+        self, id: str, form_data: MessageForm
+    ) -> Optional[MessageModel]:
+        with get_db() as db:
+            message = db.get(Message, id)
+            message.content = form_data.content
+            message.data = form_data.data
+            message.meta = form_data.meta
+            message.updated_at = int(time.time())
+            db.commit()
+            db.refresh(message)
+            return MessageModel.model_validate(message) if message else None
+
+    def delete_message_by_id(self, id: str) -> bool:
+        with get_db() as db:
+            db.query(Message).filter_by(id=id).delete()
+            db.commit()
+            return True
+
+
+Messages = MessageTable()

+ 102 - 0
backend/open_webui/routers/channels.py

@@ -0,0 +1,102 @@
+import json
+import logging
+from typing import Optional
+
+from open_webui.models.channels import Channels, ChannelModel, ChannelForm
+from open_webui.models.messages import Messages, MessageModel, MessageForm
+
+
+from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
+from open_webui.constants import ERROR_MESSAGES
+from open_webui.env import SRC_LOG_LEVELS
+from fastapi import APIRouter, Depends, HTTPException, Request, status
+from pydantic import BaseModel
+
+
+from open_webui.utils.auth import get_admin_user, get_verified_user
+from open_webui.utils.access_control import has_permission
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MODELS"])
+
+router = APIRouter()
+
+############################
+# GetChatList
+############################
+
+
+@router.get("/", response_model=list[ChannelModel])
+async def get_channels(user=Depends(get_verified_user)):
+    return Channels.get_channels()
+
+
+############################
+# CreateNewChannel
+############################
+
+
+@router.post("/create", response_model=Optional[ChannelModel])
+async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user)):
+    try:
+        channel = Channels.insert_new_channel(form_data, user.id)
+        return ChannelModel(**channel.model_dump())
+    except Exception as e:
+        log.exception(e)
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
+        )
+
+
+############################
+# GetChannelMessages
+############################
+
+
+@router.post("/{id}/messages", response_model=list[MessageModel])
+async def get_channel_messages(id: str, page: int = 1, user=Depends(get_verified_user)):
+    channel = Channels.get_channel_by_id(id)
+    if not channel:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
+        )
+
+    if not has_permission(channel.access_control, user):
+        raise HTTPException(
+            status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
+        )
+
+    limit = 50
+    skip = (page - 1) * limit
+
+    return Messages.get_messages_by_channel_id(id, skip, limit)
+
+
+############################
+# PostNewMessage
+############################
+
+
+@router.post("/{id}/messages/post", response_model=Optional[MessageModel])
+async def post_new_message(
+    id: str, form_data: MessageForm, user=Depends(get_verified_user)
+):
+    channel = Channels.get_channel_by_id(id)
+    if not channel:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
+        )
+
+    if not has_permission(channel.access_control, user):
+        raise HTTPException(
+            status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
+        )
+
+    try:
+        message = Messages.insert_new_message(form_data, channel.id, user.id)
+        return MessageModel(**message.model_dump())
+    except Exception as e:
+        log.exception(e)
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
+        )