ソースを参照

Endpoint role-checking was redundantly applied but FastAPI provides a nice abstraction mechanic...so I applied it. There should be no logical changes in this code; only simpler, cleaner ways for doing the same thing.

Tim Farrell 1 年間 前
コミット
08e8e922fd

+ 9 - 15
backend/apps/ollama/main.py

@@ -1,4 +1,4 @@
-from fastapi import FastAPI, Request, Response, HTTPException, Depends
+from fastapi import FastAPI, Request, Response, HTTPException, Depends, status
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.responses import StreamingResponse
 from fastapi.concurrency import run_in_threadpool
@@ -10,7 +10,7 @@ from pydantic import BaseModel
 
 from apps.web.models.users import Users
 from constants import ERROR_MESSAGES
-from utils.utils import decode_token, get_current_user
+from utils.utils import decode_token, get_current_user, get_admin_user
 from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
 
 app = FastAPI()
@@ -31,11 +31,8 @@ REQUEST_POOL = []
 
 
 @app.get("/url")
-async def get_ollama_api_url(user=Depends(get_current_user)):
-    if user and user.role == "admin":
-        return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
-    else:
-        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+async def get_ollama_api_url(user=Depends(get_admin_user)):
+    return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
 
 
 class UrlUpdateForm(BaseModel):
@@ -44,13 +41,10 @@ class UrlUpdateForm(BaseModel):
 
 @app.post("/url/update")
 async def update_ollama_api_url(
-    form_data: UrlUpdateForm, user=Depends(get_current_user)
+    form_data: UrlUpdateForm, user=Depends(get_admin_user)
 ):
-    if user and user.role == "admin":
-        app.state.OLLAMA_API_BASE_URL = form_data.url
-        return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
-    else:
-        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+    app.state.OLLAMA_API_BASE_URL = form_data.url
+    return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
 
 
 @app.get("/cancel/{request_id}")
@@ -74,10 +68,10 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
         if path in ["pull", "delete", "push", "copy", "create"]:
             if user.role != "admin":
                 raise HTTPException(
-                    status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
+                    status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
                 )
     else:
-        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
 
     headers.pop("host", None)
     headers.pop("authorization", None)

+ 14 - 29
backend/apps/openai/main.py

@@ -9,7 +9,7 @@ from pydantic import BaseModel
 
 from apps.web.models.users import Users
 from constants import ERROR_MESSAGES
-from utils.utils import decode_token, get_current_user
+from utils.utils import decode_token, get_current_user, get_verified_user, get_admin_user
 from config import OPENAI_API_BASE_URL, OPENAI_API_KEY, CACHE_DIR
 
 import hashlib
@@ -37,45 +37,32 @@ class KeyUpdateForm(BaseModel):
 
 
 @app.get("/url")
-async def get_openai_url(user=Depends(get_current_user)):
-    if user and user.role == "admin":
-        return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
-    else:
-        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+async def get_openai_url(user=Depends(get_admin_user)):
+    return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
 
 
 @app.post("/url/update")
-async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_current_user)):
-    if user and user.role == "admin":
-        app.state.OPENAI_API_BASE_URL = form_data.url
-        return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
-    else:
-        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
+    app.state.OPENAI_API_BASE_URL = form_data.url
+    return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
+
 
 
 @app.get("/key")
-async def get_openai_key(user=Depends(get_current_user)):
-    if user and user.role == "admin":
-        return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
-    else:
-        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+async def get_openai_key(user=Depends(get_admin_user)):
+    return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
 
 
 @app.post("/key/update")
-async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_user)):
-    if user and user.role == "admin":
-        app.state.OPENAI_API_KEY = form_data.key
-        return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
-    else:
-        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_admin_user)):
+    app.state.OPENAI_API_KEY = form_data.key
+    return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
 
 
 @app.post("/audio/speech")
-async def speech(request: Request, user=Depends(get_current_user)):
+async def speech(request: Request, user=Depends(get_verified_user)):
     target_url = f"{app.state.OPENAI_API_BASE_URL}/audio/speech"
 
