Browse Source

Merge pull request #3569 from Semihal/custom-openid-claims

feat: Custom claims for OAuth
Timothy Jaeryang Baek 10 months ago
parent
commit
08c024d752
3 changed files with 21 additions and 2 deletions
  1. 5 0
      backend/apps/webui/main.py
  2. 12 0
      backend/config.py
  3. 4 2
      backend/main.py

+ 5 - 0
backend/apps/webui/main.py

@@ -39,6 +39,8 @@ from config import (
     WEBUI_BANNERS,
     ENABLE_COMMUNITY_SHARING,
     AppConfig,
+    OAUTH_USERNAME_CLAIM,
+    OAUTH_PICTURE_CLAIM,
 )
 
 import inspect
@@ -74,6 +76,9 @@ app.state.config.BANNERS = WEBUI_BANNERS
 
 app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
 
+app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
+app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
+
 app.state.MODELS = {}
 app.state.TOOLS = {}
 app.state.FUNCTIONS = {}

+ 12 - 0
backend/config.py

@@ -393,6 +393,18 @@ OAUTH_PROVIDER_NAME = PersistentConfig(
     os.environ.get("OAUTH_PROVIDER_NAME", "SSO"),
 )
 
+OAUTH_USERNAME_CLAIM = PersistentConfig(
+    "OAUTH_USERNAME_CLAIM",
+    "oauth.oidc.username_claim",
+    os.environ.get("OAUTH_USERNAME_CLAIM", "name"),
+)
+
+OAUTH_PICTURE_CLAIM = PersistentConfig(
+    "OAUTH_USERNAME_CLAIM",
+    "oauth.oidc.avatar_claim",
+    os.environ.get("OAUTH_PICTURE_CLAIM", "picture"),
+)
+
 
 def load_oauth_providers():
     OAUTH_PROVIDERS.clear()

+ 4 - 2
backend/main.py

@@ -2064,7 +2064,8 @@ async def oauth_callback(provider: str, request: Request, response: Response):
             if existing_user:
                 raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
 
-            picture_url = user_data.get("picture", "")
+            picture_claim = webui_app.state.config.OAUTH_PICTURE_CLAIM
+            picture_url = user_data.get(picture_claim, "")
             if picture_url:
                 # Download the profile image into a base64 string
                 try:
@@ -2084,6 +2085,7 @@ async def oauth_callback(provider: str, request: Request, response: Response):
                     picture_url = ""
             if not picture_url:
                 picture_url = "/user.png"
+            username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM
             role = (
                 "admin"
                 if Users.get_num_users() == 0
@@ -2094,7 +2096,7 @@ async def oauth_callback(provider: str, request: Request, response: Response):
                 password=get_password_hash(
                     str(uuid.uuid4())
                 ),  # Random password, not used
-                name=user_data.get("name", "User"),
+                name=user_data.get(username_claim, "User"),
                 profile_image_url=picture_url,
                 role=role,
                 oauth_sub=provider_sub,