浏览代码

wip: access control backend

Timothy Jaeryang Baek 5 月之前
父节点
当前提交
2ab5b2fd71

+ 10 - 1
backend/open_webui/apps/webui/models/groups.py

@@ -68,7 +68,6 @@ class GroupResponse(BaseModel):
     permissions: Optional[dict] = None
     permissions: Optional[dict] = None
     meta: Optional[dict] = None
     meta: Optional[dict] = None
     user_ids: list[str] = []
     user_ids: list[str] = []
-    admin_ids: list[str] = []
     created_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
     updated_at: int  # timestamp in epoch
     updated_at: int  # timestamp in epoch
 
 
@@ -119,6 +118,16 @@ class GroupTable:
                 for group in db.query(Group).order_by(Group.updated_at.desc()).all()
                 for group in db.query(Group).order_by(Group.updated_at.desc()).all()
             ]
             ]
 
 
+    def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
+        with get_db() as db:
+            return [
+                GroupModel.model_validate(group)
+                for group in db.query(Group)
+                .filter(Group.user_ids.contains([user_id]))
+                .order_by(Group.updated_at.desc())
+                .all()
+            ]
+
     def get_group_by_id(self, id: str) -> Optional[GroupModel]:
     def get_group_by_id(self, id: str) -> Optional[GroupModel]:
         try:
         try:
             with get_db() as db:
             with get_db() as db:

+ 35 - 0
backend/open_webui/apps/webui/models/models.py

@@ -4,9 +4,20 @@ from typing import Optional
 
 
 from open_webui.apps.webui.internal.db import Base, JSONField, get_db
 from open_webui.apps.webui.internal.db import Base, JSONField, get_db
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
+
+from open_webui.apps.webui.models.groups import Groups
+
+
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
+
+from sqlalchemy import or_, and_, func
+from sqlalchemy.dialects import postgresql, sqlite
 from sqlalchemy import BigInteger, Column, Text, JSON
 from sqlalchemy import BigInteger, Column, Text, JSON
 
 
+
+from open_webui.utils.utils import has_access
+
+
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
 
 
@@ -112,8 +123,14 @@ class ModelModel(BaseModel):
 
 
 class ModelResponse(BaseModel):
 class ModelResponse(BaseModel):
     id: str
     id: str
+    user_id: str
+    base_model_id: Optional[str] = None
+
     name: str
     name: str
+    params: ModelParams
     meta: ModelMeta
     meta: ModelMeta
+
+    access_control: Optional[dict] = None
     updated_at: int  # timestamp in epoch
     updated_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
 
 
@@ -157,6 +174,24 @@ class ModelsTable:
         with get_db() as db:
         with get_db() as db:
             return [ModelModel.model_validate(model) for model in db.query(Model).all()]
             return [ModelModel.model_validate(model) for model in db.query(Model).all()]
 
 
+    def get_models(self) -> list[ModelModel]:
+        with get_db() as db:
+            return [
+                ModelModel.model_validate(model)
+                for model in db.query(Model).filter(Model.base_model_id != None).all()
+            ]
+
+    def get_models_by_user_id(
+        self, user_id: str, permission: str = "write"
+    ) -> list[ModelModel]:
+        models = self.get_all_models()
+        return [
+            model
+            for model in models
+            if model.user_id == user_id
+            or has_access(user_id, permission, model.access_control)
+        ]
+
     def get_model_by_id(self, id: str) -> Optional[ModelModel]:
     def get_model_by_id(self, id: str) -> Optional[ModelModel]:
         try:
         try:
             with get_db() as db:
             with get_db() as db:

+ 60 - 0
backend/open_webui/apps/webui/models/prompts.py

@@ -2,6 +2,8 @@ import time
 from typing import Optional
 from typing import Optional
 
 
 from open_webui.apps.webui.internal.db import Base, get_db
 from open_webui.apps.webui.internal.db import Base, get_db
+from open_webui.apps.webui.models.groups import Groups
+
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
 from sqlalchemy import BigInteger, Column, String, Text, JSON
 from sqlalchemy import BigInteger, Column, String, Text, JSON
 
 
@@ -100,6 +102,64 @@ class PromptsTable:
                 PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
                 PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
             ]
             ]
 
 