-    if user.role not in ["user", "admin"]:
-        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
     if app.state.OPENAI_API_KEY == "":
         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
 
@@ -133,12 +120,10 @@ async def speech(request: Request, user=Depends(get_current_user)):
 
 
 @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
-async def proxy(path: str, request: Request, user=Depends(get_current_user)):
+async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
     target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}"
     print(target_url, app.state.OPENAI_API_KEY)
 
-    if user.role not in ["user", "admin"]:
-        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
     if app.state.OPENAI_API_KEY == "":
         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
 

+ 18 - 30
backend/apps/rag/main.py

@@ -39,7 +39,7 @@ import uuid
 import time
 
 from utils.misc import calculate_sha256, calculate_sha256_string
-from utils.utils import get_current_user
+from utils.utils import get_current_user, get_admin_user
 from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
 from constants import ERROR_MESSAGES
 
@@ -354,38 +354,26 @@ def store_doc(
 
 
 @app.get("/reset/db")
-def reset_vector_db(user=Depends(get_current_user)):
-    if user.role == "admin":
-        CHROMA_CLIENT.reset()
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_403_FORBIDDEN,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
+def reset_vector_db(user=Depends(get_admin_user)):
+    CHROMA_CLIENT.reset()
 
 
 @app.get("/reset")
-def reset(user=Depends(get_current_user)) -> bool:
-    if user.role == "admin":
-        folder = f"{UPLOAD_DIR}"
-        for filename in os.listdir(folder):
-            file_path = os.path.join(folder, filename)
-            try:
-                if os.path.isfile(file_path) or os.path.islink(file_path):
-                    os.unlink(file_path)
-                elif os.path.isdir(file_path):
-                    shutil.rmtree(file_path)
-            except Exception as e:
-                print("Failed to delete %s. Reason: %s" % (file_path, e))
-
+def reset(user=Depends(get_admin_user)) -> bool:
+    folder = f"{UPLOAD_DIR}"
+    for filename in os.listdir(folder):
+        file_path = os.path.join(folder, filename)
         try:
-            CHROMA_CLIENT.reset()
+            if os.path.isfile(file_path) or os.path.islink(file_path):
+                os.unlink(file_path)
+            elif os.path.isdir(file_path):
+                shutil.rmtree(file_path)
         except Exception as e:
-            print(e)
+            print("Failed to delete %s. Reason: %s" % (file_path, e))
 
-        return True
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_403_FORBIDDEN,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
+    try:
+        CHROMA_CLIENT.reset()
+    except Exception as e:
+        print(e)
+
+    return True

+ 9 - 21
backend/apps/web/routers/auths.py

@@ -3,7 +3,7 @@ from fastapi import Depends, FastAPI, HTTPException, status
 from datetime import datetime, timedelta
 from typing import List, Union
 
-from fastapi import APIRouter
+from fastapi import APIRouter, status
 from pydantic import BaseModel
 import time
 import uuid
@@ -19,7 +19,7 @@ from apps.web.models.auths import (
 )
 from apps.web.models.users import Users
 
-from utils.utils import get_password_hash, get_current_user, create_token
+from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token
 from utils.misc import get_gravatar_url, validate_email_format
 from constants import ERROR_MESSAGES
 
@@ -116,10 +116,10 @@ async def signin(form_data: SigninForm):
 @router.post("/signup", response_model=SigninResponse)
 async def signup(request: Request, form_data: SignupForm):
     if not request.app.state.ENABLE_SIGNUP:
-        raise HTTPException(400, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+        raise HTTPException(status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
 
     if not validate_email_format(form_data.email.lower()):
-        raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT)
+        raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT)
 
     if Users.get_user_by_email(form_data.email.lower()):
         raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
@@ -156,23 +156,11 @@ async def signup(request: Request, form_data: SignupForm):
 
 
 @router.get("/signup/enabled", response_model=bool)
-async def get_sign_up_status(request: Request, user=Depends(get_current_user)):
-    if user.role == "admin":
-        return request.app.state.ENABLE_SIGNUP
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_403_FORBIDDEN,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
+async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
+    return request.app.state.ENABLE_SIGNUP
 
 
 @router.get("/signup/enabled/toggle", response_model=bool)
-async def toggle_sign_up(request: Request, user=Depends(get_current_user)):
-    if user.role == "admin":
-        request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP
-        return request.app.state.ENABLE_SIGNUP
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_403_FORBIDDEN,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
+async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
+    request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP
+    return request.app.state.ENABLE_SIGNUP

+ 6 - 12
backend/apps/web/routers/chats.py

@@ -1,7 +1,7 @@
 from fastapi import Depends, Request, HTTPException, status
 from datetime import datetime, timedelta
 from typing import List, Union, Optional
-from utils.utils import get_current_user
+from utils.utils import get_current_user, get_admin_user
 from fastapi import APIRouter
 from pydantic import BaseModel
 import json
@@ -60,17 +60,11 @@ async def get_all_user_chats(user=Depends(get_current_user)):
 
 
 @router.get("/all/db", response_model=List[ChatResponse])
-async def get_all_user_chats_in_db(user=Depends(get_current_user)):
-    if user.role == "admin":
-        return [
-            ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-            for chat in Chats.get_all_chats()
-        ]
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_403_FORBIDDEN,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
+async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
+    return [
+        ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+        for chat in Chats.get_all_chats()
+    ]
 
 
 ############################

+ 9 - 20
backend/apps/web/routers/configs.py

@@ -10,7 +10,7 @@ import uuid
 
 from apps.web.models.users import Users
 
-from utils.utils import get_password_hash, get_current_user, create_token
+from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token
 from utils.misc import get_gravatar_url, validate_email_format
 from constants import ERROR_MESSAGES
 
@@ -37,30 +37,19 @@ class SetDefaultSuggestionsForm(BaseModel):
 
 @router.post("/default/models", response_model=str)
 async def set_global_default_models(
-    request: Request, form_data: SetDefaultModelsForm, user=Depends(get_current_user)
+    request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user)
 ):
