浏览代码

Merge pull request #311 from anuraagdjain/refac/auth-middleware

feat(auth): add auth middleware
Timothy Jaeryang Baek 1 年之前
父节点
当前提交
c5386d05ab

+ 2 - 1
backend/.gitignore

@@ -4,4 +4,5 @@ _old
 uploads
 uploads
 .ipynb_checkpoints
 .ipynb_checkpoints
 *.db
 *.db
-_test
+_test
+Pipfile

+ 7 - 3
backend/apps/ollama/main.py

@@ -8,7 +8,7 @@ import json
 
 
 from apps.web.models.users import Users
 from apps.web.models.users import Users
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
-from utils.utils import extract_token_from_auth_header
+from utils.utils import decode_token
 from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
 from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
 
 
 app = Flask(__name__)
 app = Flask(__name__)
@@ -34,8 +34,12 @@ def proxy(path):
     # Basic RBAC support
     # Basic RBAC support
     if WEBUI_AUTH:
     if WEBUI_AUTH:
         if "Authorization" in headers:
         if "Authorization" in headers:
-            token = extract_token_from_auth_header(headers["Authorization"])
-            user = Users.get_user_by_token(token)
+            _, credentials = headers["Authorization"].split()
+            token_data = decode_token(credentials)
+            if token_data is None or "email" not in token_data:
+                return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
+
+            user = Users.get_user_by_email(token_data["email"])
             if user:
             if user:
                 # Only user and admin roles can access
                 # Only user and admin roles can access
                 if user.role in ["user", "admin"]:
                 if user.role in ["user", "admin"]:

+ 3 - 5
backend/apps/web/main.py

@@ -1,6 +1,6 @@
-from fastapi import FastAPI, Request, Depends, HTTPException
+from fastapi import FastAPI, Depends
+from fastapi.routing import APIRoute
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
-
 from apps.web.routers import auths, users, chats, modelfiles, utils
 from apps.web.routers import auths, users, chats, modelfiles, utils
 from config import WEBUI_VERSION, WEBUI_AUTH
 from config import WEBUI_VERSION, WEBUI_AUTH
 
 
@@ -16,13 +16,11 @@ app.add_middleware(
     allow_headers=["*"],
     allow_headers=["*"],
 )
 )
 
 
-
 app.include_router(auths.router, prefix="/auths", tags=["auths"])
 app.include_router(auths.router, prefix="/auths", tags=["auths"])
+
 app.include_router(users.router, prefix="/users", tags=["users"])
 app.include_router(users.router, prefix="/users", tags=["users"])
 app.include_router(chats.router, prefix="/chats", tags=["chats"])
 app.include_router(chats.router, prefix="/chats", tags=["chats"])
 app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
 app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
-
-
 app.include_router(utils.router, prefix="/utils", tags=["utils"])
 app.include_router(utils.router, prefix="/utils", tags=["utils"])
 
 
 
 

+ 0 - 10
backend/apps/web/models/users.py

@@ -3,8 +3,6 @@ from peewee import *
 from playhouse.shortcuts import model_to_dict
 from playhouse.shortcuts import model_to_dict
 from typing import List, Union, Optional
 from typing import List, Union, Optional
 import time
 import time
-
-from utils.utils import decode_token
 from utils.misc import get_gravatar_url
 from utils.misc import get_gravatar_url
 
 
 from apps.web.internal.db import DB
 from apps.web.internal.db import DB
@@ -85,14 +83,6 @@ class UsersTable:
         except:
         except:
             return None
             return None
 
 
