Prechádzať zdrojové kódy

WIP
- refactoring oauth functions to enable refresh functionality

Willnow, Patrick 6 mesiacov pred
rodič
commit
08ff494754

+ 1 - 1
backend/open_webui/apps/webui/main.py

@@ -44,7 +44,7 @@ from open_webui.config import (
     WEBHOOK_URL,
     WEBHOOK_URL,
     WEBUI_AUTH,
     WEBUI_AUTH,
     WEBUI_BANNERS,
     WEBUI_BANNERS,
-    AppConfig, OAUTH_ALLOWED_ROLES, OAUTH_ADMIN_ROLES,
+    AppConfig,
 )
 )
 from open_webui.env import (
 from open_webui.env import (
     WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
     WEBUI_AUTH_TRUSTED_EMAIL_HEADER,

+ 4 - 186
backend/open_webui/main.py

@@ -46,8 +46,6 @@ from open_webui.apps.webui.models.models import Models
 from open_webui.apps.webui.models.users import UserModel, Users
 from open_webui.apps.webui.models.users import UserModel, Users
 from open_webui.apps.webui.utils import load_function_module_by_id
 from open_webui.apps.webui.utils import load_function_module_by_id
 
 
-from authlib.integrations.starlette_client import OAuth
-from authlib.oidc.core import UserInfo
 
 
 from open_webui.config import (
 from open_webui.config import (
     CACHE_DIR,
     CACHE_DIR,
@@ -144,6 +142,8 @@ from open_webui.utils.response import (
     convert_streaming_response_ollama_to_openai,
     convert_streaming_response_ollama_to_openai,
 )
 )
 
 
+from open_webui.utils.oauth import oauth_manager
+
 if SAFE_MODE:
 if SAFE_MODE:
     print("SAFE MODE ENABLED")
     print("SAFE MODE ENABLED")
     Functions.deactivate_all_functions()
     Functions.deactivate_all_functions()
@@ -2168,20 +2168,6 @@ async def get_app_latest_release_version():
 # OAuth Login & Callback
 # OAuth Login & Callback
 ############################
 ############################
 
 
-oauth = OAuth()
-
-for provider_name, provider_config in OAUTH_PROVIDERS.items():
-    oauth.register(
-        name=provider_name,
-        client_id=provider_config["client_id"],
-        client_secret=provider_config["client_secret"],
-        server_metadata_url=provider_config["server_metadata_url"],
-        client_kwargs={
-            "scope": provider_config["scope"],
-        },
-        redirect_uri=provider_config["redirect_uri"],
-    )
-
 # SessionMiddleware is used by authlib for oauth
 # SessionMiddleware is used by authlib for oauth
 if len(OAUTH_PROVIDERS) > 0:
 if len(OAUTH_PROVIDERS) > 0:
     app.add_middleware(
     app.add_middleware(
@@ -2193,65 +2179,9 @@ if len(OAUTH_PROVIDERS) > 0:
     )
     )
 
 
 
 
-def get_user_role(user: UserModel, user_data: UserInfo) -> str:
-    if user and Users.get_num_users() == 1:
-        # If the user is the only user, assign the role "admin" - actually repairs role for single user on login
-        return "admin"
-    if not user and Users.get_num_users() == 0:
-        # If there are no users, assign the role "admin", as the first user will be an admin
-        return "admin"
-
-    if webui_app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT:
-        oauth_claim = webui_app.state.config.OAUTH_ROLES_CLAIM
-        oauth_allowed_roles = webui_app.state.config.OAUTH_ALLOWED_ROLES
-        oauth_admin_roles = webui_app.state.config.OAUTH_ADMIN_ROLES
-        oauth_roles = None
-        role = "pending"  # Default/fallback role if no matching roles are found
-
-        # Next block extracts the roles from the user data, accepting nested claims of any depth
-        if oauth_claim and oauth_allowed_roles and oauth_admin_roles:
-            claim_data = user_data
-            nested_claims = oauth_claim.split(".")
-            for nested_claim in nested_claims:
-                claim_data = claim_data.get(nested_claim, {})
-            oauth_roles = claim_data if isinstance(claim_data, list) else None
-
-        # If any roles are found, check if they match the allowed or admin roles
-        if oauth_roles:
-            # If role management is enabled, and matching roles are provided, use the roles
-            for allowed_role in oauth_allowed_roles:
-                # If the user has any of the allowed roles, assign the role "user"
-                if allowed_role in oauth_roles:
-                    role = "user"
-                    break
-            for admin_role in oauth_admin_roles:
-                # If the user has any of the admin roles, assign the role "admin"
-                if admin_role in oauth_roles:
-                    role = "admin"
-                    break
-    else:
-        if not user:
-            # If role management is disabled, use the default role for new users
-            role = webui_app.state.config.DEFAULT_USER_ROLE
-        else:
-            # If role management is disabled, use the existing role for existing users
-            role = user.role
-
-    return role
-
-
 @app.get("/oauth/{provider}/login")
 @app.get("/oauth/{provider}/login")
 async def oauth_login(provider: str, request: Request):
 async def oauth_login(provider: str, request: Request):
-    if provider not in OAUTH_PROVIDERS:
-        raise HTTPException(404)
-    # If the provider has a custom redirect URL, use that, otherwise automatically generate one
-    redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
-        "oauth_callback", provider=provider
-    )
-    client = oauth.create_client(provider)
-    if client is None:
-        raise HTTPException(404)
-    return await client.authorize_redirect(request, redirect_uri)
+    return oauth_manager.handle_login(provider, request)
 
 
 
 
 # OAuth login logic is as follows:
 # OAuth login logic is as follows:
@@ -2262,119 +2192,7 @@ async def oauth_login(provider: str, request: Request):
 #    - Email addresses are considered unique, so we fail registration if the email address is alreayd taken
 #    - Email addresses are considered unique, so we fail registration if the email address is alreayd taken
 @app.get("/oauth/{provider}/callback")
 @app.get("/oauth/{provider}/callback")
 async def oauth_callback(provider: str, request: Request, response: Response):
 async def oauth_callback(provider: str, request: Request, response: Response):
-    if provider not in OAUTH_PROVIDERS:
-        raise HTTPException(404)
-    client = oauth.create_client(provider)
-    try:
-        token = await client.authorize_access_token(request)
-    except Exception as e:
-        log.warning(f"OAuth callback error: {e}")
-        raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
-    user_data: UserInfo = token["userinfo"]
-
-    sub = user_data.get("sub")
-    if not sub:
-        log.warning(f"OAuth callback failed, sub is missing: {user_data}")
-        raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
-    provider_sub = f"{provider}@{sub}"
-    email_claim = webui_app.state.config.OAUTH_EMAIL_CLAIM
-    email = user_data.get(email_claim, "").lower()
-    # We currently mandate that email addresses are provided
-    if not email:
-        log.warning(f"OAuth callback failed, email is missing: {user_data}")
-        raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
-
-    # Check if the user exists
-    user = Users.get_user_by_oauth_sub(provider_sub)
-
-    if not user:
-        # If the user does not exist, check if merging is enabled
-        if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
-            # Check if the user exists by email
-            user = Users.get_user_by_email(email)
-            if user:
-                # Update the user with the new oauth sub
-                Users.update_user_oauth_sub_by_id(user.id, provider_sub)
-
-    if user:
-        determined_role = get_user_role(user, user_data)
-        if user.role != determined_role:
-            Users.update_user_role_by_id(user.id, determined_role)
-
-    if not user:
-        # If the user does not exist, check if signups are enabled
-        if ENABLE_OAUTH_SIGNUP.value:
-            # Check if an existing user with the same email already exists
-            existing_user = Users.get_user_by_email(user_data.get("email", "").lower())
-            if existing_user:
-                raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
-
-            picture_claim = webui_app.state.config.OAUTH_PICTURE_CLAIM
-            picture_url = user_data.get(picture_claim, "")
-            if picture_url:
-                # Download the profile image into a base64 string
-                try:
-                    async with aiohttp.ClientSession() as session:
-                        async with session.get(picture_url) as resp:
-                            picture = await resp.read()
-                            base64_encoded_picture = base64.b64encode(picture).decode(
-                                "utf-8"
-                            )
-                            guessed_mime_type = mimetypes.guess_type(picture_url)[0]
-                            if guessed_mime_type is None:
-                                # assume JPG, browsers are tolerant enough of image formats
-                                guessed_mime_type = "image/jpeg"
-                            picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
-                except Exception as e:
-                    log.error(f"Error downloading profile image '{picture_url}': {e}")
-                    picture_url = ""
-            if not picture_url:
-                picture_url = "/user.png"
-            username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM
-
-            role = get_user_role(None, user_data)
-
-            user = Auths.insert_new_auth(
-                email=email,
-                password=get_password_hash(
-                    str(uuid.uuid4())
-                ),  # Random password, not used
-                name=user_data.get(username_claim, "User"),
-                profile_image_url=picture_url,
-                role=role,
-                oauth_sub=provider_sub,
-            )
-
-            if webui_app.state.config.WEBHOOK_URL:
-                post_webhook(
-                    webui_app.state.config.WEBHOOK_URL,
-                    WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
-                    {
-                        "action": "signup",
-                        "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
-                        "user": user.model_dump_json(exclude_none=True),
-                    },
-                )
-        else:
-            raise HTTPException(
-                status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
-            )
-
-    jwt_token = create_token(
-        data={"id": user.id},
-        expires_delta=parse_duration(webui_app.state.config.JWT_EXPIRES_IN),
-    )
-
-    # Set the cookie token
-    response.set_cookie(
-        key="token",
-        value=jwt_token,
-        httponly=True,  # Ensures the cookie is not accessible via JavaScript
-    )
-
-    # Redirect back to the frontend with the JWT token
-    redirect_url = f"{request.base_url}auth#token={jwt_token}"
-    return RedirectResponse(url=redirect_url)
+    return oauth_manager.handle_callback(provider, request, response)
 
 
 
 
 @app.get("/manifest.json")
 @app.get("/manifest.json")

+ 234 - 0
backend/open_webui/utils/oauth.py

@@ -0,0 +1,234 @@
+import base64
+import mimetypes
+import uuid
+
+import aiohttp
+import logging
+from fastapi import (
+    HTTPException,
+    Request,
+    status,
+)
+from starlette.responses import RedirectResponse, Response, StreamingResponse
+from authlib.oidc.core import UserInfo
+
+from open_webui.apps.webui.models.auths import Auths
+from open_webui.apps.webui.models.users import Users, UserModel
+from open_webui.config import (
+    DEFAULT_USER_ROLE,
+    ENABLE_OAUTH_SIGNUP,
+    OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
+    OAUTH_PROVIDERS,
+    ENABLE_OAUTH_ROLE_MANAGEMENT,
+    OAUTH_ROLES_CLAIM,
+    OAUTH_EMAIL_CLAIM,
+    OAUTH_PICTURE_CLAIM,
+    OAUTH_USERNAME_CLAIM,
+    OAUTH_ALLOWED_ROLES,
+    OAUTH_ADMIN_ROLES, WEBHOOK_URL, JWT_EXPIRES_IN,
+)
+
+from authlib.integrations.starlette_client import OAuth
+
+from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
+from open_webui.utils.misc import parse_duration
+from open_webui.utils.utils import get_password_hash, create_token
+from open_webui.utils.webhook import post_webhook
+
+log = logging.getLogger(__name__)
+
+oauth_manager = {}
+oauth_manager.oauth = OAuth()
+
+for provider_name, provider_config in OAUTH_PROVIDERS.items():
+    oauth_manager.oauth.register(
+        name=provider_name,
+        client_id=provider_config["client_id"],
+        client_secret=provider_config["client_secret"],
+        server_metadata_url=provider_config["server_metadata_url"],
+        client_kwargs={
+            "scope": provider_config["scope"],
+        },
+        redirect_uri=provider_config["redirect_uri"],
+    )
+
+oauth_manager.get_client = oauth_manager.oauth.create_client
+
+def get_user_role(user: UserModel, user_data: UserInfo) -> str:
+    if user and Users.get_num_users() == 1:
+        # If the user is the only user, assign the role "admin" - actually repairs role for single user on login
+        return "admin"
+    if not user and Users.get_num_users() == 0:
+        # If there are no users, assign the role "admin", as the first user will be an admin
+        return "admin"
+
+    if ENABLE_OAUTH_ROLE_MANAGEMENT:
+        oauth_claim = OAUTH_ROLES_CLAIM
+        oauth_allowed_roles = OAUTH_ALLOWED_ROLES
+        oauth_admin_roles = OAUTH_ADMIN_ROLES
+        oauth_roles = None
+        role = "pending"  # Default/fallback role if no matching roles are found
+
+        # Next block extracts the roles from the user data, accepting nested claims of any depth
+        if oauth_claim and oauth_allowed_roles and oauth_admin_roles:
+            claim_data = user_data
+            nested_claims = oauth_claim.split(".")
+            for nested_claim in nested_claims:
+                claim_data = claim_data.get(nested_claim, {})
+            oauth_roles = claim_data if isinstance(claim_data, list) else None
+
+        # If any roles are found, check if they match the allowed or admin roles
+        if oauth_roles:
+            # If role management is enabled, and matching roles are provided, use the roles
+            for allowed_role in oauth_allowed_roles:
+                # If the user has any of the allowed roles, assign the role "user"
+                if allowed_role in oauth_roles:
+                    role = "user"
+                    break
+            for admin_role in oauth_admin_roles:
+                # If the user has any of the admin roles, assign the role "admin"
+                if admin_role in oauth_roles:
+                    role = "admin"
+                    break
+    else:
+        if not user:
+            # If role management is disabled, use the default role for new users
+            role = DEFAULT_USER_ROLE
+        else:
+            # If role management is disabled, use the existing role for existing users
+            role = user.role
+
+    return role
+
+oauth_manager.get_user_role = get_user_role
+
+async def handle_login(provider: str, request: Request):
+    if provider not in OAUTH_PROVIDERS:
+        raise HTTPException(404)
+    # If the provider has a custom redirect URL, use that, otherwise automatically generate one
+    redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
+        "oauth_callback", provider=provider
+    )
+    client = oauth_manager.get_client(provider)
+    if client is None:
+        raise HTTPException(404)
+    return await client.authorize_redirect(request, redirect_uri)
+
+oauth_manager.handle_login = handle_login
+
+async def handle_callback(provider: str, request: Request, response: Response):
+    if provider not in OAUTH_PROVIDERS:
+        raise HTTPException(404)
+    client = oauth_manager.get_client(provider)
+    try:
+        token = await client.authorize_access_token(request)
+    except Exception as e:
+        log.warning(f"OAuth callback error: {e}")
+        raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
+    user_data: UserInfo = token["userinfo"]
+
+    sub = user_data.get("sub")
+    if not sub:
+        log.warning(f"OAuth callback failed, sub is missing: {user_data}")
+        raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
+    provider_sub = f"{provider}@{sub}"
+    email_claim = OAUTH_EMAIL_CLAIM
+    email = user_data.get(email_claim, "").lower()
+    # We currently mandate that email addresses are provided
+    if not email:
+        log.warning(f"OAuth callback failed, email is missing: {user_data}")
+        raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
+
+    # Check if the user exists
+    user = Users.get_user_by_oauth_sub(provider_sub)
+
+    if not user:
+        # If the user does not exist, check if merging is enabled
+        if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
+            # Check if the user exists by email
+            user = Users.get_user_by_email(email)
+            if user:
+                # Update the user with the new oauth sub
+                Users.update_user_oauth_sub_by_id(user.id, provider_sub)
+
+    if user:
+        determined_role = get_user_role(user, user_data)
+        if user.role != determined_role:
+            Users.update_user_role_by_id(user.id, determined_role)
+
+    if not user:
+        # If the user does not exist, check if signups are enabled
+        if ENABLE_OAUTH_SIGNUP.value:
+            # Check if an existing user with the same email already exists
+            existing_user = Users.get_user_by_email(user_data.get("email", "").lower())
+            if existing_user:
+                raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
+
+            picture_claim = OAUTH_PICTURE_CLAIM
+            picture_url = user_data.get(picture_claim, "")
+            if picture_url:
+                # Download the profile image into a base64 string
+                try:
+                    async with aiohttp.ClientSession() as session:
+                        async with session.get(picture_url) as resp:
+                            picture = await resp.read()
+                            base64_encoded_picture = base64.b64encode(picture).decode(
+                                "utf-8"
+                            )
+                            guessed_mime_type = mimetypes.guess_type(picture_url)[0]
+                            if guessed_mime_type is None:
+                                # assume JPG, browsers are tolerant enough of image formats
+                                guessed_mime_type = "image/jpeg"
+                            picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
+                except Exception as e:
+                    log.error(f"Error downloading profile image '{picture_url}': {e}")
+                    picture_url = ""
+            if not picture_url:
+                picture_url = "/user.png"
+            username_claim = OAUTH_USERNAME_CLAIM
+
+            role = get_user_role(None, user_data)
+
+            user = Auths.insert_new_auth(
+                email=email,
+                password=get_password_hash(
+                    str(uuid.uuid4())
+                ),  # Random password, not used
+                name=user_data.get(username_claim, "User"),
+                profile_image_url=picture_url,
+                role=role,
+                oauth_sub=provider_sub,
+            )
+
+            if WEBHOOK_URL:
+                post_webhook(
+                    WEBHOOK_URL,
+                    WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
+                    {
+                        "action": "signup",
+                        "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
+                        "user": user.model_dump_json(exclude_none=True),
+                    },
+                )
+        else:
+            raise HTTPException(
+                status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
+            )
+
+    jwt_token = create_token(
+        data={"id": user.id},
+        expires_delta=parse_duration(JWT_EXPIRES_IN),
+    )
+
+    # Set the cookie token
+    response.set_cookie(
+        key="token",
+        value=jwt_token,
+        httponly=True,  # Ensures the cookie is not accessible via JavaScript
+    )
+
+    # Redirect back to the frontend with the JWT token
+    redirect_url = f"{request.base_url}auth#token={jwt_token}"
+    return RedirectResponse(url=redirect_url)
+
+oauth_manager.handle_callback = handle_callback