Просмотр исходного кода

refac: modify oauth login logic for unique email addresses

Jun Siang Cheah 10 месяцев назад
Родитель
Сommit
981f384154
2 измененных файлов с 20 добавлено и 15 удалено
  1. 2 9
      backend/apps/webui/models/users.py
  2. 18 6
      backend/main.py

+ 2 - 9
backend/apps/webui/models/users.py

@@ -122,16 +122,9 @@ class UsersTable:
         except:
             return None
 
-    def get_user_by_email(
-        self, email: str, oauth_user: bool = False
-    ) -> Optional[UserModel]:
+    def get_user_by_email(self, email: str) -> Optional[UserModel]:
         try:
-            conditions = (
-                (User.email == email, User.oauth_sub.is_null())
-                if not oauth_user
-                else (User.email == email)
-            )
-            user = User.get(*conditions)
+            user = User.get(User.email == email)
             return UserModel(**model_to_dict(user))
         except:
             return None

+ 18 - 6
backend/main.py

@@ -1869,6 +1869,12 @@ async def oauth_login(provider: str, request: Request):
     return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
 
 
+# OAuth login logic is as follows:
+# 1. Attempt to find a user with matching subject ID, tied to the provider
+# 2. If OAUTH_MERGE_ACCOUNTS_BY_EMAIL is true, find a user with the email address provided via OAuth
+#    - This is considered insecure in general, as OAuth providers do not always verify email addresses
+# 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user
+#    - Email addresses are considered unique, so we fail registration if the email address is alreayd taken
 @app.get("/oauth/{provider}/callback")
 async def oauth_callback(provider: str, request: Request, response: Response):
     if provider not in OAUTH_PROVIDERS:
@@ -1885,6 +1891,10 @@ async def oauth_callback(provider: str, request: Request, response: Response):
     if not sub:
         raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
     provider_sub = f"{provider}@{sub}"
+    email = user_data.get("email", "").lower()
+    # We currently mandate that email addresses are provided
+    if not email:
+        raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
 
     # Check if the user exists
     user = Users.get_user_by_oauth_sub(provider_sub)
@@ -1893,10 +1903,7 @@ async def oauth_callback(provider: str, request: Request, response: Response):
         # If the user does not exist, check if merging is enabled
         if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
             # Check if the user exists by email
-            email = user_data.get("email", "").lower()
-            if not email:
-                raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
-            user = Users.get_user_by_email(user_data.get("email", "").lower(), True)
+            user = Users.get_user_by_email(email)
             if user:
                 # Update the user with the new oauth sub
                 Users.update_user_oauth_sub_by_id(user.id, provider_sub)
@@ -1904,6 +1911,11 @@ async def oauth_callback(provider: str, request: Request, response: Response):
     if not user:
         # If the user does not exist, check if signups are enabled
         if ENABLE_OAUTH_SIGNUP.value:
+            # Check if an existing user with the same email already exists
+            existing_user = Users.get_user_by_email(user_data.get("email", "").lower())
+            if existing_user:
+                raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
+
             picture_url = user_data.get("picture", "")
             if picture_url:
                 # Download the profile image into a base64 string
@@ -1920,12 +1932,12 @@ async def oauth_callback(provider: str, request: Request, response: Response):
                                 guessed_mime_type = "image/jpeg"
                             picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
                 except Exception as e:
-                    log.error(f"Profile image download error: {e}")
+                    log.error(f"Error downloading profile image '{picture_url}': {e}")
                     picture_url = ""
             if not picture_url:
                 picture_url = "/user.png"
             user = Auths.insert_new_auth(
-                email=user_data.get("email", "").lower(),
+                email=email,
                 password=get_password_hash(
                     str(uuid.uuid4())
                 ),  # Random password, not used