Ver código fonte

refac: Extend OIDC support to all OAuth authentication methods

Tryanks 3 meses atrás
pai
commit
f3e6dacf0d
2 arquivos alterados com 39 adições e 25 exclusões
  1. 36 13
      backend/open_webui/config.py
  2. 3 12
      backend/open_webui/utils/oauth.py

+ 36 - 13
backend/open_webui/config.py

@@ -468,12 +468,20 @@ OAUTH_ALLOWED_DOMAINS = PersistentConfig(
 def load_oauth_providers():
 def load_oauth_providers():
     OAUTH_PROVIDERS.clear()
     OAUTH_PROVIDERS.clear()
     if GOOGLE_CLIENT_ID.value and GOOGLE_CLIENT_SECRET.value:
     if GOOGLE_CLIENT_ID.value and GOOGLE_CLIENT_SECRET.value:
+        def google_oauth_register(client):
+            client.register(
+                name="google",
+                client_id=GOOGLE_CLIENT_ID.value,
+                client_secret=GOOGLE_CLIENT_SECRET.value,
+                server_metadata_url="https://accounts.google.com/.well-known/openid-configuration",
+                client_kwargs={
+                    "scope": GOOGLE_OAUTH_SCOPE.value
+                },
+                redirect_uri=GOOGLE_REDIRECT_URI.value,
+            )
         OAUTH_PROVIDERS["google"] = {
         OAUTH_PROVIDERS["google"] = {
-            "client_id": GOOGLE_CLIENT_ID.value,
-            "client_secret": GOOGLE_CLIENT_SECRET.value,
-            "server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration",
-            "scope": GOOGLE_OAUTH_SCOPE.value,
             "redirect_uri": GOOGLE_REDIRECT_URI.value,
             "redirect_uri": GOOGLE_REDIRECT_URI.value,
+            "register": google_oauth_register,
         }
         }
 
 
     if (
     if (
@@ -481,13 +489,21 @@ def load_oauth_providers():
         and MICROSOFT_CLIENT_SECRET.value
         and MICROSOFT_CLIENT_SECRET.value
         and MICROSOFT_CLIENT_TENANT_ID.value
         and MICROSOFT_CLIENT_TENANT_ID.value
     ):
     ):
+        def microsoft_oauth_register(client):
+            client.register(
+                name="microsoft",
+                client_id=MICROSOFT_CLIENT_ID.value,
+                client_secret=MICROSOFT_CLIENT_SECRET.value,
+                server_metadata_url=f"https://login.microsoftonline.com/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration",
+                client_kwargs={
+                    "scope": MICROSOFT_OAUTH_SCOPE.value,
+                },
+                redirect_uri=MICROSOFT_REDIRECT_URI.value,
+            )
         OAUTH_PROVIDERS["microsoft"] = {
         OAUTH_PROVIDERS["microsoft"] = {
-            "client_id": MICROSOFT_CLIENT_ID.value,
-            "client_secret": MICROSOFT_CLIENT_SECRET.value,
-            "server_metadata_url": f"https://login.microsoftonline.com/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration",
-            "scope": MICROSOFT_OAUTH_SCOPE.value,
             "redirect_uri": MICROSOFT_REDIRECT_URI.value,
             "redirect_uri": MICROSOFT_REDIRECT_URI.value,
             "picture_url": "https://graph.microsoft.com/v1.0/me/photo/$value",
             "picture_url": "https://graph.microsoft.com/v1.0/me/photo/$value",
+            "register": microsoft_oauth_register,
         }
         }
 
 
     if (
     if (
@@ -495,13 +511,20 @@ def load_oauth_providers():
         and OAUTH_CLIENT_SECRET.value
         and OAUTH_CLIENT_SECRET.value
         and OPENID_PROVIDER_URL.value
         and OPENID_PROVIDER_URL.value
     ):
     ):
+        def oidc_oauth_register(client):
+            client.register(
+                name="oidc",
+                client_id=OAUTH_CLIENT_ID.value,
+                client_secret=OAUTH_CLIENT_SECRET.value,
+                server_metadata_url=OPENID_PROVIDER_URL.value,
+                client_kwargs={
+                    "scope": OAUTH_SCOPES.value,
+                },
+                redirect_uri=OPENID_REDIRECT_URI.value,
+            )
         OAUTH_PROVIDERS["oidc"] = {
         OAUTH_PROVIDERS["oidc"] = {
-            "client_id": OAUTH_CLIENT_ID.value,
-            "client_secret": OAUTH_CLIENT_SECRET.value,
-            "server_metadata_url": OPENID_PROVIDER_URL.value,
-            "scope": OAUTH_SCOPES.value,
             "name": OAUTH_PROVIDER_NAME.value,
             "name": OAUTH_PROVIDER_NAME.value,
-            "redirect_uri": OPENID_REDIRECT_URI.value,
+            "register": oidc_oauth_register,
         }
         }
 
 
 
 

+ 3 - 12
backend/open_webui/utils/oauth.py

@@ -63,17 +63,8 @@ auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
 class OAuthManager:
 class OAuthManager:
     def __init__(self):
     def __init__(self):
         self.oauth = OAuth()
         self.oauth = OAuth()
-        for provider_name, provider_config in OAUTH_PROVIDERS.items():
-            self.oauth.register(
-                name=provider_name,
-                client_id=provider_config["client_id"],
-                client_secret=provider_config["client_secret"],
-                server_metadata_url=provider_config["server_metadata_url"],
-                client_kwargs={
-                    "scope": provider_config["scope"],
-                },
-                redirect_uri=provider_config["redirect_uri"],
-            )
+        for _, provider_config in OAUTH_PROVIDERS.items():
+            provider_config["register"](self.oauth)
 
 
     def get_client(self, provider_name):
     def get_client(self, provider_name):
         return self.oauth.create_client(provider_name)
         return self.oauth.create_client(provider_name)
@@ -207,7 +198,7 @@ class OAuthManager:
             log.warning(f"OAuth callback failed, user data is missing: {token}")
             log.warning(f"OAuth callback failed, user data is missing: {token}")
             raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
             raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
 
 
-        sub = user_data.get("sub")
+        sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub"))
         if not sub:
         if not sub:
             log.warning(f"OAuth callback failed, sub is missing: {user_data}")
             log.warning(f"OAuth callback failed, sub is missing: {user_data}")
             raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
             raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)