Преглед изворни кода

Merge pull request #311 from anuraagdjain/refac/auth-middleware

feat(auth): add auth middleware
Timothy Jaeryang Baek пре 1 година
родитељ
комит
c5386d05ab

+ 2 - 1
backend/.gitignore

@@ -4,4 +4,5 @@ _old
 uploads
 .ipynb_checkpoints
 *.db
-_test
+_test
+Pipfile

+ 7 - 3
backend/apps/ollama/main.py

@@ -8,7 +8,7 @@ import json
 
 from apps.web.models.users import Users
 from constants import ERROR_MESSAGES
-from utils.utils import extract_token_from_auth_header
+from utils.utils import decode_token
 from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
 
 app = Flask(__name__)
@@ -34,8 +34,12 @@ def proxy(path):
     # Basic RBAC support
     if WEBUI_AUTH:
         if "Authorization" in headers:
-            token = extract_token_from_auth_header(headers["Authorization"])
-            user = Users.get_user_by_token(token)
+            _, credentials = headers["Authorization"].split()
+            token_data = decode_token(credentials)
+            if token_data is None or "email" not in token_data:
+                return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
+
+            user = Users.get_user_by_email(token_data["email"])
             if user:
                 # Only user and admin roles can access
                 if user.role in ["user", "admin"]:

+ 3 - 5
backend/apps/web/main.py

@@ -1,6 +1,6 @@
-from fastapi import FastAPI, Request, Depends, HTTPException
+from fastapi import FastAPI, Depends
+from fastapi.routing import APIRoute
 from fastapi.middleware.cors import CORSMiddleware
-
 from apps.web.routers import auths, users, chats, modelfiles, utils
 from config import WEBUI_VERSION, WEBUI_AUTH
 
@@ -16,13 +16,11 @@ app.add_middleware(
     allow_headers=["*"],
 )
 
-
 app.include_router(auths.router, prefix="/auths", tags=["auths"])
+
 app.include_router(users.router, prefix="/users", tags=["users"])
 app.include_router(chats.router, prefix="/chats", tags=["chats"])
 app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
-
-
 app.include_router(utils.router, prefix="/utils", tags=["utils"])
 
 

+ 0 - 10
backend/apps/web/models/users.py

@@ -3,8 +3,6 @@ from peewee import *
 from playhouse.shortcuts import model_to_dict
 from typing import List, Union, Optional
 import time
-
-from utils.utils import decode_token
 from utils.misc import get_gravatar_url
 
 from apps.web.internal.db import DB
@@ -85,14 +83,6 @@ class UsersTable:
         except:
             return None
 
