浏览代码

refac: use dependencies to verify token

- feat: added new util to get the current user when needed. Middleware was adding authentication logic to all the routes. let's revisit if we can move the non-auth endpoints to a separate route.
- refac: update the routes to use new helpers for verification and retrieving user
- chore: added black for local formatting of py code
Anuraag Jain 1 年之前
父节点
当前提交
bdd153d8f5

+ 7 - 3
backend/apps/ollama/main.py

@@ -8,7 +8,7 @@ import json
 
 from apps.web.models.users import Users
 from constants import ERROR_MESSAGES
-from utils.utils import extract_token_from_auth_header
+from utils.utils import decode_token
 from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
 
 app = Flask(__name__)
@@ -34,8 +34,12 @@ def proxy(path):
     # Basic RBAC support
     if WEBUI_AUTH:
         if "Authorization" in headers:
-            token = extract_token_from_auth_header(headers["Authorization"])
-            user = Users.get_user_by_token(token)
+            _, credentials = headers["Authorization"].split()
+            token_data = decode_token(credentials)
+            if token_data is None or "email" not in token_data:
+                return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
+
+            user = Users.get_user_by_email(token_data["email"])
             if user:
                 # Only user and admin roles can access
                 if user.role in ["user", "admin"]:

+ 21 - 8
backend/apps/web/main.py

@@ -1,9 +1,10 @@
-from fastapi import FastAPI
+from fastapi import FastAPI, Depends
+from fastapi.routing import APIRoute
 from fastapi.middleware.cors import CORSMiddleware
 from starlette.middleware.authentication import AuthenticationMiddleware
 from apps.web.routers import auths, users, chats, modelfiles, utils
 from config import WEBUI_VERSION, WEBUI_AUTH
-from apps.web.middlewares.auth import BearerTokenAuthBackend, on_auth_error
+from utils.utils import verify_auth_token
 
 app = FastAPI()
 
@@ -17,14 +18,26 @@ app.add_middleware(
     allow_headers=["*"],
 )
 
-
 app.include_router(auths.router, prefix="/auths", tags=["auths"])
 
-app.add_middleware(AuthenticationMiddleware, backend=BearerTokenAuthBackend(), on_error=on_auth_error)
-
-app.include_router(users.router, prefix="/users", tags=["users"])
-app.include_router(chats.router, prefix="/chats", tags=["chats"])
-app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
+app.include_router(
+    users.router,
+    prefix="/users",
+    tags=["users"],
+    dependencies=[Depends(verify_auth_token)],
+)
+app.include_router(
+    chats.router,
+    prefix="/chats",
+    tags=["chats"],
+    dependencies=[Depends(verify_auth_token)],
+)
+app.include_router(
+    modelfiles.router,
+    prefix="/modelfiles",
+    tags=["modelfiles"],
+    dependencies=[Depends(verify_auth_token)],
+)
 app.include_router(utils.router, prefix="/utils", tags=["utils"])
 
 

+ 0 - 27
backend/apps/web/middlewares/auth.py

@@ -1,27 +0,0 @@
-from apps.web.models.users import Users
-from fastapi import Request, status
-from starlette.authentication import (
-    AuthCredentials, AuthenticationBackend, AuthenticationError, 
-)
-from starlette.requests import HTTPConnection
-from utils.utils import verify_token
-from starlette.responses import JSONResponse
-from constants import ERROR_MESSAGES
-
-class BearerTokenAuthBackend(AuthenticationBackend):
-
-    async def authenticate(self, conn: HTTPConnection):
-        if "Authorization" not in conn.headers:
-            return
-        data = verify_token(conn)
-        if data != None and 'email' in data:
-            user = Users.get_user_by_email(data['email'])
-            if user is None:
-                raise AuthenticationError('Invalid credentials') 
-            return AuthCredentials([user.role]), user
-        else:
-            raise AuthenticationError('Invalid credentials') 
-
-def on_auth_error(request: Request, exc: Exception):
-    print('Authentication failed: ', exc)
-    return JSONResponse({"detail": ERROR_MESSAGES.INVALID_TOKEN}, status_code=status.HTTP_401_UNAUTHORIZED)

+ 0 - 10
backend/apps/web/models/users.py

@@ -3,8 +3,6 @@ from peewee import *
 from playhouse.shortcuts import model_to_dict
 from typing import List, Union, Optional
 import time
-
-from utils.utils import decode_token
 from utils.misc import get_gravatar_url
 
 from apps.web.internal.db import DB
@@ -83,14 +81,6 @@ class UsersTable:
         except:
             return None
 
-    def get_user_by_token(self, token: str) -> Optional[UserModel]:
-        data = decode_token(token)
-
-        if data != None and "email" in data:
-            return self.get_user_by_email(data["email"])
-        else:
-            return None
-
     def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
         return [
             UserModel(**model_to_dict(user))

+ 9 - 17
backend/apps/web/routers/auths.py

@@ -20,7 +20,7 @@ from apps.web.models.users import Users
 
 from utils.utils import (
     get_password_hash,
-    bearer_scheme,
+    get_current_user,
     create_token,
 )
 from utils.misc import get_gravatar_url
@@ -35,22 +35,14 @@ router = APIRouter()
 
 
 @router.get("/", response_model=UserResponse)
-async def get_session_user(cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-    if user:
-        return {
-            "id": user.id,
-            "email": user.email,
-            "name": user.name,
-            "role": user.role,
-            "profile_image_url": user.profile_image_url,
-        }
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def get_session_user(user=Depends(get_current_user)):
+    return {
+        "id": user.id,
+        "email": user.email,
+        "name": user.name,
+        "role": user.role,
+        "profile_image_url": user.profile_image_url,
+    }
 
 
 ############################

+ 30 - 26
backend/apps/web/routers/chats.py

@@ -1,8 +1,7 @@
-
 from fastapi import Depends, Request, HTTPException, status
 from datetime import datetime, timedelta
 from typing import List, Union, Optional
-
+from utils.utils import get_current_user
 from fastapi import APIRouter
 from pydantic import BaseModel
 import json
@@ -30,8 +29,10 @@ router = APIRouter()
 
 
 @router.get("/", response_model=List[ChatTitleIdResponse])
-async def get_user_chats(request:Request, skip: int = 0, limit: int = 50):
-    return Chats.get_chat_lists_by_user_id(request.user.id, skip, limit)
+async def get_user_chats(
+    user=Depends(get_current_user), skip: int = 0, limit: int = 50
+):
+    return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
 
 
 ############################
@@ -40,11 +41,11 @@ async def get_user_chats(request:Request, skip: int = 0, limit: int = 50):
 
 
 @router.get("/all", response_model=List[ChatResponse])
-async def get_all_user_chats(request:Request,):
+async def get_all_user_chats(user=Depends(get_current_user)):
     return [
-            ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-            for chat in Chats.get_all_chats_by_user_id(request.user.id)
-        ]
+        ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+        for chat in Chats.get_all_chats_by_user_id(user.id)
+    ]
 
 
 ############################
@@ -53,8 +54,8 @@ async def get_all_user_chats(request:Request,):
 
 
 @router.post("/new", response_model=Optional[ChatResponse])
-async def create_new_chat(form_data: ChatForm,request:Request):
-    chat = Chats.insert_new_chat(request.user.id, form_data)
+async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
+    chat = Chats.insert_new_chat(user.id, form_data)
     return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
 
 
@@ -64,14 +65,15 @@ async def create_new_chat(form_data: ChatForm,request:Request):
 
 
 @router.get("/{id}", response_model=Optional[ChatResponse])
-async def get_chat_by_id(id: str, request:Request):
-    chat = Chats.get_chat_by_id_and_user_id(id, request.user.id)
+async def get_chat_by_id(id: str, user=Depends(get_current_user)):
+    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
 
     if chat:
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
     else:
-        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
-                            detail=ERROR_MESSAGES.NOT_FOUND)
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
+        )
 
 
 ############################
@@ -80,18 +82,20 @@ async def get_chat_by_id(id: str, request:Request):
 
 
 @router.post("/{id}", response_model=Optional[ChatResponse])
-async def update_chat_by_id(id: str, form_data: ChatForm, request:Request):
-    chat = Chats.get_chat_by_id_and_user_id(id, request.user.id)
+async def update_chat_by_id(
+    id: str, form_data: ChatForm, user=Depends(get_current_user)
+):
+    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     if chat:
-            updated_chat = {**json.loads(chat.chat), **form_data.chat}
+        updated_chat = {**json.loads(chat.chat), **form_data.chat}
 
-            chat = Chats.update_chat_by_id(id, updated_chat)
-            return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+        chat = Chats.update_chat_by_id(id, updated_chat)
+        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
     else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
 
 
 ############################
@@ -100,6 +104,6 @@ async def update_chat_by_id(id: str, form_data: ChatForm, request:Request):
 
 
 @router.delete("/{id}", response_model=bool)
-async def delete_chat_by_id(id: str, request: Request):
-    result = Chats.delete_chat_by_id_and_user_id(id, request.user.id)
-    return result
+async def delete_chat_by_id(id: str, user=Depends(get_current_user)):
+    result = Chats.delete_chat_by_id_and_user_id(id, user.id)
+    return result

+ 62 - 112
backend/apps/web/routers/modelfiles.py

@@ -1,4 +1,3 @@
-from fastapi import Response
 from fastapi import Depends, FastAPI, HTTPException, status
 from datetime import datetime, timedelta
 from typing import List, Union, Optional
@@ -16,9 +15,7 @@ from apps.web.models.modelfiles import (
     ModelfileResponse,
 )
 
-from utils.utils import (
-    bearer_scheme,
-)
+from utils.utils import bearer_scheme, get_current_user
 from constants import ERROR_MESSAGES
 
 router = APIRouter()
@@ -30,16 +27,7 @@ router = APIRouter()
 
 @router.get("/", response_model=List[ModelfileResponse])
 async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        return Modelfiles.get_modelfiles(skip, limit)
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+    return Modelfiles.get_modelfiles(skip, limit)
 
 
 ############################
@@ -48,36 +36,28 @@ async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch
 
 
 @router.post("/create", response_model=Optional[ModelfileResponse])
-async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        # Admin Only
-        if user.role == "admin":
-            modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
-
-            if modelfile:
-                return ModelfileResponse(
-                    **{
-                        **modelfile.model_dump(),
-                        "modelfile": json.loads(modelfile.modelfile),
-                    }
-                )
-            else:
-                raise HTTPException(
-                    status_code=status.HTTP_401_UNAUTHORIZED,
-                    detail=ERROR_MESSAGES.DEFAULT(),
-                )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
+async def create_new_modelfile(
+    form_data: ModelfileForm, user=Depends(get_current_user)
+):
+    if user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+
+    modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
+
+    if modelfile:
+        return ModelfileResponse(
+            **{
+                **modelfile.model_dump(),
+                "modelfile": json.loads(modelfile.modelfile),
+            }
+        )
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            detail=ERROR_MESSAGES.DEFAULT(),
         )
 
 
@@ -87,31 +67,20 @@ async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_sch
 
 
 @router.post("/", response_model=Optional[ModelfileResponse])
-async def get_modelfile_by_tag_name(
-    form_data: ModelfileTagNameForm, cred=Depends(bearer_scheme)
-):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
-
-        if modelfile:
-            return ModelfileResponse(
-                **{
-                    **modelfile.model_dump(),
-                    "modelfile": json.loads(modelfile.modelfile),
-                }
-            )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.NOT_FOUND,
-            )
+async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm):
+    modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
+
+    if modelfile:
+        return ModelfileResponse(
+            **{
+                **modelfile.model_dump(),
+                "modelfile": json.loads(modelfile.modelfile),
+            }
+        )
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            detail=ERROR_MESSAGES.NOT_FOUND,
         )
 
 
@@ -122,44 +91,34 @@ async def get_modelfile_by_tag_name(
 
 @router.post("/update", response_model=Optional[ModelfileResponse])
 async def update_modelfile_by_tag_name(
-    form_data: ModelfileUpdateForm, cred=Depends(bearer_scheme)
+    form_data: ModelfileUpdateForm, user=Depends(get_current_user)
 ):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        if user.role == "admin":
-            modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
-            if modelfile:
-                updated_modelfile = {
-                    **json.loads(modelfile.modelfile),
-                    **form_data.modelfile,
-                }
-
-                modelfile = Modelfiles.update_modelfile_by_tag_name(
-                    form_data.tag_name, updated_modelfile
-                )
-
-                return ModelfileResponse(
-                    **{
-                        **modelfile.model_dump(),
-                        "modelfile": json.loads(modelfile.modelfile),
-                    }
-                )
-            else:
-                raise HTTPException(
-                    status_code=status.HTTP_401_UNAUTHORIZED,
-                    detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-                )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
+    if user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+    modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
+    if modelfile:
+        updated_modelfile = {
+            **json.loads(modelfile.modelfile),
+            **form_data.modelfile,
+        }
+
+        modelfile = Modelfiles.update_modelfile_by_tag_name(
+            form_data.tag_name, updated_modelfile
+        )
+
+        return ModelfileResponse(
+            **{
+                **modelfile.model_dump(),
+                "modelfile": json.loads(modelfile.modelfile),
+            }
+        )
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
 
 
@@ -170,22 +129,13 @@ async def update_modelfile_by_tag_name(
 
 @router.delete("/delete", response_model=bool)
 async def delete_modelfile_by_tag_name(
-    form_data: ModelfileTagNameForm, cred=Depends(bearer_scheme)
+    form_data: ModelfileTagNameForm, user=Depends(get_current_user)
 ):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        if user.role == "admin":
-            result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
-            return result
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
-    else:
+    if user.role != "admin":
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
+
+    result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
+    return result

+ 18 - 39
backend/apps/web/routers/users.py

@@ -10,11 +10,7 @@ import uuid
 
 from apps.web.models.users import UserModel, UserRoleUpdateForm, Users
 
-from utils.utils import (
-    get_password_hash,
-    bearer_scheme,
-    create_token,
-)
+from utils.utils import get_current_user
 from constants import ERROR_MESSAGES
 
 router = APIRouter()
@@ -25,23 +21,13 @@ router = APIRouter()
 
 
 @router.get("/", response_model=List[UserModel])
-async def get_users(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        if user.role == "admin":
-            return Users.get_users(skip, limit)
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_403_FORBIDDEN,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
-    else:
+async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
+    if user.role != "admin":
         raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
+    return Users.get_users(skip, limit)
 
 
 ############################
@@ -50,26 +36,19 @@ async def get_users(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme))
 
 
 @router.post("/update/role", response_model=Optional[UserModel])
-async def update_user_role(form_data: UserRoleUpdateForm, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
+async def update_user_role(
+    form_data: UserRoleUpdateForm, user=Depends(get_current_user)
+):
+    if user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
 
-    if user:
-        if user.role == "admin":
-            if user.id != form_data.id:
-                return Users.update_user_role_by_id(form_data.id, form_data.role)
-            else:
-                raise HTTPException(
-                    status_code=status.HTTP_403_FORBIDDEN,
-                    detail=ERROR_MESSAGES.ACTION_PROHIBITED,
-                )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_403_FORBIDDEN,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
+    if user.id != form_data.id:
+        return Users.update_user_role_by_id(form_data.id, form_data.role)
     else:
         raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail=ERROR_MESSAGES.ACTION_PROHIBITED,
         )

+ 2 - 0
backend/requirements.txt

@@ -18,3 +18,5 @@ bcrypt
 
 PyJWT
 pyjwt[crypto]
+
+black

+ 23 - 14
backend/utils/utils.py

@@ -1,7 +1,9 @@
-from fastapi.security import HTTPBasicCredentials, HTTPBearer
+from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
+from fastapi import HTTPException, status, Depends
+from apps.web.models.users import Users
 from pydantic import BaseModel
 from typing import Union, Optional
-
+from constants import ERROR_MESSAGES
 from passlib.context import CryptContext
 from datetime import datetime, timedelta
 import requests
@@ -53,16 +55,23 @@ def extract_token_from_auth_header(auth_header: str):
     return auth_header[len("Bearer ") :]
 
 
-def verify_token(request):
-    try:
-        authorization = request.headers["authorization"]
-        if authorization:
-            _, token = authorization.split()
-            decoded_token = jwt.decode(
-                token, JWT_SECRET_KEY, options={"verify_signature": False}
+def verify_auth_token(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
+    data = decode_token(auth_token.credentials)
+    if data != None and "email" in data:
+        user = Users.get_user_by_email(data["email"])
+        if user is None:
+            raise HTTPException(
+                status_code=status.HTTP_401_UNAUTHORIZED,
+                detail=ERROR_MESSAGES.INVALID_TOKEN,
             )
-            return decoded_token
-        else:
-            return None
-    except Exception as e:
-        return None
+        return
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.UNAUTHORIZED,
+        )
+
+
+def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
+    data = decode_token(auth_token.credentials)
+    return Users.get_user_by_email(data["email"])