|
@@ -1,4 +1,9 @@
|
|
|
|
+import base64
|
|
|
|
+import uuid
|
|
from contextlib import asynccontextmanager
|
|
from contextlib import asynccontextmanager
|
|
|
|
+
|
|
|
|
+from authlib.integrations.starlette_client import OAuth
|
|
|
|
+from authlib.oidc.core import UserInfo
|
|
from bs4 import BeautifulSoup
|
|
from bs4 import BeautifulSoup
|
|
import json
|
|
import json
|
|
import markdown
|
|
import markdown
|
|
@@ -24,7 +29,8 @@ from fastapi.middleware.wsgi import WSGIMiddleware
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
-from starlette.responses import StreamingResponse, Response
|
|
|
|
|
|
+from starlette.middleware.sessions import SessionMiddleware
|
|
|
|
+from starlette.responses import StreamingResponse, Response, RedirectResponse
|
|
|
|
|
|
|
|
|
|
from apps.socket.main import app as socket_app
|
|
from apps.socket.main import app as socket_app
|
|
@@ -53,9 +59,11 @@ from apps.webui.main import (
|
|
from pydantic import BaseModel
|
|
from pydantic import BaseModel
|
|
from typing import List, Optional, Iterator, Generator, Union
|
|
from typing import List, Optional, Iterator, Generator, Union
|
|
|
|
|
|
|
|
+from apps.webui.models.auths import Auths
|
|
from apps.webui.models.models import Models, ModelModel
|
|
from apps.webui.models.models import Models, ModelModel
|
|
from apps.webui.models.tools import Tools
|
|
from apps.webui.models.tools import Tools
|
|
from apps.webui.models.functions import Functions
|
|
from apps.webui.models.functions import Functions
|
|
|
|
+from apps.webui.models.users import Users
|
|
|
|
|
|
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
|
|
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
|
|
|
|
|
|
@@ -64,6 +72,8 @@ from utils.utils import (
|
|
get_verified_user,
|
|
get_verified_user,
|
|
get_current_user,
|
|
get_current_user,
|
|
get_http_authorization_cred,
|
|
get_http_authorization_cred,
|
|
|
|
+ get_password_hash,
|
|
|
|
+ create_token,
|
|
)
|
|
)
|
|
from utils.task import (
|
|
from utils.task import (
|
|
title_generation_template,
|
|
title_generation_template,
|
|
@@ -74,6 +84,7 @@ from utils.misc import (
|
|
get_last_user_message,
|
|
get_last_user_message,
|
|
add_or_update_system_message,
|
|
add_or_update_system_message,
|
|
stream_message_template,
|
|
stream_message_template,
|
|
|
|
+ parse_duration,
|
|
)
|
|
)
|
|
|
|
|
|
from apps.rag.utils import get_rag_context, rag_template
|
|
from apps.rag.utils import get_rag_context, rag_template
|
|
@@ -106,9 +117,16 @@ from config import (
|
|
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
|
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
SAFE_MODE,
|
|
SAFE_MODE,
|
|
|
|
+ OAUTH_PROVIDERS,
|
|
|
|
+ ENABLE_OAUTH_SIGNUP,
|
|
|
|
+ OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
|
|
|
|
+ WEBUI_SECRET_KEY,
|
|
|
|
+ WEBUI_SESSION_COOKIE_SAME_SITE,
|
|
|
|
+ WEBUI_SESSION_COOKIE_SECURE,
|
|
AppConfig,
|
|
AppConfig,
|
|
)
|
|
)
|
|
-from constants import ERROR_MESSAGES
|
|
|
|
|
|
+from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
|
|
|
+from utils.webhook import post_webhook
|
|
|
|
|
|
if SAFE_MODE:
|
|
if SAFE_MODE:
|
|
print("SAFE MODE ENABLED")
|
|
print("SAFE MODE ENABLED")
|
|
@@ -1725,6 +1743,12 @@ async def get_app_config():
|
|
"engine": audio_app.state.config.STT_ENGINE,
|
|
"engine": audio_app.state.config.STT_ENGINE,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
|
|
+ "oauth": {
|
|
|
|
+ "providers": {
|
|
|
|
+ name: config.get("name", name)
|
|
|
|
+ for name, config in OAUTH_PROVIDERS.items()
|
|
|
|
+ }
|
|
|
|
+ },
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -1806,6 +1830,154 @@ async def get_app_latest_release_version():
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
+############################
|
|
|
|
+# OAuth Login & Callback
|
|
|
|
+############################
|
|
|
|
+
|
|
|
|
+oauth = OAuth()
|
|
|
|
+
|
|
|
|
+for provider_name, provider_config in OAUTH_PROVIDERS.items():
|
|
|
|
+ oauth.register(
|
|
|
|
+ name=provider_name,
|
|
|
|
+ client_id=provider_config["client_id"],
|
|
|
|
+ client_secret=provider_config["client_secret"],
|
|
|
|
+ server_metadata_url=provider_config["server_metadata_url"],
|
|
|
|
+ client_kwargs={
|
|
|
|
+ "scope": provider_config["scope"],
|
|
|
|
+ },
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+# SessionMiddleware is used by authlib for oauth
|
|
|
|
+if len(OAUTH_PROVIDERS) > 0:
|
|
|
|
+ app.add_middleware(
|
|
|
|
+ SessionMiddleware,
|
|
|
|
+ secret_key=WEBUI_SECRET_KEY,
|
|
|
|
+ session_cookie="oui-session",
|
|
|
|
+ same_site=WEBUI_SESSION_COOKIE_SAME_SITE,
|
|
|
|
+ https_only=WEBUI_SESSION_COOKIE_SECURE,
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@app.get("/oauth/{provider}/login")
|
|
|
|
+async def oauth_login(provider: str, request: Request):
|
|
|
|
+ if provider not in OAUTH_PROVIDERS:
|
|
|
|
+ raise HTTPException(404)
|
|
|
|
+ redirect_uri = request.url_for("oauth_callback", provider=provider)
|
|
|
|
+ return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+# OAuth login logic is as follows:
|
|
|
|
+# 1. Attempt to find a user with matching subject ID, tied to the provider
|
|
|
|
+# 2. If OAUTH_MERGE_ACCOUNTS_BY_EMAIL is true, find a user with the email address provided via OAuth
|
|
|
|
+# - This is considered insecure in general, as OAuth providers do not always verify email addresses
|
|
|
|
+# 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user
|
|
|
|
+# - Email addresses are considered unique, so we fail registration if the email address is alreayd taken
|
|
|
|
+@app.get("/oauth/{provider}/callback")
|
|
|
|
+async def oauth_callback(provider: str, request: Request, response: Response):
|
|
|
|
+ if provider not in OAUTH_PROVIDERS:
|
|
|
|
+ raise HTTPException(404)
|
|
|
|
+ client = oauth.create_client(provider)
|
|
|
|
+ try:
|
|
|
|
+ token = await client.authorize_access_token(request)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ log.warning(f"OAuth callback error: {e}")
|
|
|
|
+ raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
|
+ user_data: UserInfo = token["userinfo"]
|
|
|
|
+
|
|
|
|
+ sub = user_data.get("sub")
|
|
|
|
+ if not sub:
|
|
|
|
+ log.warning(f"OAuth callback failed, sub is missing: {user_data}")
|
|
|
|
+ raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
|
+ provider_sub = f"{provider}@{sub}"
|
|
|
|
+ email = user_data.get("email", "").lower()
|
|
|
|
+ # We currently mandate that email addresses are provided
|
|
|
|
+ if not email:
|
|
|
|
+ log.warning(f"OAuth callback failed, email is missing: {user_data}")
|
|
|
|
+ raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
|
+
|
|
|
|
+ # Check if the user exists
|
|
|
|
+ user = Users.get_user_by_oauth_sub(provider_sub)
|
|
|
|
+
|
|
|
|
+ if not user:
|
|
|
|
+ # If the user does not exist, check if merging is enabled
|
|
|
|
+ if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
|
|
|
|
+ # Check if the user exists by email
|
|
|
|
+ user = Users.get_user_by_email(email)
|
|
|
|
+ if user:
|
|
|
|
+ # Update the user with the new oauth sub
|
|
|
|
+ Users.update_user_oauth_sub_by_id(user.id, provider_sub)
|
|
|
|
+
|
|
|
|
+ if not user:
|
|
|
|
+ # If the user does not exist, check if signups are enabled
|
|
|
|
+ if ENABLE_OAUTH_SIGNUP.value:
|
|
|
|
+ # Check if an existing user with the same email already exists
|
|
|
|
+ existing_user = Users.get_user_by_email(user_data.get("email", "").lower())
|
|
|
|
+ if existing_user:
|
|
|
|
+ raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
|
|
|
+
|
|
|
|
+ picture_url = user_data.get("picture", "")
|
|
|
|
+ if picture_url:
|
|
|
|
+ # Download the profile image into a base64 string
|
|
|
|
+ try:
|
|
|
|
+ async with aiohttp.ClientSession() as session:
|
|
|
|
+ async with session.get(picture_url) as resp:
|
|
|
|
+ picture = await resp.read()
|
|
|
|
+ base64_encoded_picture = base64.b64encode(picture).decode(
|
|
|
|
+ "utf-8"
|
|
|
|
+ )
|
|
|
|
+ guessed_mime_type = mimetypes.guess_type(picture_url)[0]
|
|
|
|
+ if guessed_mime_type is None:
|
|
|
|
+ # assume JPG, browsers are tolerant enough of image formats
|
|
|
|
+ guessed_mime_type = "image/jpeg"
|
|
|
|
+ picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
|
|
|
|
+ except Exception as e:
|
|
|
|
+ log.error(f"Error downloading profile image '{picture_url}': {e}")
|
|
|
|
+ picture_url = ""
|
|
|
|
+ if not picture_url:
|
|
|
|
+ picture_url = "/user.png"
|
|
|
|
+ user = Auths.insert_new_auth(
|
|
|
|
+ email=email,
|
|
|
|
+ password=get_password_hash(
|
|
|
|
+ str(uuid.uuid4())
|
|
|
|
+ ), # Random password, not used
|
|
|
|
+ name=user_data.get("name", "User"),
|
|
|
|
+ profile_image_url=picture_url,
|
|
|
|
+ role=webui_app.state.config.DEFAULT_USER_ROLE,
|
|
|
|
+ oauth_sub=provider_sub,
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ if webui_app.state.config.WEBHOOK_URL:
|
|
|
|
+ post_webhook(
|
|
|
|
+ webui_app.state.config.WEBHOOK_URL,
|
|
|
|
+ WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
|
|
|
+ {
|
|
|
|
+ "action": "signup",
|
|
|
|
+ "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
|
|
|
+ "user": user.model_dump_json(exclude_none=True),
|
|
|
|
+ },
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ raise HTTPException(
|
|
|
|
+ status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ jwt_token = create_token(
|
|
|
|
+ data={"id": user.id},
|
|
|
|
+ expires_delta=parse_duration(webui_app.state.config.JWT_EXPIRES_IN),
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ # Set the cookie token
|
|
|
|
+ response.set_cookie(
|
|
|
|
+ key="token",
|
|
|
|
+ value=token,
|
|
|
|
+ httponly=True, # Ensures the cookie is not accessible via JavaScript
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ # Redirect back to the frontend with the JWT token
|
|
|
|
+ redirect_url = f"{request.base_url}auth#token={jwt_token}"
|
|
|
|
+ return RedirectResponse(url=redirect_url)
|
|
|
|
+
|
|
|
|
+
|
|
@app.get("/manifest.json")
|
|
@app.get("/manifest.json")
|
|
async def get_manifest_json():
|
|
async def get_manifest_json():
|
|
return {
|
|
return {
|