-    if user.role == "admin":
-        request.app.state.DEFAULT_MODELS = form_data.models
-        return request.app.state.DEFAULT_MODELS
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_403_FORBIDDEN,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
+    request.app.state.DEFAULT_MODELS = form_data.models
+    return request.app.state.DEFAULT_MODELS
+
 
 
 @router.post("/default/suggestions", response_model=List[PromptSuggestion])
 async def set_global_default_suggestions(
     request: Request,
     form_data: SetDefaultSuggestionsForm,
-    user=Depends(get_current_user),
+    user=Depends(get_admin_user),
 ):
-    if user.role == "admin":
-        data = form_data.model_dump()
-        request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
-        return request.app.state.DEFAULT_PROMPT_SUGGESTIONS
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_403_FORBIDDEN,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
+    data = form_data.model_dump()
+    request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
+    return request.app.state.DEFAULT_PROMPT_SUGGESTIONS

+ 4 - 22
backend/apps/web/routers/documents.py

@@ -14,7 +14,7 @@ from apps.web.models.documents import (
     DocumentResponse,
 )
 
-from utils.utils import get_current_user
+from utils.utils import get_current_user, get_admin_user
 from constants import ERROR_MESSAGES
 
 router = APIRouter()
@@ -44,13 +44,7 @@ async def get_documents(user=Depends(get_current_user)):
 
 
 @router.post("/create", response_model=Optional[DocumentResponse])
-async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user)):
-    if user.role != "admin":
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
-
+async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
     doc = Documents.get_doc_by_name(form_data.name)
     if doc == None:
         doc = Documents.insert_new_doc(user.id, form_data)
@@ -132,14 +126,8 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_u
 
 @router.post("/name/{name}/update", response_model=Optional[DocumentResponse])
 async def update_doc_by_name(
-    name: str, form_data: DocumentUpdateForm, user=Depends(get_current_user)
+    name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user)
 ):
