Browse Source

refac: move things around, uplift oauth endpoints

Jun Siang Cheah 11 months ago
parent
commit
985fdca585

+ 0 - 6
backend/apps/webui/main.py

@@ -58,12 +58,6 @@ app.add_middleware(
     allow_headers=["*"],
     allow_headers=["*"],
 )
 )
 
 
-# SessionMiddleware is used by authlib for oauth
-if len(OAUTH_PROVIDERS) > 0:
-    app.add_middleware(
-        SessionMiddleware, secret_key=WEBUI_SECRET_KEY, session_cookie="oui-session"
-    )
-
 app.include_router(auths.router, prefix="/auths", tags=["auths"])
 app.include_router(auths.router, prefix="/auths", tags=["auths"])
 app.include_router(users.router, prefix="/users", tags=["users"])
 app.include_router(users.router, prefix="/users", tags=["users"])
 app.include_router(chats.router, prefix="/chats", tags=["chats"])
 app.include_router(chats.router, prefix="/chats", tags=["chats"])

+ 21 - 2
backend/apps/webui/models/users.py

@@ -112,9 +112,16 @@ class UsersTable:
         except:
         except:
             return None
             return None
 
 
-    def get_user_by_email(self, email: str) -> Optional[UserModel]:
+    def get_user_by_email(
+        self, email: str, oauth_user: bool = False
+    ) -> Optional[UserModel]:
         try:
         try:
-            user = User.get((User.email == email, User.oauth_sub.is_null()))
+            conditions = (
+                (User.email == email, User.oauth_sub.is_null())
+                if not oauth_user
+                else (User.email == email)
+            )
+            user = User.get(conditions)
             return UserModel(**model_to_dict(user))
             return UserModel(**model_to_dict(user))
         except:
         except:
             return None
             return None
@@ -177,6 +184,18 @@ class UsersTable:
         except:
         except:
             return None
             return None
 
 
+    def update_user_oauth_sub_by_id(
+        self, id: str, oauth_sub: str
+    ) -> Optional[UserModel]:
+        try:
+            query = User.update(oauth_sub=oauth_sub).where(User.id == id)
+            query.execute()
+
+            user = User.get(User.id == id)
+            return UserModel(**model_to_dict(user))
+        except:
+            return None
+
     def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
     def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
         try:
         try:
             query = User.update(**updated).where(User.id == id)
             query = User.update(**updated).where(User.id == id)

+ 0 - 85
backend/apps/webui/routers/auths.py

@@ -1,7 +1,5 @@
 import logging
 import logging
 
 
-from authlib.integrations.starlette_client import OAuth
-from authlib.oidc.core import UserInfo
 from fastapi import Request, UploadFile, File
 from fastapi import Request, UploadFile, File
 from fastapi import Depends, HTTPException, status
 from fastapi import Depends, HTTPException, status
 
 
@@ -11,8 +9,6 @@ import re
 import uuid
 import uuid
 import csv
 import csv
 
 
-from starlette.responses import RedirectResponse
-
 from apps.webui.models.auths import (
 from apps.webui.models.auths import (
     SigninForm,
     SigninForm,
     SignupForm,
     SignupForm,
@@ -39,8 +35,6 @@ from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
 from config import (
 from config import (
     WEBUI_AUTH,
     WEBUI_AUTH,
     WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
     WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
-    OAUTH_PROVIDERS,
-    ENABLE_OAUTH_SIGNUP,
 )
 )
 
 
 router = APIRouter()
 router = APIRouter()
@@ -381,82 +375,3 @@ async def get_api_key(user=Depends(get_current_user)):
         }
         }
     else:
     else:
         raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
         raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
