瀏覽代碼

rewrite oauth role management logic to allow any custom roles to be used for oauth role to open webui role mapping

Willnow, Patrick 6 月之前
父節點
當前提交
edc15d0d7c
共有 3 個文件被更改,包括 105 次插入85 次删除
  1. 2 2
      backend/open_webui/apps/webui/main.py
  2. 14 3
      backend/open_webui/config.py
  3. 89 80
      backend/open_webui/main.py

+ 2 - 2
backend/open_webui/apps/webui/main.py

@@ -32,7 +32,7 @@ from open_webui.config import (
     ENABLE_MESSAGE_RATING,
     ENABLE_MESSAGE_RATING,
     ENABLE_SIGNUP,
     ENABLE_SIGNUP,
     JWT_EXPIRES_IN,
     JWT_EXPIRES_IN,
-    ENABLE_OAUTH_ROLE_MAPPING,
+    ENABLE_OAUTH_ROLE_MANAGEMENT,
     OAUTH_ROLES_CLAIM,
     OAUTH_ROLES_CLAIM,
     OAUTH_EMAIL_CLAIM,
     OAUTH_EMAIL_CLAIM,
     OAUTH_PICTURE_CLAIM,
     OAUTH_PICTURE_CLAIM,
@@ -95,7 +95,7 @@ app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
 app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
 app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
 app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
 app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
 
 
-app.state.config.ENABLE_OAUTH_ROLE_MAPPING = ENABLE_OAUTH_ROLE_MAPPING
+app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
 app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
 app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
 
 
 app.state.MODELS = {}
 app.state.MODELS = {}

+ 14 - 3
backend/open_webui/config.py

@@ -394,10 +394,10 @@ OAUTH_EMAIL_CLAIM = PersistentConfig(
     os.environ.get("OAUTH_EMAIL_CLAIM", "email"),
     os.environ.get("OAUTH_EMAIL_CLAIM", "email"),
 )
 )
 
 
-ENABLE_OAUTH_ROLE_MAPPING = PersistentConfig(
-    "ENABLE_OAUTH_ROLE_MAPPING",
+ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig(
+    "ENABLE_OAUTH_ROLE_MANAGEMENT",
     "oauth.enable_role_mapping",
     "oauth.enable_role_mapping",
-    os.environ.get("ENABLE_OAUTH_ROLE_MAPPING", "False").lower() == "true",
+    os.environ.get("ENABLE_OAUTH_ROLE_MANAGEMENT", "False").lower() == "true",
 )
 )
 
 
 OAUTH_ROLES_CLAIM = PersistentConfig(
 OAUTH_ROLES_CLAIM = PersistentConfig(
@@ -406,6 +406,17 @@ OAUTH_ROLES_CLAIM = PersistentConfig(
     os.environ.get("OAUTH_ROLES_CLAIM", "roles"),
     os.environ.get("OAUTH_ROLES_CLAIM", "roles"),
 )
 )
 
 
+OAUTH_ALLOWED_ROLES = PersistentConfig(
+    "OAUTH_ALLOWED_ROLES",
+    "oauth.allowed_roles",
+    [role.strip() for role in os.environ.get("OAUTH_ALLOWED_ROLES", "pending,user,admin").split(",")],
+)
+
+OAUTH_ADMIN_ROLES = PersistentConfig(
+    "OAUTH_ADMIN_ROLES",
+    "oauth.admin_roles",
+    [role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")],
+)
 
 
 def load_oauth_providers():
 def load_oauth_providers():
     OAUTH_PROVIDERS.clear()
     OAUTH_PROVIDERS.clear()

+ 89 - 80
backend/open_webui/main.py

@@ -16,7 +16,6 @@ from typing import Optional
 import aiohttp
 import aiohttp
 import requests
 import requests
 
 
-
 from open_webui.apps.audio.main import app as audio_app
 from open_webui.apps.audio.main import app as audio_app
 from open_webui.apps.images.main import app as images_app
 from open_webui.apps.images.main import app as images_app
 from open_webui.apps.ollama.main import app as ollama_app
 from open_webui.apps.ollama.main import app as ollama_app
@@ -47,11 +46,9 @@ from open_webui.apps.webui.models.models import Models
 from open_webui.apps.webui.models.users import UserModel, Users
 from open_webui.apps.webui.models.users import UserModel, Users
 from open_webui.apps.webui.utils import load_function_module_by_id
 from open_webui.apps.webui.utils import load_function_module_by_id
 
 
-
 from authlib.integrations.starlette_client import OAuth
 from authlib.integrations.starlette_client import OAuth
 from authlib.oidc.core import UserInfo
 from authlib.oidc.core import UserInfo
 
 
-
 from open_webui.config import (
 from open_webui.config import (
     CACHE_DIR,
     CACHE_DIR,
     CORS_ALLOW_ORIGIN,
     CORS_ALLOW_ORIGIN,
@@ -151,7 +148,6 @@ if SAFE_MODE:
     print("SAFE MODE ENABLED")
     print("SAFE MODE ENABLED")
     Functions.deactivate_all_functions()
     Functions.deactivate_all_functions()
 
 
-
 logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
 logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MAIN"])
 log.setLevel(SRC_LOG_LEVELS["MAIN"])
@@ -210,7 +206,6 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
 
 
 app.state.config.WEBHOOK_URL = WEBHOOK_URL
 app.state.config.WEBHOOK_URL = WEBHOOK_URL
 
 
-
 app.state.config.TASK_MODEL = TASK_MODEL
 app.state.config.TASK_MODEL = TASK_MODEL
 app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
 app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
 app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
 app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
@@ -238,14 +233,14 @@ def get_task_model_id(default_model_id):
     # Check if the user has a custom task model and use that model
     # Check if the user has a custom task model and use that model
     if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
     if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
         if (
         if (
-            app.state.config.TASK_MODEL
-            and app.state.config.TASK_MODEL in app.state.MODELS
+                app.state.config.TASK_MODEL
+                and app.state.config.TASK_MODEL in app.state.MODELS
         ):
         ):
             task_model_id = app.state.config.TASK_MODEL
             task_model_id = app.state.config.TASK_MODEL
     else:
     else:
         if (
         if (
-            app.state.config.TASK_MODEL_EXTERNAL
-            and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
+                app.state.config.TASK_MODEL_EXTERNAL
+                and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
         ):
         ):
             task_model_id = app.state.config.TASK_MODEL_EXTERNAL
             task_model_id = app.state.config.TASK_MODEL_EXTERNAL
 
 
@@ -382,7 +377,7 @@ async def get_content_from_response(response) -> Optional[str]:
 
 
 
 
 async def chat_completion_tools_handler(
 async def chat_completion_tools_handler(
-    body: dict, user: UserModel, extra_params: dict
+        body: dict, user: UserModel, extra_params: dict
 ) -> tuple[dict, dict]:
 ) -> tuple[dict, dict]:
     # If tool_ids field is present, call the functions
     # If tool_ids field is present, call the functions
     metadata = body.get("metadata", {})
     metadata = body.get("metadata", {})
@@ -608,8 +603,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             if prompt is None:
             if prompt is None:
                 raise Exception("No user message found")
                 raise Exception("No user message found")
             if (
             if (
-                rag_app.state.config.RELEVANCE_THRESHOLD == 0
-                and context_string.strip() == ""
+                    rag_app.state.config.RELEVANCE_THRESHOLD == 0
+                    and context_string.strip() == ""
             ):
             ):
                 log.debug(
                 log.debug(
                     f"With a 0 relevancy threshold for RAG, the context cannot be empty"
                     f"With a 0 relevancy threshold for RAG, the context cannot be empty"
@@ -676,6 +671,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
 
 
 app.add_middleware(ChatCompletionMiddleware)
 app.add_middleware(ChatCompletionMiddleware)
 
 
+
 ##################################
 ##################################
 #
 #
 # Pipeline Middleware
 # Pipeline Middleware
@@ -688,15 +684,15 @@ def get_sorted_filters(model_id):
         model
         model
         for model in app.state.MODELS.values()
         for model in app.state.MODELS.values()
         if "pipeline" in model
         if "pipeline" in model
-        and "type" in model["pipeline"]
-        and model["pipeline"]["type"] == "filter"
-        and (
-            model["pipeline"]["pipelines"] == ["*"]
-            or any(
-                model_id == target_model_id
-                for target_model_id in model["pipeline"]["pipelines"]
-            )
-        )
+           and "type" in model["pipeline"]
+           and model["pipeline"]["type"] == "filter"
+           and (
+                   model["pipeline"]["pipelines"] == ["*"]
+                   or any(
+               model_id == target_model_id
+               for target_model_id in model["pipeline"]["pipelines"]
+           )
+           )
     ]
     ]
     sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
     sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
     return sorted_filters
     return sorted_filters
@@ -798,7 +794,6 @@ class PipelineMiddleware(BaseHTTPMiddleware):
 
 
 app.add_middleware(PipelineMiddleware)
 app.add_middleware(PipelineMiddleware)
 
 
-
 app.add_middleware(
 app.add_middleware(
     CORSMiddleware,
     CORSMiddleware,
     allow_origins=CORS_ALLOW_ORIGIN,
     allow_origins=CORS_ALLOW_ORIGIN,
@@ -844,8 +839,8 @@ async def update_embedding_function(request: Request, call_next):
 @app.middleware("http")
 @app.middleware("http")
 async def inspect_websocket(request: Request, call_next):
 async def inspect_websocket(request: Request, call_next):
     if (
     if (
-        "/ws/socket.io" in request.url.path
-        and request.query_params.get("transport") == "websocket"
+            "/ws/socket.io" in request.url.path
+            and request.query_params.get("transport") == "websocket"
     ):
     ):
         upgrade = (request.headers.get("Upgrade") or "").lower()
         upgrade = (request.headers.get("Upgrade") or "").lower()
         connection = (request.headers.get("Connection") or "").lower().split(",")
         connection = (request.headers.get("Connection") or "").lower().split(",")
@@ -913,8 +908,8 @@ async def get_all_models():
         if custom_model.base_model_id is None:
         if custom_model.base_model_id is None:
             for model in models:
             for model in models:
                 if (
                 if (
-                    custom_model.id == model["id"]
-                    or custom_model.id == model["id"].split(":")[0]
+                        custom_model.id == model["id"]
+                        or custom_model.id == model["id"].split(":")[0]
                 ):
                 ):
                     model["name"] = custom_model.name
                     model["name"] = custom_model.name
                     model["info"] = custom_model.model_dump()
                     model["info"] = custom_model.model_dump()
@@ -931,8 +926,8 @@ async def get_all_models():
 
 
             for model in models:
             for model in models:
                 if (
                 if (
-                    custom_model.base_model_id == model["id"]
-                    or custom_model.base_model_id == model["id"].split(":")[0]
+                        custom_model.base_model_id == model["id"]
+                        or custom_model.base_model_id == model["id"].split(":")[0]
                 ):
                 ):
                     owned_by = model["owned_by"]
                     owned_by = model["owned_by"]
                     if "pipe" in model:
                     if "pipe" in model:
@@ -1727,7 +1722,7 @@ async def get_pipelines_list(user=Depends(get_admin_user)):
 
 
 @app.post("/api/pipelines/upload")
 @app.post("/api/pipelines/upload")
 async def upload_pipeline(
 async def upload_pipeline(
-    urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
+        urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
 ):
 ):
     print("upload_pipeline", urlIdx, file.filename)
     print("upload_pipeline", urlIdx, file.filename)
     # Check if the uploaded file is a python file
     # Check if the uploaded file is a python file
@@ -1904,9 +1899,9 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use
 
 
 @app.get("/api/pipelines/{pipeline_id}/valves")
 @app.get("/api/pipelines/{pipeline_id}/valves")
 async def get_pipeline_valves(
 async def get_pipeline_valves(
-    urlIdx: Optional[int],
-    pipeline_id: str,
-    user=Depends(get_admin_user),
+        urlIdx: Optional[int],
+        pipeline_id: str,
+        user=Depends(get_admin_user),
 ):
 ):
     r = None
     r = None
     try:
     try:
@@ -1942,9 +1937,9 @@ async def get_pipeline_valves(
 
 
 @app.get("/api/pipelines/{pipeline_id}/valves/spec")
 @app.get("/api/pipelines/{pipeline_id}/valves/spec")
 async def get_pipeline_valves_spec(
 async def get_pipeline_valves_spec(
-    urlIdx: Optional[int],
-    pipeline_id: str,
-    user=Depends(get_admin_user),
+        urlIdx: Optional[int],
+        pipeline_id: str,
+        user=Depends(get_admin_user),
 ):
 ):
     r = None
     r = None
     try:
     try:
@@ -1979,10 +1974,10 @@ async def get_pipeline_valves_spec(
 
 
 @app.post("/api/pipelines/{pipeline_id}/valves/update")
 @app.post("/api/pipelines/{pipeline_id}/valves/update")
 async def update_pipeline_valves(
 async def update_pipeline_valves(
-    urlIdx: Optional[int],
-    pipeline_id: str,
-    form_data: dict,
-    user=Depends(get_admin_user),
+        urlIdx: Optional[int],
+        pipeline_id: str,
+        form_data: dict,
+        user=Depends(get_admin_user),
 ):
 ):
     r = None
     r = None
     try:
     try:
@@ -2106,7 +2101,7 @@ class ModelFilterConfigForm(BaseModel):
 
 
 @app.post("/api/config/model/filter")
 @app.post("/api/config/model/filter")
 async def update_model_filter_config(
 async def update_model_filter_config(
-    form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
+        form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
 ):
 ):
     app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
     app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
     app.state.config.MODEL_FILTER_LIST = form_data.models
     app.state.config.MODEL_FILTER_LIST = form_data.models
@@ -2155,7 +2150,7 @@ async def get_app_latest_release_version():
     try:
     try:
         async with aiohttp.ClientSession(trust_env=True) as session:
         async with aiohttp.ClientSession(trust_env=True) as session:
             async with session.get(
             async with session.get(
-                "https://api.github.com/repos/open-webui/open-webui/releases/latest"
+                    "https://api.github.com/repos/open-webui/open-webui/releases/latest"
             ) as response:
             ) as response:
                 response.raise_for_status()
                 response.raise_for_status()
                 data = await response.json()
                 data = await response.json()
@@ -2198,6 +2193,53 @@ if len(OAUTH_PROVIDERS) > 0:
     )
     )
 
 
 
 
+def get_user_role(user: UserModel, user_data: UserInfo) -> str:
+    if user and Users.get_num_users() == 1:
+        # If the user is the only user, assign the role "admin" - actually repairs role for single user on login
+        return "admin"
+    if not user and Users.get_num_users() == 0:
+        # If there are no users, assign the role "admin", as the first user will be an admin
+        return "admin"
+
+    if webui_app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT:
+        oauth_claim = webui_app.state.config.OAUTH_ROLES_CLAIM
+        oauth_allowed_roles = webui_app.state.config.OAUTH_ALLOWED_ROLES
+        oauth_admin_roles = webui_app.state.config.OAUTH_ADMIN_ROLES
+        oauth_roles = None
+        role = "pending"  # Default/fallback role if no matching roles are found
+
+        # Next block extracts the roles from the user data, accepting nested claims of any depth
+        if oauth_claim and oauth_allowed_roles and oauth_admin_roles:
+            claim_data = user_data
+            nested_claims = oauth_claim.split(".")
+            for nested_claim in nested_claims:
+                claim_data = claim_data.get(nested_claim, {})
+            oauth_roles = claim_data if isinstance(claim_data, list) else None
+
+        # If any roles are found, check if they match the allowed or admin roles
+        if oauth_roles:
+            # If role management is enabled, and matching roles are provided, use the roles
+            for allowed_role in oauth_allowed_roles:
+                # If the user has any of the allowed roles, assign the role "user"
+                if allowed_role in oauth_roles:
+                    role = "user"
+                    break
+            for admin_role in oauth_admin_roles:
+                # If the user has any of the admin roles, assign the role "admin"
+                if admin_role in oauth_roles:
+                    role = "admin"
+                    break
+    else:
+        if not user:
+            # If role management is disabled, use the default role for new users
+            role = webui_app.state.config.DEFAULT_USER_ROLE
+        else:
+            # If role management is disabled, use the existing role for existing users
+            role = user.role
+
+    return role
+
+
 @app.get("/oauth/{provider}/login")
 @app.get("/oauth/{provider}/login")
 async def oauth_login(provider: str, request: Request):
 async def oauth_login(provider: str, request: Request):
     if provider not in OAUTH_PROVIDERS:
     if provider not in OAUTH_PROVIDERS:
@@ -2244,34 +2286,6 @@ async def oauth_callback(provider: str, request: Request, response: Response):
 
 
     # Check if the user exists
     # Check if the user exists
     user = Users.get_user_by_oauth_sub(provider_sub)
     user = Users.get_user_by_oauth_sub(provider_sub)
-    # print all user data content for debugging
-    log.info(f"User data: {user_data}")
-    if user:
-        role = user.role
-        if Users.get_num_users() == 1:
-            role = "admin"
-        elif webui_app.state.config.ENABLE_OAUTH_ROLE_MAPPING:
-            oauth_claim = webui_app.state.config.OAUTH_ROLES_CLAIM
-            oauth_roles = None
-
-            if oauth_claim:
-                claim_data = user_data
-                nested_claims = oauth_claim.split(".")
-                for nested_claim in nested_claims:
-                    claim_data = claim_data.get(nested_claim, {})
-                oauth_roles = claim_data if isinstance(claim_data, list) else None
-
-            log.info(f"User {user.name} has OAuth roles: {oauth_roles}")
-            if oauth_roles:
-                for allowed_role in ["pending", "user", "admin"]:
-                    role = allowed_role if allowed_role in oauth_roles else role
-            else:
-                # If role mapping is enabled, but no roles are provided, fall back to pending
-                role = "pending"
-            log.info(f"Applied role: {role} to user {user.name}")
-
-        if role != user.role:
-            Users.update_user_role_by_id(user.id, role)
 
 
     if not user:
     if not user:
         # If the user does not exist, check if merging is enabled
         # If the user does not exist, check if merging is enabled
@@ -2282,6 +2296,11 @@ async def oauth_callback(provider: str, request: Request, response: Response):
                 # Update the user with the new oauth sub
                 # Update the user with the new oauth sub
                 Users.update_user_oauth_sub_by_id(user.id, provider_sub)
                 Users.update_user_oauth_sub_by_id(user.id, provider_sub)
 
 
+    if user:
+        determined_role = get_user_role(user, user_data)
+        if user.role != determined_role:
+            Users.update_user_role_by_id(user.id, determined_role)
+
     if not user:
     if not user:
         # If the user does not exist, check if signups are enabled
         # If the user does not exist, check if signups are enabled
         if ENABLE_OAUTH_SIGNUP.value:
         if ENABLE_OAUTH_SIGNUP.value:
@@ -2313,17 +2332,7 @@ async def oauth_callback(provider: str, request: Request, response: Response):
                 picture_url = "/user.png"
                 picture_url = "/user.png"
             username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM
             username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM
 
 
-            role = webui_app.state.config.DEFAULT_USER_ROLE
-            if Users.get_num_users() == 0:
-                role = "admin"
-            elif webui_app.state.config.ENABLE_OAUTH_ROLE_MAPPING:
-                oauth_roles = user_data.get(webui_app.state.config.OAUTH_ROLE_CLAIM)
-                if oauth_roles:
-                    for allowed_role in ["pending", "user", "admin"]:
-                        role = allowed_role if allowed_role in oauth_roles else role
-                else:
-                    # If role mapping is enabled, but no roles are provided, fall back to pending
-                    role = "pending"
+            role = get_user_role(None, user_data)
 
 
             user = Auths.insert_new_auth(
             user = Auths.insert_new_auth(
                 email=email,
                 email=email,