瀏覽代碼

enh: user permissions

Timothy Jaeryang Baek 5 月之前
父節點
當前提交
c0371f6525

+ 2 - 0
backend/open_webui/apps/webui/main.py

@@ -110,6 +110,8 @@ app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
 app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
 app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
 app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
+
+
 app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
 app.state.config.WEBHOOK_URL = WEBHOOK_URL
 app.state.config.BANNERS = WEBUI_BANNERS

+ 95 - 31
backend/open_webui/apps/webui/routers/auths.py

@@ -40,10 +40,12 @@ from open_webui.utils.utils import (
     get_password_hash,
 )
 from open_webui.utils.webhook import post_webhook
+from open_webui.utils.access_control import get_permissions
+
 from typing import Optional, List
 
-from ldap3 import Server, Connection, ALL, Tls
 from ssl import CERT_REQUIRED, PROTOCOL_TLS
+from ldap3 import Server, Connection, ALL, Tls
 from ldap3.utils.conv import escape_filter_chars
 
 router = APIRouter()
@@ -58,6 +60,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
 
 class SessionUserResponse(Token, UserResponse):
     expires_at: Optional[int] = None
+    permissions: Optional[dict] = None
 
 
 @router.get("/", response_model=SessionUserResponse)
@@ -90,6 +93,10 @@ async def get_session_user(
         secure=WEBUI_SESSION_COOKIE_SECURE,
     )
 
+    user_permissions = get_permissions(
+        user.id, request.app.state.config.USER_PERMISSIONS
+    )
+
     return {
         "token": token,
         "token_type": "Bearer",
@@ -99,6 +106,7 @@ async def get_session_user(
         "name": user.name,
         "role": user.role,
         "profile_image_url": user.profile_image_url,
+        "permissions": user_permissions,
     }
 
 
@@ -163,40 +171,67 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
     LDAP_APP_PASSWORD = request.app.state.config.LDAP_APP_PASSWORD
     LDAP_USE_TLS = request.app.state.config.LDAP_USE_TLS
     LDAP_CA_CERT_FILE = request.app.state.config.LDAP_CA_CERT_FILE
-    LDAP_CIPHERS = request.app.state.config.LDAP_CIPHERS if request.app.state.config.LDAP_CIPHERS else 'ALL'
+    LDAP_CIPHERS = (
+        request.app.state.config.LDAP_CIPHERS
+        if request.app.state.config.LDAP_CIPHERS
+        else "ALL"
+    )
 
     if not ENABLE_LDAP:
         raise HTTPException(400, detail="LDAP authentication is not enabled")
 
     try:
-        tls = Tls(validate=CERT_REQUIRED, version=PROTOCOL_TLS, ca_certs_file=LDAP_CA_CERT_FILE, ciphers=LDAP_CIPHERS)
+        tls = Tls(
+            validate=CERT_REQUIRED,
+            version=PROTOCOL_TLS,
+            ca_certs_file=LDAP_CA_CERT_FILE,
+            ciphers=LDAP_CIPHERS,
+        )
     except Exception as e:
         log.error(f"An error occurred on TLS: {str(e)}")
         raise HTTPException(400, detail=str(e))
 
     try:
-        server = Server(host=LDAP_SERVER_HOST, port=LDAP_SERVER_PORT, get_info=ALL, use_ssl=LDAP_USE_TLS, tls=tls)
-        connection_app = Connection(server, LDAP_APP_DN, LDAP_APP_PASSWORD, auto_bind='NONE', authentication='SIMPLE')
+        server = Server(
+            host=LDAP_SERVER_HOST,
+            port=LDAP_SERVER_PORT,
+            get_info=ALL,
+            use_ssl=LDAP_USE_TLS,
+            tls=tls,
+        )
+        connection_app = Connection(
+            server,
+            LDAP_APP_DN,
+            LDAP_APP_PASSWORD,
+            auto_bind="NONE",
+            authentication="SIMPLE",
+        )
         if not connection_app.bind():
             raise HTTPException(400, detail="Application account bind failed")
 
         search_success = connection_app.search(
             search_base=LDAP_SEARCH_BASE,
-            search_filter=f'(&({LDAP_ATTRIBUTE_FOR_USERNAME}={escape_filter_chars(form_data.user.lower())}){LDAP_SEARCH_FILTERS})',
-            attributes=[f'{LDAP_ATTRIBUTE_FOR_USERNAME}', 'mail', 'cn']
+            search_filter=f"(&({LDAP_ATTRIBUTE_FOR_USERNAME}={escape_filter_chars(form_data.user.lower())}){LDAP_SEARCH_FILTERS})",
+            attributes=[f"{LDAP_ATTRIBUTE_FOR_USERNAME}", "mail", "cn"],
         )
 
         if not search_success:
             raise HTTPException(400, detail="User not found in the LDAP server")
 
         entry = connection_app.entries[0]
-        username = str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}']).lower()
-        mail = str(entry['mail'])
-        cn = str(entry['cn'])
+        username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower()
+        mail = str(entry["mail"])
+        cn = str(entry["cn"])
         user_dn = entry.entry_dn
 
         if username == form_data.user.lower():