-    def get_user_by_token(self, token: str) -> Optional[UserModel]:
-        data = decode_token(token)
-
-        if data != None and "email" in data:
-            return self.get_user_by_email(data["email"])
-        else:
-            return None
-
     def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
         return [
             UserModel(**model_to_dict(user))

+ 12 - 25
backend/apps/web/routers/auths.py

@@ -19,11 +19,7 @@ from apps.web.models.auths import (
 from apps.web.models.users import Users
 
 
-from utils.utils import (
-    get_password_hash,
-    bearer_scheme,
-    create_token,
-)
+from utils.utils import get_password_hash, get_current_user, create_token
 from utils.misc import get_gravatar_url
 from constants import ERROR_MESSAGES
 
@@ -36,22 +32,14 @@ router = APIRouter()
 
 
 @router.get("/", response_model=UserResponse)
-async def get_session_user(cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-    if user:
-        return {
-            "id": user.id,
-            "email": user.email,
-            "name": user.name,
-            "role": user.role,
-            "profile_image_url": user.profile_image_url,
-        }
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def get_session_user(user=Depends(get_current_user)):
+    return {
+        "id": user.id,
+        "email": user.email,
+        "name": user.name,
+        "role": user.role,
+        "profile_image_url": user.profile_image_url,
+    }
 
 
 ############################
@@ -60,10 +48,9 @@ async def get_session_user(cred=Depends(bearer_scheme)):
 
 
 @router.post("/update/password", response_model=bool)
-async def update_password(form_data: UpdatePasswordForm, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    session_user = Users.get_user_by_token(token)
-
+async def update_password(
+    form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
+):
     if session_user:
         user = Auths.authenticate_user(session_user.email, form_data.password)
 

+ 36 - 97
backend/apps/web/routers/chats.py

@@ -1,8 +1,7 @@
-from fastapi import Response
-from fastapi import Depends, FastAPI, HTTPException, status
+from fastapi import Depends, Request, HTTPException, status
 from datetime import datetime, timedelta
 from typing import List, Union, Optional
-
+from utils.utils import get_current_user
 from fastapi import APIRouter
 from pydantic import BaseModel
 import json
@@ -30,17 +29,10 @@ router = APIRouter()
 
 
 @router.get("/", response_model=List[ChatTitleIdResponse])
-async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def get_user_chats(
+    user=Depends(get_current_user), skip: int = 0, limit: int = 50
+):
+    return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
 
 
 ############################
@@ -49,20 +41,11 @@ async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch
 
 
 @router.get("/all", response_model=List[ChatResponse])
-async def get_all_user_chats(cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        return [
-            ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-            for chat in Chats.get_all_chats_by_user_id(user.id)
-        ]
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def get_all_user_chats(user=Depends(get_current_user)):
+    return [
+        ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+        for chat in Chats.get_all_chats_by_user_id(user.id)
+    ]
 
 
 ############################
@@ -71,18 +54,9 @@ async def get_all_user_chats(cred=Depends(bearer_scheme)):
 
 
 @router.post("/new", response_model=Optional[ChatResponse])
-async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        chat = Chats.insert_new_chat(user.id, form_data)
-        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
+    chat = Chats.insert_new_chat(user.id, form_data)
+    return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
 
 
 ############################
@@ -91,24 +65,14 @@ async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)):
 
 
 @router.get("/{id}", response_model=Optional[ChatResponse])
-async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        chat = Chats.get_chat_by_id_and_user_id(id, user.id)
-
-        if chat:
-            return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.NOT_FOUND,
-            )
+async def get_chat_by_id(id: str, user=Depends(get_current_user)):
+    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+
+    if chat:
+        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
     else:
         raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
         )
 
 
@@ -118,26 +82,19 @@ async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)):
 
 
 @router.post("/{id}", response_model=Optional[ChatResponse])
-async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        chat = Chats.get_chat_by_id_and_user_id(id, user.id)
-        if chat:
-            updated_chat = {**json.loads(chat.chat), **form_data.chat}
-
-            chat = Chats.update_chat_by_id(id, updated_chat)
-            return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
+async def update_chat_by_id(
+    id: str, form_data: ChatForm, user=Depends(get_current_user)
+):
+    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+    if chat:
+        updated_chat = {**json.loads(chat.chat), **form_data.chat}
+
+        chat = Chats.update_chat_by_id(id, updated_chat)
+        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
 
 
@@ -147,18 +104,9 @@ async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_sc
 
 
 @router.delete("/{id}", response_model=bool)
-async def delete_chat_by_id(id: str, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        result = Chats.delete_chat_by_id_and_user_id(id, user.id)
-        return result
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def delete_chat_by_id(id: str, user=Depends(get_current_user)):
+    result = Chats.delete_chat_by_id_and_user_id(id, user.id)
+    return result
 
 
 ############################
@@ -167,15 +115,6 @@ async def delete_chat_by_id(id: str, cred=Depends(bearer_scheme)):
 
 
 @router.delete("/", response_model=bool)
-async def delete_all_user_chats(cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        result = Chats.delete_chats_by_user_id(user.id)
-        return result
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def delete_all_user_chats(user=Depends(get_current_user)):
+    result = Chats.delete_chats_by_user_id(user.id)
+    return result

+ 63 - 115
backend/apps/web/routers/modelfiles.py

@@ -1,4 +1,3 @@
-from fastapi import Response
 from fastapi import Depends, FastAPI, HTTPException, status
 from datetime import datetime, timedelta
 from typing import List, Union, Optional
@@ -6,8 +5,6 @@ from typing import List, Union, Optional
 from fastapi import APIRouter
 from pydantic import BaseModel
 import json
-
-from apps.web.models.users import Users
 from apps.web.models.modelfiles import (
     Modelfiles,
     ModelfileForm,
@@ -16,9 +13,7 @@ from apps.web.models.modelfiles import (
     ModelfileResponse,
 )
 
-from utils.utils import (
-    bearer_scheme,
-)
+from utils.utils import get_current_user
 from constants import ERROR_MESSAGES
 
 router = APIRouter()
@@ -29,17 +24,8 @@ router = APIRouter()
 
 
 @router.get("/", response_model=List[ModelfileResponse])
-async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        return Modelfiles.get_modelfiles(skip, limit)
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def get_modelfiles(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
+    return Modelfiles.get_modelfiles(skip, limit)
 
 
 ############################
@@ -48,36 +34,28 @@ async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch
 
 
 @router.post("/create", response_model=Optional[ModelfileResponse])
-async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        # Admin Only
-        if user.role == "admin":
-            modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
-
-            if modelfile:
-                return ModelfileResponse(
-                    **{
-                        **modelfile.model_dump(),
-                        "modelfile": json.loads(modelfile.modelfile),
-                    }
-                )
-            else:
-                raise HTTPException(
-                    status_code=status.HTTP_401_UNAUTHORIZED,
-                    detail=ERROR_MESSAGES.DEFAULT(),
-                )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
+async def create_new_modelfile(
+    form_data: ModelfileForm, user=Depends(get_current_user)
+):
+    if user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+
+    modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
+
+    if modelfile:
+        return ModelfileResponse(
+            **{
+                **modelfile.model_dump(),
+                "modelfile": json.loads(modelfile.modelfile),
+            }
+        )
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            detail=ERROR_MESSAGES.DEFAULT(),
         )
 
 
@@ -87,31 +65,20 @@ async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_sch
 
 
 @router.post("/", response_model=Optional[ModelfileResponse])
-async def get_modelfile_by_tag_name(
-    form_data: ModelfileTagNameForm, cred=Depends(bearer_scheme)
-):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
-
-        if modelfile:
-            return ModelfileResponse(
-                **{
-                    **modelfile.model_dump(),
-                    "modelfile": json.loads(modelfile.modelfile),
-                }
-            )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.NOT_FOUND,
-            )
+async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm, user=Depends(get_current_user)):
+    modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
+
+    if modelfile:
+        return ModelfileResponse(
+            **{
+                **modelfile.model_dump(),
+                "modelfile": json.loads(modelfile.modelfile),
+            }
+        )
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            detail=ERROR_MESSAGES.NOT_FOUND,
         )
 
 
@@ -122,44 +89,34 @@ async def get_modelfile_by_tag_name(
 
 @router.post("/update", response_model=Optional[ModelfileResponse])
 async def update_modelfile_by_tag_name(
-    form_data: ModelfileUpdateForm, cred=Depends(bearer_scheme)
+    form_data: ModelfileUpdateForm, user=Depends(get_current_user)
 ):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        if user.role == "admin":
-            modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
-            if modelfile:
-                updated_modelfile = {
-                    **json.loads(modelfile.modelfile),
-                    **form_data.modelfile,
-                }
-
-                modelfile = Modelfiles.update_modelfile_by_tag_name(
-                    form_data.tag_name, updated_modelfile
-                )
-
-                return ModelfileResponse(
-                    **{
-                        **modelfile.model_dump(),
-                        "modelfile": json.loads(modelfile.modelfile),
-                    }
-                )
-            else:
-                raise HTTPException(
-                    status_code=status.HTTP_401_UNAUTHORIZED,
-                    detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-                )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
+    if user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+    modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
+    if modelfile:
+        updated_modelfile = {
+            **json.loads(modelfile.modelfile),
+            **form_data.modelfile,
+        }
+
+        modelfile = Modelfiles.update_modelfile_by_tag_name(
+            form_data.tag_name, updated_modelfile
+        )
+
+        return ModelfileResponse(
+            **{
+                **modelfile.model_dump(),
+                "modelfile": json.loads(modelfile.modelfile),
+            }
+        )
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
 
 
@@ -170,22 +127,13 @@ async def update_modelfile_by_tag_name(
 
 @router.delete("/delete", response_model=bool)
 async def delete_modelfile_by_tag_name(
-    form_data: ModelfileTagNameForm, cred=Depends(bearer_scheme)
+    form_data: ModelfileTagNameForm, user=Depends(get_current_user)
 ):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        if user.role == "admin":
-            result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
-            return result
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
-    else:
+    if user.role != "admin":
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
+
+    result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
+    return result

+ 31 - 61
backend/apps/web/routers/users.py

@@ -12,11 +12,7 @@ from apps.web.models.users import UserModel, UserRoleUpdateForm, Users
 from apps.web.models.auths import Auths
 
 
-from utils.utils import (
-    get_password_hash,
-    bearer_scheme,
-    create_token,
-)
+from utils.utils import get_current_user
 from constants import ERROR_MESSAGES
 
 router = APIRouter()
@@ -27,23 +23,13 @@ router = APIRouter()
 
 
 @router.get("/", response_model=List[UserModel])
-async def get_users(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        if user.role == "admin":
-            return Users.get_users(skip, limit)
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_403_FORBIDDEN,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
-    else:
+async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
+    if user.role != "admin":
         raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
+    return Users.get_users(skip, limit)
 
 
 ############################
@@ -52,28 +38,21 @@ async def get_users(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme))
 
 
 @router.post("/update/role", response_model=Optional[UserModel])
-async def update_user_role(form_data: UserRoleUpdateForm, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        if user.role == "admin":
-            if user.id != form_data.id:
-                return Users.update_user_role_by_id(form_data.id, form_data.role)
-            else:
-                raise HTTPException(
-                    status_code=status.HTTP_403_FORBIDDEN,
-                    detail=ERROR_MESSAGES.ACTION_PROHIBITED,
-                )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_403_FORBIDDEN,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
+async def update_user_role(
+    form_data: UserRoleUpdateForm, user=Depends(get_current_user)
+):
+    if user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+
+    if user.id != form_data.id:
+        return Users.update_user_role_by_id(form_data.id, form_data.role)
     else:
         raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail=ERROR_MESSAGES.ACTION_PROHIBITED,
         )
 
 
@@ -83,34 +62,25 @@ async def update_user_role(form_data: UserRoleUpdateForm, cred=Depends(bearer_sc
 
 
 @router.delete("/{user_id}", response_model=bool)
-async def delete_user_by_id(user_id: str, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        if user.role == "admin":
-            if user.id != user_id:
-                result = Auths.delete_auth_by_id(user_id)
-
-                if result:
-                    return True
-                else:
-                    raise HTTPException(
-                        status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
-                        detail=ERROR_MESSAGES.DELETE_USER_ERROR,
-                    )
+async def delete_user_by_id(user_id: str, user=Depends(get_current_user)):
+    if user.role == "admin":
+        if user.id != user_id:
+            result = Auths.delete_auth_by_id(user_id)
+
+            if result:
+                return True
             else:
                 raise HTTPException(
-                    status_code=status.HTTP_403_FORBIDDEN,
-                    detail=ERROR_MESSAGES.ACTION_PROHIBITED,
+                    status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+                    detail=ERROR_MESSAGES.DELETE_USER_ERROR,
                 )
         else:
             raise HTTPException(
                 status_code=status.HTTP_403_FORBIDDEN,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+                detail=ERROR_MESSAGES.ACTION_PROHIBITED,
             )
     else:
         raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )

+ 2 - 0
backend/requirements.txt

@@ -18,3 +18,5 @@ bcrypt
 
 PyJWT
 pyjwt[crypto]
+
+black

+ 18 - 14
backend/utils/utils.py

@@ -1,7 +1,9 @@
-from fastapi.security import HTTPBasicCredentials, HTTPBearer
+from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
+from fastapi import HTTPException, status, Depends
+from apps.web.models.users import Users
 from pydantic import BaseModel
 from typing import Union, Optional
-
+from constants import ERROR_MESSAGES
 from passlib.context import CryptContext
 from datetime import datetime, timedelta
 import requests
@@ -53,16 +55,18 @@ def extract_token_from_auth_header(auth_header: str):
     return auth_header[len("Bearer ") :]
 
 
-def verify_token(request):
-    try:
-        bearer = request.headers["authorization"]
-        if bearer:
-            token = bearer[len("Bearer ") :]
-            decoded = jwt.decode(
-                token, JWT_SECRET_KEY, options={"verify_signature": False}
+def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
+    data = decode_token(auth_token.credentials)
+    if data != None and "email" in data:
+        user = Users.get_user_by_email(data["email"])
+        if user is None:
+            raise HTTPException(
+                status_code=status.HTTP_401_UNAUTHORIZED,
+                detail=ERROR_MESSAGES.INVALID_TOKEN,
             )
-            return decoded
-        else:
-            return None
-    except Exception as e:
-        return None
+        return user
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.UNAUTHORIZED,
+        )