浏览代码

feat: groups backend

Timothy Jaeryang Baek 5 月之前
父节点
当前提交
659f3dac44

+ 11 - 4
backend/open_webui/apps/webui/main.py

@@ -12,6 +12,7 @@ from open_webui.apps.webui.routers import (
     chats,
     folders,
     configs,
+    groups,
     files,
     functions,
     memories,
@@ -85,7 +86,11 @@ from open_webui.utils.payload import (
 
 from open_webui.utils.tools import get_tools
 
-app = FastAPI(docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None)
+app = FastAPI(
+    docs_url="/docs" if ENV == "dev" else None,
+    openapi_url="/openapi.json" if ENV == "dev" else None,
+    redoc_url=None,
+)
 
 log = logging.getLogger(__name__)
 
@@ -161,13 +166,15 @@ app.include_router(models.router, prefix="/models", tags=["models"])
 app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"])
 app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
 app.include_router(tools.router, prefix="/tools", tags=["tools"])
-app.include_router(functions.router, prefix="/functions", tags=["functions"])
 
 app.include_router(memories.router, prefix="/memories", tags=["memories"])
-app.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"])
-
 app.include_router(folders.router, prefix="/folders", tags=["folders"])
+
+app.include_router(groups.router, prefix="/groups", tags=["groups"])
 app.include_router(files.router, prefix="/files", tags=["files"])
+app.include_router(functions.router, prefix="/functions", tags=["functions"])
+app.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"])
+
 
 app.include_router(utils.router, prefix="/utils", tags=["utils"])
 

+ 169 - 0
backend/open_webui/apps/webui/models/groups.py

@@ -0,0 +1,169 @@
+import json
+import logging
+import time
+from typing import Optional
+import uuid
+
+from open_webui.apps.webui.internal.db import Base, get_db
+from open_webui.env import SRC_LOG_LEVELS
+
+from open_webui.apps.webui.models.files import FileMetadataResponse
+
+
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Column, String, Text, JSON
+
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MODELS"])
+
+####################
+# UserGroup DB Schema
+####################
+
+
+class Group(Base):
+    __tablename__ = "group"
+
+    id = Column(Text, unique=True, primary_key=True)
+    user_id = Column(Text)
+
+    name = Column(Text)
+    description = Column(Text)
+    meta = Column(JSON, nullable=True)
+
+    permissions = Column(JSON, nullable=True)
+    user_ids = Column(JSON, nullable=True)
+    admin_ids = Column(JSON, nullable=True)
+
+    created_at = Column(BigInteger)
+    updated_at = Column(BigInteger)
+
+
+class GroupModel(BaseModel):
+    model_config = ConfigDict(from_attributes=True)
+    id: str
+    user_id: str
+
+    name: str
+    description: str
+    meta: Optional[dict] = None
+
+    permissions: Optional[dict] = None
+    user_ids: list[str] = []
+    admin_ids: list[str] = []
+
+    created_at: int  # timestamp in epoch
+    updated_at: int  # timestamp in epoch
+
+
+####################
+# Forms
+####################
+
+
+class GroupResponse(BaseModel):
+    id: str
+    user_id: str
+    name: str
+    description: str
+    permissions: Optional[dict] = None
+    meta: Optional[dict] = None
+    user_ids: list[str] = []
+    admin_ids: list[str] = []
+    created_at: int  # timestamp in epoch
+    updated_at: int  # timestamp in epoch
+
+
+class GroupForm(BaseModel):
+    name: str
+    description: str
+
+
+class GroupUpdateForm(GroupForm):
+    permissions: Optional[dict] = None
+    user_ids: Optional[list[str]] = None
+    admin_ids: Optional[list[str]] = None
+
+
+class GroupTable:
+    def insert_new_group(
+        self, user_id: str, form_data: GroupForm
+    ) -> Optional[GroupModel]:
+        with get_db() as db:
+            group = GroupModel(
+                **{
+                    **form_data.model_dump(),
+                    "id": str(uuid.uuid4()),
+                    "user_id": user_id,
+                    "created_at": int(time.time()),
+                    "updated_at": int(time.time()),
+                }
+            )
+
+            try:
+                result = Groups(**group.model_dump())
+                db.add(result)
+                db.commit()
+                db.refresh(result)
+                if result:
+                    return GroupModel.model_validate(result)
+                else:
+                    return None
+
+            except Exception:
+                return None
+
+    def get_groups(self) -> list[GroupModel]:
+        with get_db() as db:
+            return [
+                GroupModel.model_validate(group)
+                for group in db.query(Groups).order_by(Groups.updated_at.desc()).all()
+            ]
+
+    def get_group_by_id(self, id: str) -> Optional[GroupModel]:
+        try:
+            with get_db() as db:
+                group = db.query(Groups).filter_by(id=id).first()
+                return GroupModel.model_validate(group) if group else None
+        except Exception:
+            return None
+
+    def update_group_by_id(
+        self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
+    ) -> Optional[GroupModel]:
+        try:
+            with get_db() as db:
+                db.query(Groups).filter_by(id=id).update(
+                    {
+                        **form_data.model_dump(exclude_none=True),
+                        "updated_at": int(time.time()),
+                    }
+                )
+                db.commit()
+                return self.get_group_by_id(id=id)
+        except Exception as e:
+            log.exception(e)
+            return None
+
+    def delete_group_by_id(self, id: str) -> bool:
+        try:
+            with get_db() as db:
+                db.query(Groups).filter_by(id=id).delete()
+                db.commit()
+                return True
+        except Exception:
+            return False
+
+    def delete_all_groups(self) -> bool:
+        with get_db() as db:
+            try:
+                db.query(Groups).delete()
+                db.commit()
+
+                return True
+            except Exception:
+                return False
+
+
+Groups = GroupTable()

+ 117 - 0
backend/open_webui/apps/webui/routers/groups.py

@@ -0,0 +1,117 @@
+import os
+from pathlib import Path
+from typing import Optional
+
+from open_webui.apps.webui.models.groups import (
+    Groups,
+    GroupForm,
+    GroupUpdateForm,
+    GroupResponse,
+)
+
+from open_webui.config import CACHE_DIR
+from open_webui.constants import ERROR_MESSAGES
+from fastapi import APIRouter, Depends, HTTPException, Request, status
+from open_webui.utils.utils import get_admin_user, get_verified_user
+
+router = APIRouter()
+
+############################
+# GetFunctions
+############################
+
+
+@router.get("/", response_model=list[GroupResponse])
+async def get_groups(user=Depends(get_admin_user)):
+    return Groups.get_groups()
+
+
+############################
+# CreateNewGroup
+############################
+
+
+@router.post("/create", response_model=Optional[GroupResponse])
+async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)):
+    try:
+        group = Groups.insert_new_group(user.id, form_data)
+        if group:
+            return group
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT("Error creating group"),
+            )
+    except Exception as e:
+        print(e)
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.DEFAULT(e),
+        )
+
+
+############################
+# GetGroupById
+############################
+
+
+@router.get("/id/{id}", response_model=Optional[GroupResponse])
+async def get_group_by_id(id: str, user=Depends(get_admin_user)):
+    group = Groups.get_group_by_id(id)
+    if group:
+        return group
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
+############################
+# UpdateGroupById
+############################
+
+
+@router.post("/id/{id}/update", response_model=Optional[GroupResponse])
+async def update_group_by_id(
+    id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user)
+):
+    try:
+        group = Groups.update_group_by_id(id, form_data)
+        if group:
+            return group
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT("Error updating group"),
+            )
+    except Exception as e:
+        print(e)
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.DEFAULT(e),
+        )
+
+
+############################
+# DeleteGroupById
+############################
+
+
+@router.delete("/id/{id}/delete", response_model=bool)
+async def delete_group_by_id(id: str, user=Depends(get_admin_user)):
+    try:
+        result = Groups.delete_group_by_id(id)
+        if result:
+            return result
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT("Error deleting group"),
+            )
+    except Exception as e:
+        print(e)
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.DEFAULT(e),
+        )