-
-
-############################
-# OAuth Login & Callback
-############################
-
-oauth = OAuth()
-
-for provider_name, provider_config in OAUTH_PROVIDERS.items():
-    oauth.register(
-        name=provider_name,
-        client_id=provider_config["client_id"],
-        client_secret=provider_config["client_secret"],
-        server_metadata_url=provider_config["server_metadata_url"],
-        client_kwargs={
-            "scope": provider_config["scope"],
-        },
-    )
-
-
-@router.get("/oauth/{provider}/login")
-async def oauth_login(provider: str, request: Request):
-    if provider not in OAUTH_PROVIDERS:
-        raise HTTPException(404)
-    redirect_uri = request.url_for("oauth_callback", provider=provider)
-    return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
-
-
-@router.get("/oauth/{provider}/callback")
-async def oauth_callback(provider: str, request: Request):
-    if provider not in OAUTH_PROVIDERS:
-        raise HTTPException(404)
-    client = oauth.create_client(provider)
-    token = await client.authorize_access_token(request)
-    user_data: UserInfo = token["userinfo"]
-
-    sub = user_data.get("sub")
-    if not sub:
-        raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
-    provider_sub = f"{provider}@{sub}"
-
-    # Check if the user exists
-    user = Users.get_user_by_oauth_sub(provider_sub)
-
-    if not user:
-        # If the user does not exist, create a new user if signup is enabled
-        if ENABLE_OAUTH_SIGNUP.value:
-            user = Auths.insert_new_auth(
-                email=user_data.get("email", "").lower(),
-                password=get_password_hash(
-                    str(uuid.uuid4())
-                ),  # Random password, not used
-                name=user_data.get("name", "User"),
-                profile_image_url=user_data.get("picture", "/user.png"),
-                role=request.app.state.config.DEFAULT_USER_ROLE,
-                oauth_sub=provider_sub,
-            )
-
-            if request.app.state.config.WEBHOOK_URL:
-                post_webhook(
-                    request.app.state.config.WEBHOOK_URL,
-                    WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
-                    {
-                        "action": "signup",
-                        "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
-                        "user": user.model_dump_json(exclude_none=True),
-                    },
-                )
-        else:
-            raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
-
-    jwt_token = create_token(
-        data={"id": user.id},
-        expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
-    )
-
-    # Redirect back to the frontend with the JWT token
-    redirect_url = f"{request.base_url}auth#token={jwt_token}"
-    return RedirectResponse(url=redirect_url)

+ 118 - 4
backend/main.py

@@ -1,4 +1,8 @@
+import uuid
 from contextlib import asynccontextmanager
 from contextlib import asynccontextmanager
+
+from authlib.integrations.starlette_client import OAuth
+from authlib.oidc.core import UserInfo
 from bs4 import BeautifulSoup
 from bs4 import BeautifulSoup
 import json
 import json
 import markdown
 import markdown
@@ -17,7 +21,8 @@ from fastapi.middleware.wsgi import WSGIMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from starlette.exceptions import HTTPException as StarletteHTTPException
 from starlette.exceptions import HTTPException as StarletteHTTPException
 from starlette.middleware.base import BaseHTTPMiddleware
 from starlette.middleware.base import BaseHTTPMiddleware
-from starlette.responses import StreamingResponse, Response
+from starlette.middleware.sessions import SessionMiddleware
+from starlette.responses import StreamingResponse, Response, RedirectResponse
 
 
 from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models
 from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models
 from apps.openai.main import app as openai_app, get_all_models as get_openai_models
 from apps.openai.main import app as openai_app, get_all_models as get_openai_models
@@ -31,8 +36,16 @@ import asyncio
 from pydantic import BaseModel
 from pydantic import BaseModel
 from typing import List, Optional
 from typing import List, Optional
 
 
-from apps.webui.models.models import Models, ModelModel
-from utils.utils import get_admin_user, get_verified_user
+from apps.webui.models.auths import Auths
+from apps.webui.models.models import Models
+from apps.webui.models.users import Users
+from utils.misc import parse_duration
+from utils.utils import (
+    get_admin_user,
+    get_verified_user,
+    get_password_hash,
+    create_token,
+)
 from apps.rag.utils import rag_messages
 from apps.rag.utils import rag_messages
 
 
 from config import (
 from config import (
@@ -56,8 +69,12 @@ from config import (
     ENABLE_ADMIN_EXPORT,
     ENABLE_ADMIN_EXPORT,
     AppConfig,
     AppConfig,
     OAUTH_PROVIDERS,
     OAUTH_PROVIDERS,
+    ENABLE_OAUTH_SIGNUP,
+    OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
+    WEBUI_SECRET_KEY,
 )
 )
-from constants import ERROR_MESSAGES
+from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
+from utils.webhook import post_webhook
 
 
 logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
 logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
@@ -453,6 +470,103 @@ async def get_app_latest_release_version():
         )
         )
 
 
 
 
+############################
+# OAuth Login & Callback
+############################
+
+oauth = OAuth()
+
+for provider_name, provider_config in OAUTH_PROVIDERS.items():
+    oauth.register(
+        name=provider_name,
+        client_id=provider_config["client_id"],
+        client_secret=provider_config["client_secret"],
+        server_metadata_url=provider_config["server_metadata_url"],
+        client_kwargs={
+            "scope": provider_config["scope"],
+        },
+    )
+
+# SessionMiddleware is used by authlib for oauth
+if len(OAUTH_PROVIDERS) > 0:
+    app.add_middleware(
+        SessionMiddleware, secret_key=WEBUI_SECRET_KEY, session_cookie="oui-session"
+    )
+
+
+@app.get("/oauth/{provider}/login")
+async def oauth_login(provider: str, request: Request):
+    if provider not in OAUTH_PROVIDERS:
+        raise HTTPException(404)
+    redirect_uri = request.url_for("oauth_callback", provider=provider)
+    return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
+
+
+@app.get("/oauth/{provider}/callback")
+async def oauth_callback(provider: str, request: Request):
+    if provider not in OAUTH_PROVIDERS:
+        raise HTTPException(404)
+    client = oauth.create_client(provider)
+    token = await client.authorize_access_token(request)
+    user_data: UserInfo = token["userinfo"]
+
+    sub = user_data.get("sub")
+    if not sub:
+        raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
+    provider_sub = f"{provider}@{sub}"
+
+    # Check if the user exists
+    user = Users.get_user_by_oauth_sub(provider_sub)
+
+    if not user:
+        # If the user does not exist, check if merging is enabled
+        if OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
+            # Check if the user exists by email
+            email = user_data.get("email", "").lower()
+            if not email:
+                raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
+            user = Users.get_user_by_email(user_data.get("email", "").lower(), True)
+            if user:
+                # Update the user with the new oauth sub
+                Users.update_user_oauth_sub_by_id(user.id, provider_sub)
+
+    if not user:
+        # If the user does not exist, check if signups are enabled
+        if ENABLE_OAUTH_SIGNUP.value:
+            user = Auths.insert_new_auth(
+                email=user_data.get("email", "").lower(),
+                password=get_password_hash(
+                    str(uuid.uuid4())
+                ),  # Random password, not used
+                name=user_data.get("name", "User"),
+                profile_image_url=user_data.get("picture", "/user.png"),
+                role=webui_app.state.config.DEFAULT_USER_ROLE,
+                oauth_sub=provider_sub,
+            )
+
+            if webui_app.state.config.WEBHOOK_URL:
+                post_webhook(
+                    webui_app.state.config.WEBHOOK_URL,
+                    WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
+                    {
+                        "action": "signup",
+                        "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
+                        "user": user.model_dump_json(exclude_none=True),
+                    },
+                )
+        else:
+            raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
+
+    jwt_token = create_token(
+        data={"id": user.id},
+        expires_delta=parse_duration(webui_app.state.config.JWT_EXPIRES_IN),
+    )
+
+    # Redirect back to the frontend with the JWT token
+    redirect_url = f"{request.base_url}auth#token={jwt_token}"
+    return RedirectResponse(url=redirect_url)
+
+
 @app.get("/manifest.json")
 @app.get("/manifest.json")
 async def get_manifest_json():
 async def get_manifest_json():
     return {
     return {

+ 3 - 3
src/routes/auth/+page.svelte

@@ -259,7 +259,7 @@
 								<button
 								<button
 									class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
 									class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
 									on:click={() => {
 									on:click={() => {
-										window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/google/login`;
+										window.location.href = `${WEBUI_BASE_URL}/oauth/google/login`;
 									}}
 									}}
 								>
 								>
 									<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 48 48" class="size-6 mr-3">
 									<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 48 48" class="size-6 mr-3">
@@ -284,7 +284,7 @@
 								<button
 								<button
 									class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
 									class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
 									on:click={() => {
 									on:click={() => {
-										window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/microsoft/login`;
+										window.location.href = `${WEBUI_BASE_URL}/oauth/microsoft/login`;
 									}}
 									}}
 								>
 								>
 									<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 21 21" class="size-6 mr-3">
 									<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 21 21" class="size-6 mr-3">
@@ -309,7 +309,7 @@
 								<button
 								<button
 									class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
 									class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
 									on:click={() => {
 									on:click={() => {
-										window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/oidc/login`;
+										window.location.href = `${WEBUI_BASE_URL}/oauth/oidc/login`;
 									}}
 									}}
 								>
 								>
 									<svg
 									<svg