-    def get_user_by_token(self, token: str) -> Optional[UserModel]:
-        data = decode_token(token)
-
-        if data != None and "email" in data:
-            return self.get_user_by_email(data["email"])
-        else:
-            return None
-
     def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
     def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
         return [
         return [
             UserModel(**model_to_dict(user))
             UserModel(**model_to_dict(user))

+ 12 - 25
backend/apps/web/routers/auths.py

@@ -19,11 +19,7 @@ from apps.web.models.auths import (
 from apps.web.models.users import Users
 from apps.web.models.users import Users
 
 
 
 
-from utils.utils import (
-    get_password_hash,
-    bearer_scheme,
-    create_token,
-)
+from utils.utils import get_password_hash, get_current_user, create_token
 from utils.misc import get_gravatar_url
 from utils.misc import get_gravatar_url
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
 
 
@@ -36,22 +32,14 @@ router = APIRouter()
 
 
 
 
 @router.get("/", response_model=UserResponse)
 @router.get("/", response_model=UserResponse)
-async def get_session_user(cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-    if user:
-        return {
-            "id": user.id,
-            "email": user.email,
-            "name": user.name,
-            "role": user.role,
-            "profile_image_url": user.profile_image_url,
-        }
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def get_session_user(user=Depends(get_current_user)):
+    return {
+        "id": user.id,
+        "email": user.email,
+        "name": user.name,
+        "role": user.role,
+        "profile_image_url": user.profile_image_url,
+    }
 
 
 
 
 ############################
 ############################
@@ -60,10 +48,9 @@ async def get_session_user(cred=Depends(bearer_scheme)):
 
 
 
 
 @router.post("/update/password", response_model=bool)
 @router.post("/update/password", response_model=bool)
-async def update_password(form_data: UpdatePasswordForm, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    session_user = Users.get_user_by_token(token)
-
+async def update_password(
+    form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
+):
     if session_user:
     if session_user:
         user = Auths.authenticate_user(session_user.email, form_data.password)
         user = Auths.authenticate_user(session_user.email, form_data.password)
 
 

+ 36 - 97
backend/apps/web/routers/chats.py

@@ -1,8 +1,7 @@
-from fastapi import Response
-from fastapi import Depends, FastAPI, HTTPException, status
+from fastapi import Depends, Request, HTTPException, status
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 from typing import List, Union, Optional
 from typing import List, Union, Optional
-
+from utils.utils import get_current_user
 from fastapi import APIRouter
 from fastapi import APIRouter
 from pydantic import BaseModel
 from pydantic import BaseModel
 import json
 import json
@@ -30,17 +29,10 @@ router = APIRouter()
 
 
 
 
 @router.get("/", response_model=List[ChatTitleIdResponse])
 @router.get("/", response_model=List[ChatTitleIdResponse])
-async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def get_user_chats(
+    user=Depends(get_current_user), skip: int = 0, limit: int = 50
+):
+    return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
 
 
 
 
 ############################
 ############################
@@ -49,20 +41,11 @@ async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch
 
 
 
 
 @router.get("/all", response_model=List[ChatResponse])
 @router.get("/all", response_model=List[ChatResponse])
-async def get_all_user_chats(cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        return [
-            ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-            for chat in Chats.get_all_chats_by_user_id(user.id)
-        ]
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def get_all_user_chats(user=Depends(get_current_user)):
+    return [
+        ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+        for chat in Chats.get_all_chats_by_user_id(user.id)
+    ]
 
 
 
 
 ############################
 ############################
@@ -71,18 +54,9 @@ async def get_all_user_chats(cred=Depends(bearer_scheme)):
 
 
 
 
 @router.post("/new", response_model=Optional[ChatResponse])
 @router.post("/new", response_model=Optional[ChatResponse])
-async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        chat = Chats.insert_new_chat(user.id, form_data)
-        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
+    chat = Chats.insert_new_chat(user.id, form_data)
+    return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
 
 
 
 
 ############################
 ############################
@@ -91,24 +65,14 @@ async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)):
 
 
 
 
 @router.get("/{id}", response_model=Optional[ChatResponse])
 @router.get("/{id}", response_model=Optional[ChatResponse])
-async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        chat = Chats.get_chat_by_id_and_user_id(id, user.id)
-
-        if chat:
-            return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.NOT_FOUND,
-            )
+async def get_chat_by_id(id: str, user=Depends(get_current_user)):
+    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+
+    if chat:
+        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
     else:
     else:
         raise HTTPException(
         raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
         )
         )
 
 
 
 
@@ -118,26 +82,19 @@ async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)):
 
 
 
 
 @router.post("/{id}", response_model=Optional[ChatResponse])
 @router.post("/{id}", response_model=Optional[ChatResponse])
-async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        chat = Chats.get_chat_by_id_and_user_id(id, user.id)
-        if chat:
-            updated_chat = {**json.loads(chat.chat), **form_data.chat}
-
-            chat = Chats.update_chat_by_id(id, updated_chat)
-            return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
+async def update_chat_by_id(
+    id: str, form_data: ChatForm, user=Depends(get_current_user)
+):
+    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+    if chat:
+        updated_chat = {**json.loads(chat.chat), **form_data.chat}
+
+        chat = Chats.update_chat_by_id(id, updated_chat)
+        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
     else:
     else:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
             status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
         )
 
 
 
 
@@ -147,18 +104,9 @@ async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_sc
 
 
 
 
 @router.delete("/{id}", response_model=bool)
 @router.delete("/{id}", response_model=bool)
-async def delete_chat_by_id(id: str, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        result = Chats.delete_chat_by_id_and_user_id(id, user.id)
-        return result
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def delete_chat_by_id(id: str, user=Depends(get_current_user)):
+    result = Chats.delete_chat_by_id_and_user_id(id, user.id)
+    return result
 
 
 
 
 ############################
 ############################
@@ -167,15 +115,6 @@ async def delete_chat_by_id(id: str, cred=Depends(bearer_scheme)):
 
 
 
 
 @router.delete("/", response_model=bool)
 @router.delete("/", response_model=bool)
-async def delete_all_user_chats(cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        result = Chats.delete_chats_by_user_id(user.id)
-        return result
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def delete_all_user_chats(user=Depends(get_current_user)):
+    result = Chats.delete_chats_by_user_id(user.id)
+    return result

+ 63 - 115
backend/apps/web/routers/modelfiles.py

@@ -1,4 +1,3 @@
-from fastapi import Response
 from fastapi import Depends, FastAPI, HTTPException, status
 from fastapi import Depends, FastAPI, HTTPException, status
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 from typing import List, Union, Optional
 from typing import List, Union, Optional
@@ -6,8 +5,6 @@ from typing import List, Union, Optional
 from fastapi import APIRouter
 from fastapi import APIRouter
 from pydantic import BaseModel
 from pydantic import BaseModel
 import json
 import json
-
-from apps.web.models.users import Users
 from apps.web.models.modelfiles import (
 from apps.web.models.modelfiles import (
     Modelfiles,
     Modelfiles,
     ModelfileForm,
     ModelfileForm,
@@ -16,9 +13,7 @@ from apps.web.models.modelfiles import (
     ModelfileResponse,
     ModelfileResponse,
 )
 )
 
 
-from utils.utils import (
-    bearer_scheme,
-)
+from utils.utils import get_current_user
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
 
 
 router = APIRouter()
 router = APIRouter()
@@ -29,17 +24,8 @@ router = APIRouter()
 
 
 
 
 @router.get("/", response_model=List[ModelfileResponse])
 @router.get("/", response_model=List[ModelfileResponse])
-async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        return Modelfiles.get_modelfiles(skip, limit)
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def get_modelfiles(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
+    return Modelfiles.get_modelfiles(skip, limit)
 
 
 
 
 ############################
 ############################
@@ -48,36 +34,28 @@ async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch
 
 
 
 
 @router.post("/create", response_model=Optional[ModelfileResponse])
 @router.post("/create", response_model=Optional[ModelfileResponse])
-async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        # Admin Only
-        if user.role == "admin":
-            modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
-
-            if modelfile:
-                return ModelfileResponse(
-                    **{
-                        **modelfile.model_dump(),
-                        "modelfile": json.loads(modelfile.modelfile),
-                    }
-                )
-            else:
-                raise HTTPException(
-                    status_code=status.HTTP_401_UNAUTHORIZED,
-                    detail=ERROR_MESSAGES.DEFAULT(),
-                )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
+async def create_new_modelfile(
+    form_data: ModelfileForm, user=Depends(get_current_user)
+):
+    if user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+
+    modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
+
+    if modelfile:
+        return ModelfileResponse(
+            **{
+                **modelfile.model_dump(),
+                "modelfile": json.loads(modelfile.modelfile),
+            }
+        )
     else:
     else:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
             status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            detail=ERROR_MESSAGES.DEFAULT(),
         )
         )
 
 
 
 
@@ -87,31 +65,20 @@ async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_sch
 
 
 
 
 @router.post("/", response_model=Optional[ModelfileResponse])
 @router.post("/", response_model=Optional[ModelfileResponse])
-async def get_modelfile_by_tag_name(
-    form_data: ModelfileTagNameForm, cred=Depends(bearer_scheme)
-):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
-
-        if modelfile:
-            return ModelfileResponse(
-                **{
-                    **modelfile.model_dump(),
-                    "modelfile": json.loads(modelfile.modelfile),
-                }
-            )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.NOT_FOUND,
-            )
+async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm, user=Depends(get_current_user)):
+    modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
+
+    if modelfile:
+        return ModelfileResponse(
+            **{
+                **modelfile.model_dump(),
+                "modelfile": json.loads(modelfile.modelfile),
+            }
+        )
     else:
     else:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
             status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            detail=ERROR_MESSAGES.NOT_FOUND,
         )
         )
 
 
 
 
@@ -122,44 +89,34 @@ async def get_modelfile_by_tag_name(
 
 
 @router.post("/update", response_model=Optional[ModelfileResponse])
 @router.post("/update", response_model=Optional[ModelfileResponse])
 async def update_modelfile_by_tag_name(
 async def update_modelfile_by_tag_name(
-    form_data: ModelfileUpdateForm, cred=Depends(bearer_scheme)
+    form_data: ModelfileUpdateForm, user=Depends(get_current_user)
 ):
 ):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        if user.role == "admin":
-            modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
-            if modelfile:
-                updated_modelfile = {
-                    **json.loads(modelfile.modelfile),
-                    **form_data.modelfile,
-                }
-
-                modelfile = Modelfiles.update_modelfile_by_tag_name(
-                    form_data.tag_name, updated_modelfile
-                )
-
-                return ModelfileResponse(
-                    **{
-                        **modelfile.model_dump(),
-                        "modelfile": json.loads(modelfile.modelfile),
-                    }
-                )
-            else:
-                raise HTTPException(
-                    status_code=status.HTTP_401_UNAUTHORIZED,
-                    detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-                )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
+    if user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+    modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
+    if modelfile:
+        updated_modelfile = {
+            **json.loads(modelfile.modelfile),
+            **form_data.modelfile,
+        }
+
+        modelfile = Modelfiles.update_modelfile_by_tag_name(
+            form_data.tag_name, updated_modelfile
+        )
+
+        return ModelfileResponse(
+            **{
+                **modelfile.model_dump(),
+                "modelfile": json.loads(modelfile.modelfile),
+            }
+        )
     else:
     else:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
             status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
         )
 
 
 
 
@@ -170,22 +127,13 @@ async def update_modelfile_by_tag_name(
 
 
 @router.delete("/delete", response_model=bool)
 @router.delete("/delete", response_model=bool)
 async def delete_modelfile_by_tag_name(
 async def delete_modelfile_by_tag_name(
-    form_data: ModelfileTagNameForm, cred=Depends(bearer_scheme)
+    form_data: ModelfileTagNameForm, user=Depends(get_current_user)
 ):
 ):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        if user.role == "admin":
-            result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
-            return result
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
-    else:
+    if user.role != "admin":
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
             status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
         )
+
+    result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
+    return result

+ 31 - 61
backend/apps/web/routers/users.py

@@ -12,11 +12,7 @@ from apps.web.models.users import UserModel, UserRoleUpdateForm, Users
 from apps.web.models.auths import Auths
 from apps.web.models.auths import Auths
 
 
 
 
-from utils.utils import (
-    get_password_hash,
-    bearer_scheme,
-    create_token,
-)
+from utils.utils import get_current_user
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
 
 
 router = APIRouter()
 router = APIRouter()
@@ -27,23 +23,13 @@ router = APIRouter()
 
 
 
 
 @router.get("/", response_model=List[UserModel])
 @router.get("/", response_model=List[UserModel])
-async def get_users(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        if user.role == "admin":
-            return Users.get_users(skip, limit)
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_403_FORBIDDEN,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
-    else:
+async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
+    if user.role != "admin":
         raise HTTPException(
         raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
         )
+    return Users.get_users(skip, limit)
 
 
 
 
 ############################
 ############################
@@ -52,28 +38,21 @@ async def get_users(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme))
 
 
 
 
 @router.post("/update/role", response_model=Optional[UserModel])
 @router.post("/update/role", response_model=Optional[UserModel])
