|
@@ -1,5 +1,7 @@
|
|
|
import logging
|
|
|
|
|
|
+from authlib.integrations.starlette_client import OAuth
|
|
|
+from authlib.oidc.core import UserInfo
|
|
|
from fastapi import Request, UploadFile, File
|
|
|
from fastapi import Depends, HTTPException, status
|
|
|
|
|
@@ -9,6 +11,7 @@ import re
|
|
|
import uuid
|
|
|
import csv
|
|
|
|
|
|
+from starlette.responses import RedirectResponse
|
|
|
|
|
|
from apps.webui.models.auths import (
|
|
|
SigninForm,
|
|
@@ -33,7 +36,12 @@ from utils.utils import (
|
|
|
from utils.misc import parse_duration, validate_email_format
|
|
|
from utils.webhook import post_webhook
|
|
|
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
|
|
-from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER
|
|
|
+from config import (
|
|
|
+ WEBUI_AUTH,
|
|
|
+ WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
|
|
+ OAUTH_PROVIDERS,
|
|
|
+ ENABLE_OAUTH_SIGNUP,
|
|
|
+)
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
@@ -373,3 +381,82 @@ async def get_api_key(user=Depends(get_current_user)):
|
|
|
}
|
|
|
else:
|
|
|
raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
|
|
|
+
|
|
|
+
|
|
|
+############################
|
|
|
+# OAuth Login & Callback
|
|
|
+############################
|
|
|
+
|
|
|
+oauth = OAuth()
|
|
|
+
|
|
|
+for provider_name, provider_config in OAUTH_PROVIDERS.items():
|
|
|
+ oauth.register(
|
|
|
+ name=provider_name,
|
|
|
+ client_id=provider_config["client_id"],
|
|
|
+ client_secret=provider_config["client_secret"],
|
|
|
+ server_metadata_url=provider_config["server_metadata_url"],
|
|
|
+ client_kwargs={
|
|
|
+ "scope": provider_config["scope"],
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@router.get("/oauth/{provider}/login")
|
|
|
+async def oauth_login(provider: str, request: Request):
|
|
|
+ if provider not in OAUTH_PROVIDERS:
|
|
|
+ raise HTTPException(404)
|
|
|
+ redirect_uri = request.url_for("oauth_callback", provider=provider)
|
|
|
+ return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
|
|
|
+
|
|
|
+
|
|
|
+@router.get("/oauth/{provider}/callback")
|
|
|
+async def oauth_callback(provider: str, request: Request):
|
|
|
+ if provider not in OAUTH_PROVIDERS:
|
|
|
+ raise HTTPException(404)
|
|
|
+ client = oauth.create_client(provider)
|
|
|
+ token = await client.authorize_access_token(request)
|
|
|
+ user_data: UserInfo = token["userinfo"]
|
|
|
+
|
|
|
+ sub = user_data.get("sub")
|
|
|
+ if not sub:
|
|
|
+ raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
+ provider_sub = f"{provider}@{sub}"
|
|
|
+
|
|
|
+ # Check if the user exists
|
|
|
+ user = Users.get_user_by_oauth_sub(provider_sub)
|
|
|
+
|
|
|
+ if not user:
|
|
|
+ # If the user does not exist, create a new user if signup is enabled
|
|
|
+ if ENABLE_OAUTH_SIGNUP.value:
|
|
|
+ user = Auths.insert_new_auth(
|
|
|
+ email=user_data.get("email", "").lower(),
|
|
|
+ password=get_password_hash(
|
|
|
+ str(uuid.uuid4())
|
|
|
+ ), # Random password, not used
|
|
|
+ name=user_data.get("name", "User"),
|
|
|
+ profile_image_url=user_data.get("picture", "/user.png"),
|
|
|
+ role=request.app.state.config.DEFAULT_USER_ROLE,
|
|
|
+ oauth_sub=provider_sub,
|
|
|
+ )
|
|
|
+
|
|
|
+ if request.app.state.config.WEBHOOK_URL:
|
|
|
+ post_webhook(
|
|
|
+ request.app.state.config.WEBHOOK_URL,
|
|
|
+ WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
|
|
+ {
|
|
|
+ "action": "signup",
|
|
|
+ "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
|
|
+ "user": user.model_dump_json(exclude_none=True),
|
|
|
+ },
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
+
|
|
|
+ jwt_token = create_token(
|
|
|
+ data={"id": user.id},
|
|
|
+ expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
|
|
|
+ )
|
|
|
+
|
|
|
+ # Redirect back to the frontend with the JWT token
|
|
|
+ redirect_url = f"{request.base_url}auth#token={jwt_token}"
|
|
|
+ return RedirectResponse(url=redirect_url)
|