Browse Source

Merge pull request #682 from explorigin/simplify-endpoint-code

Simplify endpoint role checking
Timothy Jaeryang Baek 1 year ago
parent
commit
9f3346a6ec

+ 9 - 15
backend/apps/ollama/main.py

@@ -1,4 +1,4 @@
-from fastapi import FastAPI, Request, Response, HTTPException, Depends
+from fastapi import FastAPI, Request, Response, HTTPException, Depends, status
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.responses import StreamingResponse
 from fastapi.responses import StreamingResponse
 from fastapi.concurrency import run_in_threadpool
 from fastapi.concurrency import run_in_threadpool
@@ -10,7 +10,7 @@ from pydantic import BaseModel
 
 
 from apps.web.models.users import Users
 from apps.web.models.users import Users
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
-from utils.utils import decode_token, get_current_user
+from utils.utils import decode_token, get_current_user, get_admin_user
 from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
 from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
 
 
 app = FastAPI()
 app = FastAPI()
@@ -31,11 +31,8 @@ REQUEST_POOL = []
 
 
 
 
 @app.get("/url")
 @app.get("/url")
-async def get_ollama_api_url(user=Depends(get_current_user)):
-    if user and user.role == "admin":
-        return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
-    else:
-        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+async def get_ollama_api_url(user=Depends(get_admin_user)):
+    return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
 
 
 
 
 class UrlUpdateForm(BaseModel):
 class UrlUpdateForm(BaseModel):
@@ -44,13 +41,10 @@ class UrlUpdateForm(BaseModel):
 
 
 @app.post("/url/update")
 @app.post("/url/update")
 async def update_ollama_api_url(
 async def update_ollama_api_url(
-    form_data: UrlUpdateForm, user=Depends(get_current_user)
+    form_data: UrlUpdateForm, user=Depends(get_admin_user)
 ):
 ):
-    if user and user.role == "admin":
-        app.state.OLLAMA_API_BASE_URL = form_data.url
-        return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
-    else:
-        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+    app.state.OLLAMA_API_BASE_URL = form_data.url
+    return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
 
 
 
 
 @app.get("/cancel/{request_id}")
 @app.get("/cancel/{request_id}")
@@ -74,10 +68,10 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
         if path in ["pull", "delete", "push", "copy", "create"]:
         if path in ["pull", "delete", "push", "copy", "create"]:
             if user.role != "admin":
             if user.role != "admin":
                 raise HTTPException(
                 raise HTTPException(
-                    status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
+                    status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
                 )
                 )
     else:
     else:
-        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
 
 
     headers.pop("host", None)
     headers.pop("host", None)
     headers.pop("authorization", None)
     headers.pop("authorization", None)

+ 14 - 29
backend/apps/openai/main.py

@@ -9,7 +9,7 @@ from pydantic import BaseModel
 
 
 from apps.web.models.users import Users
 from apps.web.models.users import Users
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
-from utils.utils import decode_token, get_current_user
+from utils.utils import decode_token, get_current_user, get_verified_user, get_admin_user
 from config import OPENAI_API_BASE_URL, OPENAI_API_KEY, CACHE_DIR
 from config import OPENAI_API_BASE_URL, OPENAI_API_KEY, CACHE_DIR
 
 
 import hashlib
 import hashlib
@@ -37,45 +37,32 @@ class KeyUpdateForm(BaseModel):
 
 
 
 
 @app.get("/url")
 @app.get("/url")
-async def get_openai_url(user=Depends(get_current_user)):
-    if user and user.role == "admin":
-        return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
-    else:
-        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+async def get_openai_url(user=Depends(get_admin_user)):
+    return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
 
 
 
 
 @app.post("/url/update")
 @app.post("/url/update")
-async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_current_user)):
-    if user and user.role == "admin":
-        app.state.OPENAI_API_BASE_URL = form_data.url
-        return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
-    else:
-        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
+    app.state.OPENAI_API_BASE_URL = form_data.url
+    return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
+
 
 
 
 
 @app.get("/key")
 @app.get("/key")
-async def get_openai_key(user=Depends(get_current_user)):
-    if user and user.role == "admin":
-        return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
-    else:
-        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+async def get_openai_key(user=Depends(get_admin_user)):
+    return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
 
 
 
 
 @app.post("/key/update")
 @app.post("/key/update")
-async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_user)):
-    if user and user.role == "admin":
-        app.state.OPENAI_API_KEY = form_data.key
-        return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
-    else:
-        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_admin_user)):
+    app.state.OPENAI_API_KEY = form_data.key
+    return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
 
 
 
 
 @app.post("/audio/speech")
 @app.post("/audio/speech")
-async def speech(request: Request, user=Depends(get_current_user)):
+async def speech(request: Request, user=Depends(get_verified_user)):
     target_url = f"{app.state.OPENAI_API_BASE_URL}/audio/speech"
     target_url = f"{app.state.OPENAI_API_BASE_URL}/audio/speech"
 
 
-    if user.role not in ["user", "admin"]:
-        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
     if app.state.OPENAI_API_KEY == "":
     if app.state.OPENAI_API_KEY == "":
         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
 
 
@@ -133,12 +120,10 @@ async def speech(request: Request, user=Depends(get_current_user)):
 
 
 
 
 @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
 @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
-async def proxy(path: str, request: Request, user=Depends(get_current_user)):
+async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
     target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}"
     target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}"
     print(target_url, app.state.OPENAI_API_KEY)
     print(target_url, app.state.OPENAI_API_KEY)
 
 
-    if user.role not in ["user", "admin"]:
-        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
     if app.state.OPENAI_API_KEY == "":
     if app.state.OPENAI_API_KEY == "":
         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
 
 

+ 18 - 30
backend/apps/rag/main.py

@@ -39,7 +39,7 @@ import uuid
 import time
 import time
 
 
 from utils.misc import calculate_sha256, calculate_sha256_string
 from utils.misc import calculate_sha256, calculate_sha256_string
-from utils.utils import get_current_user
+from utils.utils import get_current_user, get_admin_user
 from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
 from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
 
 
@@ -354,38 +354,26 @@ def store_doc(
 
 
 
 
 @app.get("/reset/db")
 @app.get("/reset/db")
-def reset_vector_db(user=Depends(get_current_user)):
-    if user.role == "admin":
-        CHROMA_CLIENT.reset()
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_403_FORBIDDEN,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
+def reset_vector_db(user=Depends(get_admin_user)):
+    CHROMA_CLIENT.reset()
 
 
 
 
 @app.get("/reset")
 @app.get("/reset")
-def reset(user=Depends(get_current_user)) -> bool:
-    if user.role == "admin":
-        folder = f"{UPLOAD_DIR}"
-        for filename in os.listdir(folder):
-            file_path = os.path.join(folder, filename)
-            try:
-                if os.path.isfile(file_path) or os.path.islink(file_path):
-                    os.unlink(file_path)
-                elif os.path.isdir(file_path):
-                    shutil.rmtree(file_path)
-            except Exception as e:
-                print("Failed to delete %s. Reason: %s" % (file_path, e))
-
+def reset(user=Depends(get_admin_user)) -> bool:
+    folder = f"{UPLOAD_DIR}"
+    for filename in os.listdir(folder):
+        file_path = os.path.join(folder, filename)
         try:
         try:
-            CHROMA_CLIENT.reset()
+            if os.path.isfile(file_path) or os.path.islink(file_path):
+                os.unlink(file_path)
+            elif os.path.isdir(file_path):
+                shutil.rmtree(file_path)
         except Exception as e:
         except Exception as e:
-            print(e)
+            print("Failed to delete %s. Reason: %s" % (file_path, e))
 
 
-        return True
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_403_FORBIDDEN,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
+    try:
+        CHROMA_CLIENT.reset()
+    except Exception as e:
+        print(e)
+
+    return True

+ 9 - 21
backend/apps/web/routers/auths.py

@@ -3,7 +3,7 @@ from fastapi import Depends, FastAPI, HTTPException, status
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 from typing import List, Union
 from typing import List, Union
 
 
-from fastapi import APIRouter
+from fastapi import APIRouter, status
 from pydantic import BaseModel
 from pydantic import BaseModel
 import time
 import time
 import uuid
 import uuid
@@ -19,7 +19,7 @@ from apps.web.models.auths import (
 )
 )
 from apps.web.models.users import Users
 from apps.web.models.users import Users
 
 
-from utils.utils import get_password_hash, get_current_user, create_token
+from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token
 from utils.misc import get_gravatar_url, validate_email_format
 from utils.misc import get_gravatar_url, validate_email_format
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
 
 
@@ -116,10 +116,10 @@ async def signin(form_data: SigninForm):
 @router.post("/signup", response_model=SigninResponse)
 @router.post("/signup", response_model=SigninResponse)
 async def signup(request: Request, form_data: SignupForm):
 async def signup(request: Request, form_data: SignupForm):
     if not request.app.state.ENABLE_SIGNUP:
     if not request.app.state.ENABLE_SIGNUP:
-        raise HTTPException(400, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+        raise HTTPException(status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
 
 
     if not validate_email_format(form_data.email.lower()):
     if not validate_email_format(form_data.email.lower()):
-        raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT)
+        raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT)
 
 
     if Users.get_user_by_email(form_data.email.lower()):
     if Users.get_user_by_email(form_data.email.lower()):
         raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
         raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
@@ -156,23 +156,11 @@ async def signup(request: Request, form_data: SignupForm):
 
 
 
 
 @router.get("/signup/enabled", response_model=bool)
 @router.get("/signup/enabled", response_model=bool)
-async def get_sign_up_status(request: Request, user=Depends(get_current_user)):
-    if user.role == "admin":
-        return request.app.state.ENABLE_SIGNUP
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_403_FORBIDDEN,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
+async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
+    return request.app.state.ENABLE_SIGNUP
 
 
 
 
 @router.get("/signup/enabled/toggle", response_model=bool)
 @router.get("/signup/enabled/toggle", response_model=bool)
-async def toggle_sign_up(request: Request, user=Depends(get_current_user)):
-    if user.role == "admin":
-        request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP
-        return request.app.state.ENABLE_SIGNUP
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_403_FORBIDDEN,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
+async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
+    request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP
+    return request.app.state.ENABLE_SIGNUP

+ 6 - 12
backend/apps/web/routers/chats.py

@@ -1,7 +1,7 @@
 from fastapi import Depends, Request, HTTPException, status
 from fastapi import Depends, Request, HTTPException, status
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 from typing import List, Union, Optional
 from typing import List, Union, Optional
-from utils.utils import get_current_user
+from utils.utils import get_current_user, get_admin_user
 from fastapi import APIRouter
 from fastapi import APIRouter
 from pydantic import BaseModel
 from pydantic import BaseModel
 import json
 import json
@@ -60,17 +60,11 @@ async def get_all_user_chats(user=Depends(get_current_user)):
 
 
 
 
 @router.get("/all/db", response_model=List[ChatResponse])
 @router.get("/all/db", response_model=List[ChatResponse])
-async def get_all_user_chats_in_db(user=Depends(get_current_user)):
-    if user.role == "admin":
-        return [
-            ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-            for chat in Chats.get_all_chats()
-        ]
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_403_FORBIDDEN,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
+async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
+    return [
+        ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+        for chat in Chats.get_all_chats()
+    ]
 
 
 
 
 ############################
 ############################

+ 9 - 20
backend/apps/web/routers/configs.py

@@ -10,7 +10,7 @@ import uuid
 
 
 from apps.web.models.users import Users
 from apps.web.models.users import Users
 
 
-from utils.utils import get_password_hash, get_current_user, create_token
+from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token
 from utils.misc import get_gravatar_url, validate_email_format
 from utils.misc import get_gravatar_url, validate_email_format
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
 
 
@@ -37,30 +37,19 @@ class SetDefaultSuggestionsForm(BaseModel):
 
 
 @router.post("/default/models", response_model=str)
 @router.post("/default/models", response_model=str)
 async def set_global_default_models(
 async def set_global_default_models(
-    request: Request, form_data: SetDefaultModelsForm, user=Depends(get_current_user)
+    request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user)
 ):
 ):
-    if user.role == "admin":
-        request.app.state.DEFAULT_MODELS = form_data.models
-        return request.app.state.DEFAULT_MODELS
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_403_FORBIDDEN,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
+    request.app.state.DEFAULT_MODELS = form_data.models
+    return request.app.state.DEFAULT_MODELS
+
 
 
 
 
 @router.post("/default/suggestions", response_model=List[PromptSuggestion])
 @router.post("/default/suggestions", response_model=List[PromptSuggestion])
 async def set_global_default_suggestions(
 async def set_global_default_suggestions(
     request: Request,
     request: Request,
     form_data: SetDefaultSuggestionsForm,
     form_data: SetDefaultSuggestionsForm,
-    user=Depends(get_current_user),
+    user=Depends(get_admin_user),
 ):
 ):
-    if user.role == "admin":
-        data = form_data.model_dump()
-        request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
-        return request.app.state.DEFAULT_PROMPT_SUGGESTIONS
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_403_FORBIDDEN,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
+    data = form_data.model_dump()
+    request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
+    return request.app.state.DEFAULT_PROMPT_SUGGESTIONS

+ 4 - 22
backend/apps/web/routers/documents.py

@@ -14,7 +14,7 @@ from apps.web.models.documents import (
     DocumentResponse,
     DocumentResponse,
 )
 )
 
 
-from utils.utils import get_current_user
+from utils.utils import get_current_user, get_admin_user
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
 
 
 router = APIRouter()
 router = APIRouter()
@@ -44,13 +44,7 @@ async def get_documents(user=Depends(get_current_user)):
 
 
 
 
 @router.post("/create", response_model=Optional[DocumentResponse])
 @router.post("/create", response_model=Optional[DocumentResponse])
-async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user)):
-    if user.role != "admin":
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
-
+async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
     doc = Documents.get_doc_by_name(form_data.name)
     doc = Documents.get_doc_by_name(form_data.name)
     if doc == None:
     if doc == None:
         doc = Documents.insert_new_doc(user.id, form_data)
         doc = Documents.insert_new_doc(user.id, form_data)
@@ -132,14 +126,8 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_u
 
 
 @router.post("/name/{name}/update", response_model=Optional[DocumentResponse])
 @router.post("/name/{name}/update", response_model=Optional[DocumentResponse])
 async def update_doc_by_name(
 async def update_doc_by_name(
-    name: str, form_data: DocumentUpdateForm, user=Depends(get_current_user)
+    name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user)
 ):
 ):
-    if user.role != "admin":
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
-
     doc = Documents.update_doc_by_name(name, form_data)
     doc = Documents.update_doc_by_name(name, form_data)
     if doc:
     if doc:
         return DocumentResponse(
         return DocumentResponse(
@@ -161,12 +149,6 @@ async def update_doc_by_name(
 
 
 
 
 @router.delete("/name/{name}/delete", response_model=bool)
 @router.delete("/name/{name}/delete", response_model=bool)
-async def delete_doc_by_name(name: str, user=Depends(get_current_user)):
-    if user.role != "admin":
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
-
+async def delete_doc_by_name(name: str, user=Depends(get_admin_user)):
     result = Documents.delete_doc_by_name(name)
     result = Documents.delete_doc_by_name(name)
     return result
     return result

+ 4 - 21
backend/apps/web/routers/modelfiles.py

@@ -13,7 +13,7 @@ from apps.web.models.modelfiles import (
     ModelfileResponse,
     ModelfileResponse,
 )
 )
 
 
-from utils.utils import get_current_user
+from utils.utils import get_current_user, get_admin_user
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
 
 
 router = APIRouter()
 router = APIRouter()
@@ -37,13 +37,7 @@ async def get_modelfiles(skip: int = 0,
 
 
 @router.post("/create", response_model=Optional[ModelfileResponse])
 @router.post("/create", response_model=Optional[ModelfileResponse])
 async def create_new_modelfile(form_data: ModelfileForm,
 async def create_new_modelfile(form_data: ModelfileForm,
-                               user=Depends(get_current_user)):
-    if user.role != "admin":
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
-
+                               user=Depends(get_admin_user)):
     modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
     modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
 
 
     if modelfile:
     if modelfile:
@@ -91,12 +85,7 @@ async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
 
 
 @router.post("/update", response_model=Optional[ModelfileResponse])
 @router.post("/update", response_model=Optional[ModelfileResponse])
 async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
 async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
-                                       user=Depends(get_current_user)):
-    if user.role != "admin":
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
+                                       user=Depends(get_admin_user)):
     modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
     modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
     if modelfile:
     if modelfile:
         updated_modelfile = {
         updated_modelfile = {
@@ -127,12 +116,6 @@ async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
 
 
 @router.delete("/delete", response_model=bool)
 @router.delete("/delete", response_model=bool)
 async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
 async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
-                                       user=Depends(get_current_user)):
-    if user.role != "admin":
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
-
+                                       user=Depends(get_admin_user)):
     result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
     result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
     return result
     return result

+ 9 - 29
backend/apps/web/routers/prompts.py

@@ -8,7 +8,7 @@ import json
 
 
 from apps.web.models.prompts import Prompts, PromptForm, PromptModel
 from apps.web.models.prompts import Prompts, PromptForm, PromptModel
 
 
-from utils.utils import get_current_user
+from utils.utils import get_current_user, get_admin_user
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
 
 
 router = APIRouter()
 router = APIRouter()
@@ -29,29 +29,21 @@ async def get_prompts(user=Depends(get_current_user)):
 
 
 
 
 @router.post("/create", response_model=Optional[PromptModel])
 @router.post("/create", response_model=Optional[PromptModel])
-async def create_new_prompt(form_data: PromptForm, user=Depends(get_current_user)):
-    if user.role != "admin":
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
-
+async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)):
     prompt = Prompts.get_prompt_by_command(form_data.command)
     prompt = Prompts.get_prompt_by_command(form_data.command)
     if prompt == None:
     if prompt == None:
         prompt = Prompts.insert_new_prompt(user.id, form_data)
         prompt = Prompts.insert_new_prompt(user.id, form_data)
 
 
         if prompt:
         if prompt:
             return prompt
             return prompt
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.DEFAULT(),
-            )
-    else:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_400_BAD_REQUEST,
             status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.COMMAND_TAKEN,
+            detail=ERROR_MESSAGES.DEFAULT(),
         )
         )
+    raise HTTPException(
+        status_code=status.HTTP_400_BAD_REQUEST,
+        detail=ERROR_MESSAGES.COMMAND_TAKEN,
+    )
 
 
 
 
 ############################
 ############################
@@ -79,14 +71,8 @@ async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
 
 
 @router.post("/command/{command}/update", response_model=Optional[PromptModel])
 @router.post("/command/{command}/update", response_model=Optional[PromptModel])
 async def update_prompt_by_command(
 async def update_prompt_by_command(
-    command: str, form_data: PromptForm, user=Depends(get_current_user)
+    command: str, form_data: PromptForm, user=Depends(get_admin_user)
 ):
 ):
-    if user.role != "admin":
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
-
     prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
     prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
     if prompt:
     if prompt:
         return prompt
         return prompt
@@ -103,12 +89,6 @@ async def update_prompt_by_command(
 
 
 
 
 @router.delete("/command/{command}/delete", response_model=bool)
 @router.delete("/command/{command}/delete", response_model=bool)
-async def delete_prompt_by_command(command: str, user=Depends(get_current_user)):
-    if user.role != "admin":
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
-
+async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)):
     result = Prompts.delete_prompt_by_command(f"/{command}")
     result = Prompts.delete_prompt_by_command(f"/{command}")
     return result
     return result

+ 30 - 53
backend/apps/web/routers/users.py

@@ -11,7 +11,7 @@ import uuid
 from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
 from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
 from apps.web.models.auths import Auths
 from apps.web.models.auths import Auths
 
 
-from utils.utils import get_current_user, get_password_hash
+from utils.utils import get_current_user, get_password_hash, get_admin_user
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
 
 
 router = APIRouter()
 router = APIRouter()
@@ -22,12 +22,7 @@ router = APIRouter()
 
 
 
 
 @router.get("/", response_model=List[UserModel])
 @router.get("/", response_model=List[UserModel])
-async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
-    if user.role != "admin":
-        raise HTTPException(
-            status_code=status.HTTP_403_FORBIDDEN,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
+async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)):
     return Users.get_users(skip, limit)
     return Users.get_users(skip, limit)
 
 
 
 