-    if user.role != "admin":
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
-
     doc = Documents.update_doc_by_name(name, form_data)
     if doc:
         return DocumentResponse(
@@ -161,12 +149,6 @@ async def update_doc_by_name(
 
 
 @router.delete("/name/{name}/delete", response_model=bool)
-async def delete_doc_by_name(name: str, user=Depends(get_current_user)):
-    if user.role != "admin":
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
-
+async def delete_doc_by_name(name: str, user=Depends(get_admin_user)):
     result = Documents.delete_doc_by_name(name)
     return result

+ 4 - 21
backend/apps/web/routers/modelfiles.py

@@ -13,7 +13,7 @@ from apps.web.models.modelfiles import (
     ModelfileResponse,
 )
 
-from utils.utils import get_current_user
+from utils.utils import get_current_user, get_admin_user
 from constants import ERROR_MESSAGES
 
 router = APIRouter()
@@ -37,13 +37,7 @@ async def get_modelfiles(skip: int = 0,
 
 @router.post("/create", response_model=Optional[ModelfileResponse])
 async def create_new_modelfile(form_data: ModelfileForm,
-                               user=Depends(get_current_user)):
-    if user.role != "admin":
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
-
+                               user=Depends(get_admin_user)):
     modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
 
     if modelfile:
@@ -91,12 +85,7 @@ async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
 
 @router.post("/update", response_model=Optional[ModelfileResponse])
 async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
-                                       user=Depends(get_current_user)):
-    if user.role != "admin":
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
+                                       user=Depends(get_admin_user)):
     modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
     if modelfile:
         updated_modelfile = {
@@ -127,12 +116,6 @@ async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
 
 @router.delete("/delete", response_model=bool)
 async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
-                                       user=Depends(get_current_user)):
-    if user.role != "admin":
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
-
+                                       user=Depends(get_admin_user)):
     result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
     return result

+ 9 - 29
backend/apps/web/routers/prompts.py

@@ -8,7 +8,7 @@ import json
 
 from apps.web.models.prompts import Prompts, PromptForm, PromptModel
 
-from utils.utils import get_current_user
+from utils.utils import get_current_user, get_admin_user
 from constants import ERROR_MESSAGES
 
 router = APIRouter()
@@ -29,29 +29,21 @@ async def get_prompts(user=Depends(get_current_user)):
 
 
 @router.post("/create", response_model=Optional[PromptModel])
-async def create_new_prompt(form_data: PromptForm, user=Depends(get_current_user)):
-    if user.role != "admin":
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
-
+async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)):
     prompt = Prompts.get_prompt_by_command(form_data.command)
     if prompt == None:
         prompt = Prompts.insert_new_prompt(user.id, form_data)
 
         if prompt:
             return prompt
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.DEFAULT(),
-            )
-    else:
         raise HTTPException(
             status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.COMMAND_TAKEN,
+            detail=ERROR_MESSAGES.DEFAULT(),
         )
+    raise HTTPException(
+        status_code=status.HTTP_400_BAD_REQUEST,
+        detail=ERROR_MESSAGES.COMMAND_TAKEN,
+    )
 
 
 ############################
@@ -79,14 +71,8 @@ async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
 
 @router.post("/command/{command}/update", response_model=Optional[PromptModel])
 async def update_prompt_by_command(
-    command: str, form_data: PromptForm, user=Depends(get_current_user)
+    command: str, form_data: PromptForm, user=Depends(get_admin_user)
 ):
-    if user.role != "admin":
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
-
     prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
     if prompt:
         return prompt
@@ -103,12 +89,6 @@ async def update_prompt_by_command(
 
 
 @router.delete("/command/{command}/delete", response_model=bool)
-async def delete_prompt_by_command(command: str, user=Depends(get_current_user)):
-    if user.role != "admin":
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
-
+async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)):
     result = Prompts.delete_prompt_by_command(f"/{command}")
     return result

+ 30 - 53
backend/apps/web/routers/users.py

@@ -11,7 +11,7 @@ import uuid
 from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
 from apps.web.models.auths import Auths
 
-from utils.utils import get_current_user, get_password_hash
+from utils.utils import get_current_user, get_password_hash, get_admin_user
 from constants import ERROR_MESSAGES
 
 router = APIRouter()