-            connection_user = Connection(server, user_dn, form_data.password, auto_bind='NONE', authentication='SIMPLE')
+            connection_user = Connection(
+                server,
+                user_dn,
+                form_data.password,
+                auto_bind="NONE",
+                authentication="SIMPLE",
+            )
             if not connection_user.bind():
                 raise HTTPException(400, f"Authentication failed for {form_data.user}")
 
@@ -205,14 +240,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
 
                 try:
                     hashed = get_password_hash(form_data.password)
-                    user = Auths.insert_new_auth(
-                        mail,
-                        hashed,
-                        cn
-                    )
+                    user = Auths.insert_new_auth(mail, hashed, cn)
 
                     if not user:
-                        raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
+                        raise HTTPException(
+                            500, detail=ERROR_MESSAGES.CREATE_USER_ERROR
+                        )
 
                 except HTTPException:
                     raise
@@ -224,7 +257,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
             if user:
                 token = create_token(
                     data={"id": user.id},
-                    expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
+                    expires_delta=parse_duration(
+                        request.app.state.config.JWT_EXPIRES_IN
+                    ),
                 )
 
                 # Set the cookie token
@@ -246,7 +281,10 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
             else:
                 raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
         else:
-            raise HTTPException(400, f"User {form_data.user} does not match the record. Search result: {str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'])}")
+            raise HTTPException(
+                400,
+                f"User {form_data.user} does not match the record. Search result: {str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'])}",
+            )
     except Exception as e:
         raise HTTPException(400, detail=str(e))
 
@@ -325,6 +363,10 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
             secure=WEBUI_SESSION_COOKIE_SECURE,
         )
 
+        user_permissions = get_permissions(
+            user.id, request.app.state.config.USER_PERMISSIONS
+        )
+
         return {
             "token": token,
             "token_type": "Bearer",
@@ -334,6 +376,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
             "name": user.name,
             "role": user.role,
             "profile_image_url": user.profile_image_url,
+            "permissions": user_permissions,
         }
     else:
         raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
@@ -426,6 +469,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
                     },
                 )
 
+            user_permissions = get_permissions(
+                user.id, request.app.state.config.USER_PERMISSIONS
+            )
+
             return {
                 "token": token,
                 "token_type": "Bearer",
@@ -435,6 +482,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
                 "name": user.name,
                 "role": user.role,
                 "profile_image_url": user.profile_image_url,
+                "permissions": user_permissions,
             }
         else:
             raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
@@ -583,19 +631,18 @@ class LdapServerConfig(BaseModel):
     label: str
     host: str
     port: Optional[int] = None
-    attribute_for_username: str = 'uid'
+    attribute_for_username: str = "uid"
     app_dn: str
     app_dn_password: str
     search_base: str
-    search_filters: str = ''
+    search_filters: str = ""
     use_tls: bool = True
     certificate_path: Optional[str] = None
-    ciphers: Optional[str] = 'ALL'
+    ciphers: Optional[str] = "ALL"
+
 
 @router.get("/admin/config/ldap/server", response_model=LdapServerConfig)
