Browse Source

refac: remove the verify_token and use get-current user for auth+user

Anuraag Jain 1 year ago
parent
commit
77323d9b25

+ 3 - 19
backend/apps/web/main.py

@@ -3,7 +3,6 @@ from fastapi.routing import APIRoute
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from apps.web.routers import auths, users, chats, modelfiles, utils
 from apps.web.routers import auths, users, chats, modelfiles, utils
 from config import WEBUI_VERSION, WEBUI_AUTH
 from config import WEBUI_VERSION, WEBUI_AUTH
-from utils.utils import verify_auth_token
 
 
 app = FastAPI()
 app = FastAPI()
 
 
@@ -19,24 +18,9 @@ app.add_middleware(
 
 
 app.include_router(auths.router, prefix="/auths", tags=["auths"])
 app.include_router(auths.router, prefix="/auths", tags=["auths"])
 
 
-app.include_router(
-    users.router,
-    prefix="/users",
-    tags=["users"],
-    dependencies=[Depends(verify_auth_token)],
-)
-app.include_router(
-    chats.router,
-    prefix="/chats",
-    tags=["chats"],
-    dependencies=[Depends(verify_auth_token)],
-)
-app.include_router(
-    modelfiles.router,
-    prefix="/modelfiles",
-    tags=["modelfiles"],
-    dependencies=[Depends(verify_auth_token)],
-)
+app.include_router(users.router, prefix="/users", tags=["users"])
+app.include_router(chats.router, prefix="/chats", tags=["chats"])
+app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
 app.include_router(utils.router, prefix="/utils", tags=["utils"])
 app.include_router(utils.router, prefix="/utils", tags=["utils"])
 
 
 
 

+ 3 - 10
backend/apps/web/routers/auths.py

@@ -19,12 +19,7 @@ from apps.web.models.auths import (
 from apps.web.models.users import Users
 from apps.web.models.users import Users
 
 
 
 
-from utils.utils import (
-    get_password_hash,
-    get_current_user,
-    create_token,
-    verify_auth_token,
-)
+from utils.utils import get_password_hash, get_current_user, create_token
 from utils.misc import get_gravatar_url
 from utils.misc import get_gravatar_url
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
 
 
@@ -36,7 +31,7 @@ router = APIRouter()
 ############################
 ############################
 
 
 
 
-@router.get("/", response_model=UserResponse, dependencies=[Depends(verify_auth_token)])
+@router.get("/", response_model=UserResponse)
 async def get_session_user(user=Depends(get_current_user)):
 async def get_session_user(user=Depends(get_current_user)):
     return {
     return {
         "id": user.id,
         "id": user.id,
@@ -52,9 +47,7 @@ async def get_session_user(user=Depends(get_current_user)):
 ############################
 ############################
 
 
 
 
-@router.post(
-    "/update/password", response_model=bool, dependencies=[Depends(verify_auth_token)]
-)
+@router.post("/update/password", response_model=bool)
 async def update_password(
 async def update_password(
     form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
     form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
 ):
 ):

+ 1 - 0
backend/apps/web/routers/chats.py

@@ -108,6 +108,7 @@ async def delete_chat_by_id(id: str, user=Depends(get_current_user)):
     result = Chats.delete_chat_by_id_and_user_id(id, user.id)
     result = Chats.delete_chat_by_id_and_user_id(id, user.id)
     return result
     return result
 
 
+
 ############################
 ############################
 # DeleteAllChats
 # DeleteAllChats
 ############################
 ############################

+ 3 - 5
backend/apps/web/routers/modelfiles.py

@@ -5,8 +5,6 @@ from typing import List, Union, Optional
 from fastapi import APIRouter
 from fastapi import APIRouter
 from pydantic import BaseModel
 from pydantic import BaseModel
 import json
 import json
-
-from apps.web.models.users import Users
 from apps.web.models.modelfiles import (
 from apps.web.models.modelfiles import (
     Modelfiles,
     Modelfiles,
     ModelfileForm,
     ModelfileForm,
@@ -15,7 +13,7 @@ from apps.web.models.modelfiles import (
     ModelfileResponse,
     ModelfileResponse,
 )
 )
 
 
-from utils.utils import bearer_scheme, get_current_user
+from utils.utils import get_current_user
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
 
 
 router = APIRouter()
 router = APIRouter()
@@ -26,7 +24,7 @@ router = APIRouter()
 
 
 
 
 @router.get("/", response_model=List[ModelfileResponse])
 @router.get("/", response_model=List[ModelfileResponse])
-async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
+async def get_modelfiles(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
     return Modelfiles.get_modelfiles(skip, limit)
     return Modelfiles.get_modelfiles(skip, limit)
 
 
 
 
@@ -67,7 +65,7 @@ async def create_new_modelfile(
 
 
 
 
 @router.post("/", response_model=Optional[ModelfileResponse])
 @router.post("/", response_model=Optional[ModelfileResponse])
-async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm):
+async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm, user=Depends(get_current_user)):
     modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
     modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
 
 
     if modelfile:
     if modelfile:

+ 2 - 7
backend/utils/utils.py

@@ -55,7 +55,7 @@ def extract_token_from_auth_header(auth_header: str):
     return auth_header[len("Bearer ") :]
     return auth_header[len("Bearer ") :]
 
 
 
 
-def verify_auth_token(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
+def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
     data = decode_token(auth_token.credentials)
     data = decode_token(auth_token.credentials)
     if data != None and "email" in data:
     if data != None and "email" in data:
         user = Users.get_user_by_email(data["email"])
         user = Users.get_user_by_email(data["email"])
@@ -64,14 +64,9 @@ def verify_auth_token(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBea
                 status_code=status.HTTP_401_UNAUTHORIZED,
                 status_code=status.HTTP_401_UNAUTHORIZED,
                 detail=ERROR_MESSAGES.INVALID_TOKEN,
                 detail=ERROR_MESSAGES.INVALID_TOKEN,
             )
             )
-        return
+        return user
     else:
     else:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
             status_code=status.HTTP_401_UNAUTHORIZED,
             detail=ERROR_MESSAGES.UNAUTHORIZED,
             detail=ERROR_MESSAGES.UNAUTHORIZED,
         )
         )
-
-
-def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
-    data = decode_token(auth_token.credentials)
-    return Users.get_user_by_email(data["email"])