+    def get_prompts_by_user_id(
+        self, user_id: str, permission: str = "write"
+    ) -> list[PromptModel]:
+        prompts = self.get_prompts()
+
+        groups = Groups.get_groups_by_member_id(user_id)
+        group_ids = [group.id for group in groups]
+
+        if permission == "write":
+            return [
+                prompt
+                for prompt in prompts
+                if prompt.user_id == user_id
+                or (
+                    prompt.access_control
+                    and (
+                        any(
+                            group_id
+                            in prompt.access_control.get(permission, {}).get(
+                                "group_ids", []
+                            )
+                            for group_id in group_ids
+                        )
+                        or (
+                            user_id
+                            in prompt.access_control.get(permission, {}).get(
+                                "user_ids", []
+                            )
+                        )
+                    )
+                )
+            ]
+        elif permission == "read":
+            return [
+                prompt
+                for prompt in prompts
+                if prompt.user_id == user_id
+                or prompt.access_control is None
+                or (
+                    prompt.access_control
+                    and (
+                        any(
+                            prompt.access_control.get(permission, {}).get(
+                                "group_ids", []
+                            )
+                            in group_id
+                            for group_id in group_ids
+                        )
+                        or (
+                            user_id
+                            in prompt.access_control.get(permission, {}).get(
+                                "user_ids", []
+                            )
+                        )
+                    )
+                )
+            ]
+
     def update_prompt_by_command(
     def update_prompt_by_command(
         self, command: str, form_data: PromptForm
         self, command: str, form_data: PromptForm
     ) -> Optional[PromptModel]:
     ) -> Optional[PromptModel]:

+ 64 - 42
backend/open_webui/apps/webui/routers/models.py

@@ -8,49 +8,46 @@ from open_webui.apps.webui.models.models import (
 )
 )
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.constants import ERROR_MESSAGES
 from fastapi import APIRouter, Depends, HTTPException, Request, status
 from fastapi import APIRouter, Depends, HTTPException, Request, status
-from open_webui.utils.utils import get_admin_user, get_verified_user
+
+
+from open_webui.utils.utils import get_admin_user, get_verified_user, has_access
 
 
 router = APIRouter()
 router = APIRouter()
 
 
+
 ###########################
 ###########################
-# getModels
+# GetModels
 ###########################
 ###########################
 
 
 
 
 @router.get("/", response_model=list[ModelResponse])
 @router.get("/", response_model=list[ModelResponse])
 async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
 async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
-    if id:
-        model = Models.get_model_by_id(id)
-        if model:
-            return [model]
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.NOT_FOUND,
-            )
+    if user.role == "admin":
+        return Models.get_models()
     else:
     else:
-        return Models.get_all_models()
+        return Models.get_models_by_user_id(user.id)
 
 
 
 
 ############################
 ############################
-# AddNewModel
+# CreateNewModel
 ############################
 ############################
 
 
 
 
-@router.post("/add", response_model=Optional[ModelModel])
-async def add_new_model(
-    request: Request,
+@router.post("/create", response_model=Optional[ModelModel])
+async def create_new_model(
     form_data: ModelForm,
     form_data: ModelForm,
-    user=Depends(get_admin_user),
+    user=Depends(get_verified_user),
 ):
 ):
-    if form_data.id in request.app.state.MODELS:
+
+    model = Models.get_model_by_id(form_data.id)
+    if model:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
             status_code=status.HTTP_401_UNAUTHORIZED,
             detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
             detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
         )
         )
+
     else:
     else:
         model = Models.insert_new_model(form_data, user.id)
         model = Models.insert_new_model(form_data, user.id)
-
         if model:
         if model:
             return model
             return model
         else:
         else:
@@ -60,37 +57,49 @@ async def add_new_model(
             )
             )
 
 
 
 
+###########################
+# GetModelById
+###########################
+
+
+@router.get("/id/{id}", response_model=Optional[ModelResponse])
+async def get_model_by_id(id: str, user=Depends(get_verified_user)):
+    model = Models.get_model_by_id(id)
+    if model:
+        if (
+            user.role == "admin"
+            or model.user_id == user.id
+            or has_access(user.id, "read", model.access_control)
+        ):
+            return model
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
 ############################
 ############################
 # UpdateModelById
 # UpdateModelById
 ############################
 ############################
 
 
 
 
-@router.post("/update", response_model=Optional[ModelModel])
+@router.post("/id/{id}/update", response_model=Optional[ModelModel])
 async def update_model_by_id(
 async def update_model_by_id(
-    request: Request,
     id: str,
     id: str,
     form_data: ModelForm,
     form_data: ModelForm,
-    user=Depends(get_admin_user),
+    user=Depends(get_verified_user),
 ):
 ):
     model = Models.get_model_by_id(id)
     model = Models.get_model_by_id(id)
-    if model:
-        model = Models.update_model_by_id(id, form_data)
-        return model
-    else:
-        if form_data.id in request.app.state.MODELS:
-            model = Models.insert_new_model(form_data, user.id)
-            if model:
-                return model
-            else:
-                raise HTTPException(
-                    status_code=status.HTTP_401_UNAUTHORIZED,
-                    detail=ERROR_MESSAGES.DEFAULT(),
-                )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.DEFAULT(),
-            )
+
+    if not model:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+    model = Models.update_model_by_id(id, form_data)
+    return model
 
 
 
 
 ############################
 ############################
@@ -98,7 +107,20 @@ async def update_model_by_id(
 ############################
 ############################
 
 
 
 
-@router.delete("/delete", response_model=bool)
-async def delete_model_by_id(id: str, user=Depends(get_admin_user)):
+@router.delete("/id/{id}/delete", response_model=bool)
+async def delete_model_by_id(id: str, user=Depends(get_verified_user)):
+    model = Models.get_model_by_id(id)
+    if not model:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+    if model.user_id != user.id:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.UNAUTHORIZED,
+        )
+
     result = Models.delete_model_by_id(id)
     result = Models.delete_model_by_id(id)
     return result
     return result

+ 22 - 4
backend/open_webui/apps/webui/routers/users.py

@@ -36,16 +36,34 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
 ############################
 ############################
 
 
 
 
-@router.get("/permissions/user")
+class WorkspacePermissions(BaseModel):
+    models: bool
+    knowledge: bool
+    prompts: bool
+    tools: bool
+
+
+class ChatPermissions(BaseModel):
+    delete: bool
+    edit: bool
+    temporary: bool
+
+
+class UserPermissions(BaseModel):
+    workspace: WorkspacePermissions
+    chat: ChatPermissions
+
+
+@router.get("/permissions")
 async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
 async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
     return request.app.state.config.USER_PERMISSIONS
     return request.app.state.config.USER_PERMISSIONS
 
 
 
 
-@router.post("/permissions/user")
+@router.post("/permissions")
 async def update_user_permissions(
 async def update_user_permissions(
-    request: Request, form_data: dict, user=Depends(get_admin_user)
+    request: Request, form_data: UserPermissions, user=Depends(get_admin_user)
 ):
 ):
-    request.app.state.config.USER_PERMISSIONS = form_data
+    request.app.state.config.USER_PERMISSIONS = form_data.model_dump()
     return request.app.state.config.USER_PERMISSIONS
     return request.app.state.config.USER_PERMISSIONS
 
 
 
 

+ 27 - 1
backend/open_webui/config.py

@@ -739,6 +739,26 @@ DEFAULT_USER_ROLE = PersistentConfig(
     os.getenv("DEFAULT_USER_ROLE", "pending"),
     os.getenv("DEFAULT_USER_ROLE", "pending"),
 )
 )
 
 
+
+USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS = (
+    os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS", "False").lower()
+    == "true"
+)
+
+USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS = (
+    os.environ.get("USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS", "False").lower()
+    == "true"
+)
+
+USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS = (
+    os.environ.get("USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS", "False").lower()
+    == "true"
+)
+
+USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS = (
+    os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS", "False").lower() == "true"
+)
+
 USER_PERMISSIONS_CHAT_DELETE = (
 USER_PERMISSIONS_CHAT_DELETE = (
     os.environ.get("USER_PERMISSIONS_CHAT_DELETE", "True").lower() == "true"
     os.environ.get("USER_PERMISSIONS_CHAT_DELETE", "True").lower() == "true"
 )
 )
@@ -755,11 +775,17 @@ USER_PERMISSIONS = PersistentConfig(
     "USER_PERMISSIONS",
     "USER_PERMISSIONS",
     "user.permissions",
     "user.permissions",
     {
     {
+        "workspace": {
+            "models": USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS,
+            "knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS,
+            "prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS,
+            "tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS,
+        },
         "chat": {
         "chat": {
             "deletion": USER_PERMISSIONS_CHAT_DELETE,
             "deletion": USER_PERMISSIONS_CHAT_DELETE,
             "editing": USER_PERMISSIONS_CHAT_EDIT,
             "editing": USER_PERMISSIONS_CHAT_EDIT,
             "temporary": USER_PERMISSIONS_CHAT_TEMPORARY,
             "temporary": USER_PERMISSIONS_CHAT_TEMPORARY,
-        }
+        },
     },
     },
 )
 )
 
 

+ 1 - 1
backend/open_webui/main.py

@@ -993,7 +993,7 @@ async def get_all_models():
 
 
             models.append(
             models.append(
                 {
                 {
-                    "id": custom_model.id,
+                    "id": f"open-webui-{custom_model.id}",
                     "name": custom_model.name,
                     "name": custom_model.name,
                     "object": "model",
                     "object": "model",
                     "created": custom_model.created_at,
                     "created": custom_model.created_at,

+ 62 - 2
backend/open_webui/utils/utils.py

@@ -1,12 +1,18 @@
 import logging
 import logging
 import uuid
 import uuid
+import jwt
+
 from datetime import UTC, datetime, timedelta
 from datetime import UTC, datetime, timedelta
-from typing import Optional, Union
+from typing import Optional, Union, List, Dict
+
 
 
-import jwt
 from open_webui.apps.webui.models.users import Users
 from open_webui.apps.webui.models.users import Users
+from open_webui.apps.webui.models.groups import Groups
+
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.env import WEBUI_SECRET_KEY
 from open_webui.env import WEBUI_SECRET_KEY
+
+
 from fastapi import Depends, HTTPException, Request, Response, status
 from fastapi import Depends, HTTPException, Request, Response, status
 from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
 from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
 from passlib.context import CryptContext
 from passlib.context import CryptContext
@@ -147,3 +153,57 @@ def get_admin_user(user=Depends(get_current_user)):
             detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
             detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
         )
     return user
     return user
+
+
+def has_permission(
+    user_id: str,
+    permission_key: str,
+    default_permissions: Dict[str, bool] = {},
+) -> bool:
+    """
+    Check if a user has a specific permission by checking the group permissions
+    and falls back to default permissions if not found in any group.
+
+    Permission keys can be hierarchical and separated by dots ('.').
+    """
+
+    def get_permission(permissions: Dict[str, bool], keys: List[str]) -> bool:
+        """Traverse permissions dict using a list of keys (from dot-split permission_key)."""
+        for key in keys:
+            if key not in permissions:
+                return False  # If any part of the hierarchy is missing, deny access
+            permissions = permissions[key]  # Go one level deeper
+
+        return bool(permissions)  # Return the boolean at the final level
+
+    permission_hierarchy = permission_key.split(".")
+
+    # Retrieve user group permissions
+    user_groups = Groups.get_groups_by_member_id(user_id)
+
+    for group in user_groups:
+        group_permissions = group.permissions
+        if get_permission(group_permissions, permission_hierarchy):
+            return True
+
+    # Check default permissions afterwards if the group permissions don't allow it
+    return get_permission(default_permissions, permission_hierarchy)
+
+
+def has_access(
+    user_id: str,
+    action: str = "write",
+    access_control: Optional[dict] = None,
+) -> bool:
+    if access_control is None:
+        return action == "read"
+
+    user_groups = Groups.get_groups_by_member_id(user_id)
+    user_group_ids = [group.id for group in user_groups]
+    permission_access = access_control.get(action, {})
+    permitted_group_ids = permission_access.get("group_ids", [])
+    permitted_user_ids = permission_access.get("user_ids", [])
+
+    return user_id in permitted_user_ids or any(
+        group_id in permitted_group_ids for group_id in user_group_ids
+    )