-async def get_ldap_server(
-    request: Request, user=Depends(get_admin_user)
-):
+async def get_ldap_server(request: Request, user=Depends(get_admin_user)):
     return {
         "label": request.app.state.config.LDAP_SERVER_LABEL,
         "host": request.app.state.config.LDAP_SERVER_HOST,
@@ -607,26 +654,38 @@ async def get_ldap_server(
         "search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
         "use_tls": request.app.state.config.LDAP_USE_TLS,
         "certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
-        "ciphers": request.app.state.config.LDAP_CIPHERS
+        "ciphers": request.app.state.config.LDAP_CIPHERS,
     }
 
+
 @router.post("/admin/config/ldap/server")
 async def update_ldap_server(
     request: Request, form_data: LdapServerConfig, user=Depends(get_admin_user)
 ):
-    required_fields = ['label', 'host', 'attribute_for_username', 'app_dn', 'app_dn_password', 'search_base']
+    required_fields = [
+        "label",
+        "host",
+        "attribute_for_username",
+        "app_dn",
+        "app_dn_password",
+        "search_base",
+    ]
     for key in required_fields:
         value = getattr(form_data, key)
         if not value:
             raise HTTPException(400, detail=f"Required field {key} is empty")
 
     if form_data.use_tls and not form_data.certificate_path:
-        raise HTTPException(400, detail="TLS is enabled but certificate file path is missing")
+        raise HTTPException(
+            400, detail="TLS is enabled but certificate file path is missing"
+        )
 
     request.app.state.config.LDAP_SERVER_LABEL = form_data.label
     request.app.state.config.LDAP_SERVER_HOST = form_data.host
     request.app.state.config.LDAP_SERVER_PORT = form_data.port
-    request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = form_data.attribute_for_username
+    request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = (
+        form_data.attribute_for_username
+    )
     request.app.state.config.LDAP_APP_DN = form_data.app_dn
     request.app.state.config.LDAP_APP_PASSWORD = form_data.app_dn_password
     request.app.state.config.LDAP_SEARCH_BASE = form_data.search_base
@@ -646,18 +705,23 @@ async def update_ldap_server(
         "search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
         "use_tls": request.app.state.config.LDAP_USE_TLS,
         "certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
-        "ciphers": request.app.state.config.LDAP_CIPHERS
+        "ciphers": request.app.state.config.LDAP_CIPHERS,
     }
 
+
 @router.get("/admin/config/ldap")
 async def get_ldap_config(request: Request, user=Depends(get_admin_user)):
     return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP}
 
+
 class LdapConfigForm(BaseModel):
     enable_ldap: Optional[bool] = None
 
+
 @router.post("/admin/config/ldap")
-async def update_ldap_config(request: Request, form_data: LdapConfigForm, user=Depends(get_admin_user)):
+async def update_ldap_config(
+    request: Request, form_data: LdapConfigForm, user=Depends(get_admin_user)
+):
     request.app.state.config.ENABLE_LDAP = form_data.enable_ldap
     return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP}
 

+ 32 - 0
backend/open_webui/utils/access_control.py

@@ -2,6 +2,38 @@ from typing import Optional, Union, List, Dict
 from open_webui.apps.webui.models.groups import Groups
 
 
+def get_permissions(
+    user_id: str,
+    default_permissions: Dict[str, bool] = {},
+) -> dict:
+    """
+    Get all permissions for a user by combining the permissions of all groups the user is a member of.
+    If a permission is defined in multiple groups, the most permissive value is used.
+    """
+
+    def merge_permissions(
+        permissions: Dict[str, bool], new_permissions: Dict[str, bool]
+    ) -> Dict[str, bool]:
+        """Merge two permission dictionaries, keeping the most permissive value."""
+        for key, value in new_permissions.items():
+            if key not in permissions:
+                permissions[key] = value
+            else:
+                permissions[key] = (
+                    permissions[key] or value
+                )  # Use the most permissive value
+
+        return permissions
+
+    user_groups = Groups.get_groups_by_member_id(user_id)
+    user_permissions = default_permissions.copy()
+
+    for group in user_groups:
+        user_permissions = merge_permissions(user_permissions, group.permissions)
+
+    return user_permissions
+
+
 def has_permission(
     user_id: str,
     permission_key: str,