Преглед изворни кода

refac: move things around, uplift oauth endpoints

Jun Siang Cheah пре 11 месеци
родитељ
комит
985fdca585

+ 0 - 6
backend/apps/webui/main.py

@@ -58,12 +58,6 @@ app.add_middleware(
     allow_headers=["*"],
 )
 
-# SessionMiddleware is used by authlib for oauth
-if len(OAUTH_PROVIDERS) > 0:
-    app.add_middleware(
-        SessionMiddleware, secret_key=WEBUI_SECRET_KEY, session_cookie="oui-session"
-    )
-
 app.include_router(auths.router, prefix="/auths", tags=["auths"])
 app.include_router(users.router, prefix="/users", tags=["users"])
 app.include_router(chats.router, prefix="/chats", tags=["chats"])

+ 21 - 2
backend/apps/webui/models/users.py

@@ -112,9 +112,16 @@ class UsersTable:
         except:
             return None
 
-    def get_user_by_email(self, email: str) -> Optional[UserModel]:
+    def get_user_by_email(
+        self, email: str, oauth_user: bool = False
+    ) -> Optional[UserModel]:
         try:
-            user = User.get((User.email == email, User.oauth_sub.is_null()))
+            conditions = (
+                (User.email == email, User.oauth_sub.is_null())
+                if not oauth_user
+                else (User.email == email)
+            )
+            user = User.get(conditions)
             return UserModel(**model_to_dict(user))
         except:
             return None
@@ -177,6 +184,18 @@ class UsersTable:
         except:
             return None
 
+    def update_user_oauth_sub_by_id(
+        self, id: str, oauth_sub: str
+    ) -> Optional[UserModel]:
+        try:
+            query = User.update(oauth_sub=oauth_sub).where(User.id == id)
+            query.execute()
+
+            user = User.get(User.id == id)
+            return UserModel(**model_to_dict(user))
+        except:
+            return None
+
     def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
         try:
             query = User.update(**updated).where(User.id == id)

+ 0 - 85
backend/apps/webui/routers/auths.py

@@ -1,7 +1,5 @@
 import logging
 
-from authlib.integrations.starlette_client import OAuth
-from authlib.oidc.core import UserInfo
 from fastapi import Request, UploadFile, File
 from fastapi import Depends, HTTPException, status
 
@@ -11,8 +9,6 @@ import re
 import uuid
 import csv
 
-from starlette.responses import RedirectResponse
-
 from apps.webui.models.auths import (
     SigninForm,
     SignupForm,
@@ -39,8 +35,6 @@ from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
 from config import (
     WEBUI_AUTH,
     WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
-    OAUTH_PROVIDERS,
-    ENABLE_OAUTH_SIGNUP,
 )
 
 router = APIRouter()
@@ -381,82 +375,3 @@ async def get_api_key(user=Depends(get_current_user)):
         }
     else:
         raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
-
-
-############################
-# OAuth Login & Callback
-############################
-
-oauth = OAuth()
-
-for provider_name, provider_config in OAUTH_PROVIDERS.items():
-    oauth.register(
-        name=provider_name,
-        client_id=provider_config["client_id"],
-        client_secret=provider_config["client_secret"],
-        server_metadata_url=provider_config["server_metadata_url"],
-        client_kwargs={
-            "scope": provider_config["scope"],
-        },
-    )
-
-
-@router.get("/oauth/{provider}/login")
-async def oauth_login(provider: str, request: Request):
-    if provider not in OAUTH_PROVIDERS:
-        raise HTTPException(404)
-    redirect_uri = request.url_for("oauth_callback", provider=provider)
-    return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
-
-
-@router.get("/oauth/{provider}/callback")
-async def oauth_callback(provider: str, request: Request):
-    if provider not in OAUTH_PROVIDERS:
-        raise HTTPException(404)
-    client = oauth.create_client(provider)
-    token = await client.authorize_access_token(request)
-    user_data: UserInfo = token["userinfo"]
-
-    sub = user_data.get("sub")
-    if not sub:
-        raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
-    provider_sub = f"{provider}@{sub}"
-
-    # Check if the user exists
-    user = Users.get_user_by_oauth_sub(provider_sub)
-
-    if not user:
-        # If the user does not exist, create a new user if signup is enabled
-        if ENABLE_OAUTH_SIGNUP.value:
-            user = Auths.insert_new_auth(
-                email=user_data.get("email", "").lower(),
-                password=get_password_hash(
-                    str(uuid.uuid4())
-                ),  # Random password, not used
-                name=user_data.get("name", "User"),
-                profile_image_url=user_data.get("picture", "/user.png"),
-                role=request.app.state.config.DEFAULT_USER_ROLE,
-                oauth_sub=provider_sub,
-            )
-
-            if request.app.state.config.WEBHOOK_URL:
-                post_webhook(
-                    request.app.state.config.WEBHOOK_URL,
-                    WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
-                    {
-                        "action": "signup",
-                        "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
-                        "user": user.model_dump_json(exclude_none=True),
-                    },
-                )
-        else:
-            raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
-
-    jwt_token = create_token(
-        data={"id": user.id},
-        expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
-    )
-
-    # Redirect back to the frontend with the JWT token
-    redirect_url = f"{request.base_url}auth#token={jwt_token}"
-    return RedirectResponse(url=redirect_url)

+ 118 - 4
backend/main.py

@@ -1,4 +1,8 @@
+import uuid
 from contextlib import asynccontextmanager
+
+from authlib.integrations.starlette_client import OAuth
+from authlib.oidc.core import UserInfo
 from bs4 import BeautifulSoup
 import json
 import markdown
@@ -17,7 +21,8 @@ from fastapi.middleware.wsgi import WSGIMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from starlette.exceptions import HTTPException as StarletteHTTPException
 from starlette.middleware.base import BaseHTTPMiddleware
-from starlette.responses import StreamingResponse, Response
+from starlette.middleware.sessions import SessionMiddleware
+from starlette.responses import StreamingResponse, Response, RedirectResponse
 
 from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models
 from apps.openai.main import app as openai_app, get_all_models as get_openai_models
@@ -31,8 +36,16 @@ import asyncio
 from pydantic import BaseModel
 from typing import List, Optional
 
-from apps.webui.models.models import Models, ModelModel
-from utils.utils import get_admin_user, get_verified_user
+from apps.webui.models.auths import Auths
+from apps.webui.models.models import Models
+from apps.webui.models.users import Users
+from utils.misc import parse_duration
+from utils.utils import (
+    get_admin_user,
+    get_verified_user,
+    get_password_hash,
+    create_token,
+)
 from apps.rag.utils import rag_messages
 
 from config import (
@@ -56,8 +69,12 @@ from config import (
     ENABLE_ADMIN_EXPORT,
     AppConfig,
     OAUTH_PROVIDERS,
+    ENABLE_OAUTH_SIGNUP,
+    OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
+    WEBUI_SECRET_KEY,
 )
-from constants import ERROR_MESSAGES
+from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
+from utils.webhook import post_webhook
 
 logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
 log = logging.getLogger(__name__)
@@ -453,6 +470,103 @@ async def get_app_latest_release_version():
         )
 
 
+############################
+# OAuth Login & Callback
+############################
+
+oauth = OAuth()
+
+for provider_name, provider_config in OAUTH_PROVIDERS.items():
+    oauth.register(
+        name=provider_name,
+        client_id=provider_config["client_id"],
+        client_secret=provider_config["client_secret"],
+        server_metadata_url=provider_config["server_metadata_url"],
+        client_kwargs={
+            "scope": provider_config["scope"],
+        },
+    )
+
+# SessionMiddleware is used by authlib for oauth
+if len(OAUTH_PROVIDERS) > 0:
+    app.add_middleware(
+        SessionMiddleware, secret_key=WEBUI_SECRET_KEY, session_cookie="oui-session"
+    )
+
+
+@app.get("/oauth/{provider}/login")
+async def oauth_login(provider: str, request: Request):
+    if provider not in OAUTH_PROVIDERS:
+        raise HTTPException(404)
+    redirect_uri = request.url_for("oauth_callback", provider=provider)
+    return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
+
+
+@app.get("/oauth/{provider}/callback")
+async def oauth_callback(provider: str, request: Request):
+    if provider not in OAUTH_PROVIDERS:
+        raise HTTPException(404)
+    client = oauth.create_client(provider)
+    token = await client.authorize_access_token(request)
+    user_data: UserInfo = token["userinfo"]
+
+    sub = user_data.get("sub")
+    if not sub:
+        raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
+    provider_sub = f"{provider}@{sub}"
+
+    # Check if the user exists
+    user = Users.get_user_by_oauth_sub(provider_sub)
+
+    if not user:
+        # If the user does not exist, check if merging is enabled
+        if OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
+            # Check if the user exists by email
+            email = user_data.get("email", "").lower()
+            if not email:
+                raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
+            user = Users.get_user_by_email(user_data.get("email", "").lower(), True)
+            if user:
+                # Update the user with the new oauth sub
+                Users.update_user_oauth_sub_by_id(user.id, provider_sub)
+
+    if not user:
+        # If the user does not exist, check if signups are enabled
+        if ENABLE_OAUTH_SIGNUP.value:
+            user = Auths.insert_new_auth(
+                email=user_data.get("email", "").lower(),
+                password=get_password_hash(
+                    str(uuid.uuid4())
+                ),  # Random password, not used
+                name=user_data.get("name", "User"),
+                profile_image_url=user_data.get("picture", "/user.png"),
+                role=webui_app.state.config.DEFAULT_USER_ROLE,
+                oauth_sub=provider_sub,
+            )
+
+            if webui_app.state.config.WEBHOOK_URL:
+                post_webhook(
+                    webui_app.state.config.WEBHOOK_URL,
+                    WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
+                    {
+                        "action": "signup",
+                        "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
+                        "user": user.model_dump_json(exclude_none=True),
+                    },
+                )
+        else:
+            raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
+
+    jwt_token = create_token(
+        data={"id": user.id},
+        expires_delta=parse_duration(webui_app.state.config.JWT_EXPIRES_IN),
+    )
+
+    # Redirect back to the frontend with the JWT token
+    redirect_url = f"{request.base_url}auth#token={jwt_token}"
+    return RedirectResponse(url=redirect_url)
+
+
 @app.get("/manifest.json")
 async def get_manifest_json():
     return {

+ 3 - 3
src/routes/auth/+page.svelte

@@ -259,7 +259,7 @@
 								<button
 									class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
 									on:click={() => {
-										window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/google/login`;
+										window.location.href = `${WEBUI_BASE_URL}/oauth/google/login`;
 									}}
 								>
 									<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 48 48" class="size-6 mr-3">
@@ -284,7 +284,7 @@
 								<button
 									class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
 									on:click={() => {
-										window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/microsoft/login`;
+										window.location.href = `${WEBUI_BASE_URL}/oauth/microsoft/login`;
 									}}
 								>
 									<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 21 21" class="size-6 mr-3">
@@ -309,7 +309,7 @@
 								<button
 									class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
 									on:click={() => {
-										window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/oidc/login`;
+										window.location.href = `${WEBUI_BASE_URL}/oauth/oidc/login`;
 									}}
 								>
 									<svg