-async def update_user_role(form_data: UserRoleUpdateForm, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        if user.role == "admin":
-            if user.id != form_data.id:
-                return Users.update_user_role_by_id(form_data.id, form_data.role)
-            else:
-                raise HTTPException(
-                    status_code=status.HTTP_403_FORBIDDEN,
-                    detail=ERROR_MESSAGES.ACTION_PROHIBITED,
-                )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_403_FORBIDDEN,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
+async def update_user_role(
+    form_data: UserRoleUpdateForm, user=Depends(get_current_user)
+):
+    if user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+
+    if user.id != form_data.id:
+        return Users.update_user_role_by_id(form_data.id, form_data.role)
     else:
     else:
         raise HTTPException(
         raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail=ERROR_MESSAGES.ACTION_PROHIBITED,
         )
         )
 
 
 
 
@@ -83,34 +62,25 @@ async def update_user_role(form_data: UserRoleUpdateForm, cred=Depends(bearer_sc
 
 
 
 
 @router.delete("/{user_id}", response_model=bool)
 @router.delete("/{user_id}", response_model=bool)
-async def delete_user_by_id(user_id: str, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        if user.role == "admin":
-            if user.id != user_id:
-                result = Auths.delete_auth_by_id(user_id)
-
-                if result:
-                    return True
-                else:
-                    raise HTTPException(
-                        status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
-                        detail=ERROR_MESSAGES.DELETE_USER_ERROR,
-                    )
+async def delete_user_by_id(user_id: str, user=Depends(get_current_user)):
+    if user.role == "admin":
+        if user.id != user_id:
+            result = Auths.delete_auth_by_id(user_id)
+
+            if result:
+                return True
             else:
             else:
                 raise HTTPException(
                 raise HTTPException(
-                    status_code=status.HTTP_403_FORBIDDEN,
-                    detail=ERROR_MESSAGES.ACTION_PROHIBITED,
+                    status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+                    detail=ERROR_MESSAGES.DELETE_USER_ERROR,
                 )
                 )
         else:
         else:
             raise HTTPException(
             raise HTTPException(
                 status_code=status.HTTP_403_FORBIDDEN,
                 status_code=status.HTTP_403_FORBIDDEN,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+                detail=ERROR_MESSAGES.ACTION_PROHIBITED,
             )
             )
     else:
     else:
         raise HTTPException(
         raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
         )

+ 2 - 0
backend/requirements.txt

@@ -18,3 +18,5 @@ bcrypt
 
 
 PyJWT
 PyJWT
 pyjwt[crypto]
 pyjwt[crypto]
+
+black

+ 18 - 14
backend/utils/utils.py

@@ -1,7 +1,9 @@
-from fastapi.security import HTTPBasicCredentials, HTTPBearer
+from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
+from fastapi import HTTPException, status, Depends
+from apps.web.models.users import Users
 from pydantic import BaseModel
 from pydantic import BaseModel
 from typing import Union, Optional
 from typing import Union, Optional
-
+from constants import ERROR_MESSAGES
 from passlib.context import CryptContext
 from passlib.context import CryptContext
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 import requests
 import requests
@@ -53,16 +55,18 @@ def extract_token_from_auth_header(auth_header: str):
     return auth_header[len("Bearer ") :]
     return auth_header[len("Bearer ") :]
 
 
 
 
-def verify_token(request):
-    try:
-        bearer = request.headers["authorization"]
-        if bearer:
-            token = bearer[len("Bearer ") :]
-            decoded = jwt.decode(
-                token, JWT_SECRET_KEY, options={"verify_signature": False}
+def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
+    data = decode_token(auth_token.credentials)
+    if data != None and "email" in data:
+        user = Users.get_user_by_email(data["email"])
+        if user is None:
+            raise HTTPException(
+                status_code=status.HTTP_401_UNAUTHORIZED,
+                detail=ERROR_MESSAGES.INVALID_TOKEN,
             )
             )
-            return decoded
-        else:
-            return None
-    except Exception as e:
-        return None
+        return user
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.UNAUTHORIZED,
+        )