@@ -22,12 +22,7 @@ router = APIRouter()
 
 
 @router.get("/", response_model=List[UserModel])
-async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
-    if user.role != "admin":
-        raise HTTPException(
-            status_code=status.HTTP_403_FORBIDDEN,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
+async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)):
     return Users.get_users(skip, limit)
 
 
@@ -38,21 +33,15 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_use
 
 @router.post("/update/role", response_model=Optional[UserModel])
 async def update_user_role(
-    form_data: UserRoleUpdateForm, user=Depends(get_current_user)
+    form_data: UserRoleUpdateForm, user=Depends(get_admin_user)
 ):
-    if user.role != "admin":
-        raise HTTPException(
-            status_code=status.HTTP_403_FORBIDDEN,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
-
     if user.id != form_data.id:
         return Users.update_user_role_by_id(form_data.id, form_data.role)
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_403_FORBIDDEN,
-            detail=ERROR_MESSAGES.ACTION_PROHIBITED,
-        )
+
+    raise HTTPException(
+        status_code=status.HTTP_403_FORBIDDEN,
+        detail=ERROR_MESSAGES.ACTION_PROHIBITED,
+    )
 
 
 ############################
@@ -62,14 +51,8 @@ async def update_user_role(
 
 @router.post("/{user_id}/update", response_model=Optional[UserModel])
 async def update_user_by_id(
-    user_id: str, form_data: UserUpdateForm, session_user=Depends(get_current_user)
+    user_id: str, form_data: UserUpdateForm, session_user=Depends(get_admin_user)
 ):
-    if session_user.role != "admin":
-        raise HTTPException(
-            status_code=status.HTTP_403_FORBIDDEN,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
-
     user = Users.get_user_by_id(user_id)
 
     if user:
@@ -98,18 +81,17 @@ async def update_user_by_id(
 
         if updated_user:
             return updated_user
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_400_BAD_REQUEST,
-                detail=ERROR_MESSAGES.DEFAULT(),
-            )
 
-    else:
         raise HTTPException(
             status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.USER_NOT_FOUND,
+            detail=ERROR_MESSAGES.DEFAULT(),
         )
 
+    raise HTTPException(
+        status_code=status.HTTP_400_BAD_REQUEST,
+        detail=ERROR_MESSAGES.USER_NOT_FOUND,
+    )
+
 
 ############################
 # DeleteUserById
@@ -117,25 +99,20 @@ async def update_user_by_id(
 
 
 @router.delete("/{user_id}", response_model=bool)
-async def delete_user_by_id(user_id: str, user=Depends(get_current_user)):
-    if user.role == "admin":
-        if user.id != user_id:
-            result = Auths.delete_auth_by_id(user_id)
-
-            if result:
-                return True
-            else:
-                raise HTTPException(
-                    status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
-                    detail=ERROR_MESSAGES.DELETE_USER_ERROR,
-                )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_403_FORBIDDEN,
-                detail=ERROR_MESSAGES.ACTION_PROHIBITED,
-            )
-    else:
+async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
+    if user.id != user_id:
+        result = Auths.delete_auth_by_id(user_id)
+
+        if result:
+            return True
+
         raise HTTPException(
-            status_code=status.HTTP_403_FORBIDDEN,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+            detail=ERROR_MESSAGES.DELETE_USER_ERROR,
         )
+
+    raise HTTPException(
+        status_code=status.HTTP_403_FORBIDDEN,
+        detail=ERROR_MESSAGES.ACTION_PROHIBITED,
+    )
+

+ 16 - 0
backend/utils/utils.py

@@ -73,3 +73,19 @@ def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(bearer_s
             status_code=status.HTTP_401_UNAUTHORIZED,
             detail=ERROR_MESSAGES.UNAUTHORIZED,
         )
+
+
+def get_verified_user(user: Users = Depends(get_current_user)):
+    if user.role not in {"user", "admin"}:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+
+
+def get_admin_user(user: Users = Depends(get_current_user)):
+    if user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )