浏览代码

wip: access control backend

Timothy Jaeryang Baek 5 月之前
父节点
当前提交
2ab5b2fd71

+ 10 - 1
backend/open_webui/apps/webui/models/groups.py

@@ -68,7 +68,6 @@ class GroupResponse(BaseModel):
     permissions: Optional[dict] = None
     meta: Optional[dict] = None
     user_ids: list[str] = []
-    admin_ids: list[str] = []
     created_at: int  # timestamp in epoch
     updated_at: int  # timestamp in epoch
 
@@ -119,6 +118,16 @@ class GroupTable:
                 for group in db.query(Group).order_by(Group.updated_at.desc()).all()
             ]
 
+    def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
+        with get_db() as db:
+            return [
+                GroupModel.model_validate(group)
+                for group in db.query(Group)
+                .filter(Group.user_ids.contains([user_id]))
+                .order_by(Group.updated_at.desc())
+                .all()
+            ]
+
     def get_group_by_id(self, id: str) -> Optional[GroupModel]:
         try:
             with get_db() as db:

+ 35 - 0
backend/open_webui/apps/webui/models/models.py

@@ -4,9 +4,20 @@ from typing import Optional
 
 from open_webui.apps.webui.internal.db import Base, JSONField, get_db
 from open_webui.env import SRC_LOG_LEVELS
+
+from open_webui.apps.webui.models.groups import Groups
+
+
 from pydantic import BaseModel, ConfigDict
+
+from sqlalchemy import or_, and_, func
+from sqlalchemy.dialects import postgresql, sqlite
 from sqlalchemy import BigInteger, Column, Text, JSON
 
+
+from open_webui.utils.utils import has_access
+
+
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
 
@@ -112,8 +123,14 @@ class ModelModel(BaseModel):
 
 class ModelResponse(BaseModel):
     id: str
+    user_id: str
+    base_model_id: Optional[str] = None
+
     name: str
+    params: ModelParams
     meta: ModelMeta
+
+    access_control: Optional[dict] = None
     updated_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
 
@@ -157,6 +174,24 @@ class ModelsTable:
         with get_db() as db:
             return [ModelModel.model_validate(model) for model in db.query(Model).all()]
 
+    def get_models(self) -> list[ModelModel]:
+        with get_db() as db:
+            return [
+                ModelModel.model_validate(model)
+                for model in db.query(Model).filter(Model.base_model_id != None).all()
+            ]
+
+    def get_models_by_user_id(
+        self, user_id: str, permission: str = "write"
+    ) -> list[ModelModel]:
+        models = self.get_all_models()
+        return [
+            model
+            for model in models
+            if model.user_id == user_id
+            or has_access(user_id, permission, model.access_control)
+        ]
+
     def get_model_by_id(self, id: str) -> Optional[ModelModel]:
         try:
             with get_db() as db:

+ 60 - 0
backend/open_webui/apps/webui/models/prompts.py

@@ -2,6 +2,8 @@ import time
 from typing import Optional
 
 from open_webui.apps.webui.internal.db import Base, get_db
+from open_webui.apps.webui.models.groups import Groups
+
 from pydantic import BaseModel, ConfigDict
 from sqlalchemy import BigInteger, Column, String, Text, JSON
 
@@ -100,6 +102,64 @@ class PromptsTable:
                 PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
             ]
 
+    def get_prompts_by_user_id(
+        self, user_id: str, permission: str = "write"
+    ) -> list[PromptModel]:
+        prompts = self.get_prompts()
+
+        groups = Groups.get_groups_by_member_id(user_id)
+        group_ids = [group.id for group in groups]
+
+        if permission == "write":
+            return [
+                prompt
+                for prompt in prompts
+                if prompt.user_id == user_id
+                or (
+                    prompt.access_control
+                    and (
+                        any(
+                            group_id
+                            in prompt.access_control.get(permission, {}).get(
+                                "group_ids", []
+                            )
+                            for group_id in group_ids
+                        )
+                        or (
+                            user_id
+                            in prompt.access_control.get(permission, {}).get(
+                                "user_ids", []
+                            )
+                        )
+                    )
+                )
+            ]
+        elif permission == "read":
+            return [
+                prompt
+                for prompt in prompts
+                if prompt.user_id == user_id
+                or prompt.access_control is None
+                or (
+                    prompt.access_control
+                    and (
+                        any(
+                            prompt.access_control.get(permission, {}).get(
+                                "group_ids", []
+                            )
+                            in group_id
+                            for group_id in group_ids
+                        )
+                        or (
+                            user_id
+                            in prompt.access_control.get(permission, {}).get(
+                                "user_ids", []
+                            )
+                        )
+                    )
+                )
+            ]
+
     def update_prompt_by_command(
         self, command: str, form_data: PromptForm
     ) -> Optional[PromptModel]:

+ 64 - 42
backend/open_webui/apps/webui/routers/models.py

@@ -8,49 +8,46 @@ from open_webui.apps.webui.models.models import (
 )
 from open_webui.constants import ERROR_MESSAGES
 from fastapi import APIRouter, Depends, HTTPException, Request, status
-from open_webui.utils.utils import get_admin_user, get_verified_user
+
+
+from open_webui.utils.utils import get_admin_user, get_verified_user, has_access
 
 router = APIRouter()
 
+
 ###########################
-# getModels
+# GetModels
 ###########################
 
 
 @router.get("/", response_model=list[ModelResponse])
 async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
-    if id:
-        model = Models.get_model_by_id(id)
-        if model:
-            return [model]
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.NOT_FOUND,
-            )
+    if user.role == "admin":
+        return Models.get_models()
     else:
-        return Models.get_all_models()
+        return Models.get_models_by_user_id(user.id)
 
 
 ############################
-# AddNewModel
+# CreateNewModel
 ############################
 
 
-@router.post("/add", response_model=Optional[ModelModel])
-async def add_new_model(
-    request: Request,
+@router.post("/create", response_model=Optional[ModelModel])
+async def create_new_model(
     form_data: ModelForm,
-    user=Depends(get_admin_user),
+    user=Depends(get_verified_user),
 ):
-    if form_data.id in request.app.state.MODELS:
+
+    model = Models.get_model_by_id(form_data.id)
+    if model:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
             detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
         )
+
     else:
         model = Models.insert_new_model(form_data, user.id)
-
         if model:
             return model
         else:
@@ -60,37 +57,49 @@ async def add_new_model(
             )
 
 
+###########################
+# GetModelById
+###########################
+
+
+@router.get("/id/{id}", response_model=Optional[ModelResponse])
+async def get_model_by_id(id: str, user=Depends(get_verified_user)):
+    model = Models.get_model_by_id(id)
+    if model:
+        if (
+            user.role == "admin"
+            or model.user_id == user.id
+            or has_access(user.id, "read", model.access_control)
+        ):
+            return model
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
 ############################
 # UpdateModelById
 ############################
 
 
-@router.post("/update", response_model=Optional[ModelModel])
+@router.post("/id/{id}/update", response_model=Optional[ModelModel])
 async def update_model_by_id(
-    request: Request,
     id: str,
     form_data: ModelForm,
-    user=Depends(get_admin_user),
+    user=Depends(get_verified_user),
 ):
     model = Models.get_model_by_id(id)
-    if model:
-        model = Models.update_model_by_id(id, form_data)
-        return model
-    else:
-        if form_data.id in request.app.state.MODELS:
-            model = Models.insert_new_model(form_data, user.id)
-            if model:
-                return model
-            else:
-                raise HTTPException(
-                    status_code=status.HTTP_401_UNAUTHORIZED,
-                    detail=ERROR_MESSAGES.DEFAULT(),
-                )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.DEFAULT(),
-            )
+
+    if not model:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+    model = Models.update_model_by_id(id, form_data)
+    return model
 
 
 ############################
@@ -98,7 +107,20 @@ async def update_model_by_id(
 ############################
 
 
-@router.delete("/delete", response_model=bool)
-async def delete_model_by_id(id: str, user=Depends(get_admin_user)):
+@router.delete("/id/{id}/delete", response_model=bool)
+async def delete_model_by_id(id: str, user=Depends(get_verified_user)):
+    model = Models.get_model_by_id(id)
+    if not model:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+    if model.user_id != user.id:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.UNAUTHORIZED,
+        )
+
     result = Models.delete_model_by_id(id)
     return result

+ 22 - 4
backend/open_webui/apps/webui/routers/users.py

@@ -36,16 +36,34 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
 ############################
 
 
-@router.get("/permissions/user")
+class WorkspacePermissions(BaseModel):
+    models: bool
+    knowledge: bool
+    prompts: bool
+    tools: bool
+
+
+class ChatPermissions(BaseModel):
+    delete: bool
+    edit: bool
+    temporary: bool
+
+
+class UserPermissions(BaseModel):
+    workspace: WorkspacePermissions
+    chat: ChatPermissions
+
+
+@router.get("/permissions")
 async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
     return request.app.state.config.USER_PERMISSIONS
 
 
-@router.post("/permissions/user")
+@router.post("/permissions")
 async def update_user_permissions(
-    request: Request, form_data: dict, user=Depends(get_admin_user)
+    request: Request, form_data: UserPermissions, user=Depends(get_admin_user)
 ):
-    request.app.state.config.USER_PERMISSIONS = form_data
+    request.app.state.config.USER_PERMISSIONS = form_data.model_dump()
     return request.app.state.config.USER_PERMISSIONS
 
 

+ 27 - 1
backend/open_webui/config.py

@@ -739,6 +739,26 @@ DEFAULT_USER_ROLE = PersistentConfig(
     os.getenv("DEFAULT_USER_ROLE", "pending"),
 )
 
+
+USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS = (
+    os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS", "False").lower()
+    == "true"
+)
+
+USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS = (
+    os.environ.get("USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS", "False").lower()
+    == "true"
+)
+
+USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS = (
+    os.environ.get("USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS", "False").lower()
+    == "true"
+)
+
+USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS = (
+    os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS", "False").lower() == "true"
+)
+
 USER_PERMISSIONS_CHAT_DELETE = (
     os.environ.get("USER_PERMISSIONS_CHAT_DELETE", "True").lower() == "true"
 )
@@ -755,11 +775,17 @@ USER_PERMISSIONS = PersistentConfig(
     "USER_PERMISSIONS",
     "user.permissions",
     {
+        "workspace": {
+            "models": USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS,
+            "knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS,
+            "prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS,
+            "tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS,
+        },
         "chat": {
             "deletion": USER_PERMISSIONS_CHAT_DELETE,
             "editing": USER_PERMISSIONS_CHAT_EDIT,
             "temporary": USER_PERMISSIONS_CHAT_TEMPORARY,
-        }
+        },
     },
 )
 

+ 1 - 1
backend/open_webui/main.py

@@ -993,7 +993,7 @@ async def get_all_models():
 
             models.append(
                 {
-                    "id": custom_model.id,
+                    "id": f"open-webui-{custom_model.id}",
                     "name": custom_model.name,
                     "object": "model",
                     "created": custom_model.created_at,

+ 62 - 2
backend/open_webui/utils/utils.py

@@ -1,12 +1,18 @@
 import logging
 import uuid
+import jwt
+
 from datetime import UTC, datetime, timedelta
-from typing import Optional, Union
+from typing import Optional, Union, List, Dict
+
 
-import jwt
 from open_webui.apps.webui.models.users import Users
+from open_webui.apps.webui.models.groups import Groups
+
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.env import WEBUI_SECRET_KEY
+
+
 from fastapi import Depends, HTTPException, Request, Response, status
 from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
 from passlib.context import CryptContext
@@ -147,3 +153,57 @@ def get_admin_user(user=Depends(get_current_user)):
             detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
     return user
+
+
+def has_permission(
+    user_id: str,
+    permission_key: str,
+    default_permissions: Dict[str, bool] = {},
+) -> bool:
+    """
+    Check if a user has a specific permission by checking the group permissions
+    and falls back to default permissions if not found in any group.
+
+    Permission keys can be hierarchical and separated by dots ('.').
+    """
+
+    def get_permission(permissions: Dict[str, bool], keys: List[str]) -> bool:
+        """Traverse permissions dict using a list of keys (from dot-split permission_key)."""
+        for key in keys:
+            if key not in permissions:
+                return False  # If any part of the hierarchy is missing, deny access
+            permissions = permissions[key]  # Go one level deeper
+
+        return bool(permissions)  # Return the boolean at the final level
+
+    permission_hierarchy = permission_key.split(".")
+
+    # Retrieve user group permissions
+    user_groups = Groups.get_groups_by_member_id(user_id)
+
+    for group in user_groups:
+        group_permissions = group.permissions
+        if get_permission(group_permissions, permission_hierarchy):
+            return True
+
+    # Check default permissions afterwards if the group permissions don't allow it
+    return get_permission(default_permissions, permission_hierarchy)
+
+
+def has_access(
+    user_id: str,
+    action: str = "write",
+    access_control: Optional[dict] = None,
+) -> bool:
+    if access_control is None:
+        return action == "read"
+
+    user_groups = Groups.get_groups_by_member_id(user_id)
+    user_group_ids = [group.id for group in user_groups]
+    permission_access = access_control.get(action, {})
+    permitted_group_ids = permission_access.get("group_ids", [])
+    permitted_user_ids = permission_access.get("user_ids", [])
+
+    return user_id in permitted_user_ids or any(
+        group_id in permitted_group_ids for group_id in user_group_ids
+    )