@@ -38,21 +33,15 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_use
 
 
 @router.post("/update/role", response_model=Optional[UserModel])
 @router.post("/update/role", response_model=Optional[UserModel])
 async def update_user_role(
 async def update_user_role(
-    form_data: UserRoleUpdateForm, user=Depends(get_current_user)
+    form_data: UserRoleUpdateForm, user=Depends(get_admin_user)
 ):
 ):
-    if user.role != "admin":
-        raise HTTPException(
-            status_code=status.HTTP_403_FORBIDDEN,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
-
     if user.id != form_data.id:
     if user.id != form_data.id:
         return Users.update_user_role_by_id(form_data.id, form_data.role)
         return Users.update_user_role_by_id(form_data.id, form_data.role)
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_403_FORBIDDEN,
-            detail=ERROR_MESSAGES.ACTION_PROHIBITED,
-        )
+
+    raise HTTPException(
+        status_code=status.HTTP_403_FORBIDDEN,
+        detail=ERROR_MESSAGES.ACTION_PROHIBITED,
+    )
 
 
 
 
 ############################
 ############################
@@ -62,14 +51,8 @@ async def update_user_role(
 
 
 @router.post("/{user_id}/update", response_model=Optional[UserModel])
 @router.post("/{user_id}/update", response_model=Optional[UserModel])
 async def update_user_by_id(
 async def update_user_by_id(
-    user_id: str, form_data: UserUpdateForm, session_user=Depends(get_current_user)
+    user_id: str, form_data: UserUpdateForm, session_user=Depends(get_admin_user)
 ):
 ):
-    if session_user.role != "admin":
-        raise HTTPException(
-            status_code=status.HTTP_403_FORBIDDEN,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
-
     user = Users.get_user_by_id(user_id)
     user = Users.get_user_by_id(user_id)
 
 
     if user:
     if user:
@@ -98,18 +81,17 @@ async def update_user_by_id(
 
 
         if updated_user:
         if updated_user:
             return updated_user
             return updated_user
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_400_BAD_REQUEST,
-                detail=ERROR_MESSAGES.DEFAULT(),
-            )
 
 
-    else:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_400_BAD_REQUEST,
             status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.USER_NOT_FOUND,
+            detail=ERROR_MESSAGES.DEFAULT(),
         )
         )
 
 
+    raise HTTPException(
+        status_code=status.HTTP_400_BAD_REQUEST,
+        detail=ERROR_MESSAGES.USER_NOT_FOUND,
+    )
+
 
 
 ############################
 ############################
 # DeleteUserById
 # DeleteUserById
@@ -117,25 +99,20 @@ async def update_user_by_id(
 
 
 
 
 @router.delete("/{user_id}", response_model=bool)
 @router.delete("/{user_id}", response_model=bool)
-async def delete_user_by_id(user_id: str, user=Depends(get_current_user)):
-    if user.role == "admin":
-        if user.id != user_id:
-            result = Auths.delete_auth_by_id(user_id)
-
-            if result:
-                return True
-            else:
-                raise HTTPException(
-                    status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
-                    detail=ERROR_MESSAGES.DELETE_USER_ERROR,
-                )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_403_FORBIDDEN,
-                detail=ERROR_MESSAGES.ACTION_PROHIBITED,
-            )
-    else:
+async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
+    if user.id != user_id:
+        result = Auths.delete_auth_by_id(user_id)
+
+        if result:
+            return True
+
         raise HTTPException(
         raise HTTPException(
-            status_code=status.HTTP_403_FORBIDDEN,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+            detail=ERROR_MESSAGES.DELETE_USER_ERROR,
         )
         )
+
+    raise HTTPException(
+        status_code=status.HTTP_403_FORBIDDEN,
+        detail=ERROR_MESSAGES.ACTION_PROHIBITED,
+    )
+

+ 16 - 0
backend/utils/utils.py

@@ -73,3 +73,19 @@ def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(bearer_s
             status_code=status.HTTP_401_UNAUTHORIZED,
             status_code=status.HTTP_401_UNAUTHORIZED,
             detail=ERROR_MESSAGES.UNAUTHORIZED,
             detail=ERROR_MESSAGES.UNAUTHORIZED,
         )
         )
+
+
+def get_verified_user(user: Users = Depends(get_current_user)):
+    if user.role not in {"user", "admin"}:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+
+
+def get_admin_user(user: Users = Depends(get_current_user)):
+    if user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )