浏览代码

feat(auth): add auth middleware

- refactored chat routes to use request.user instead of doing authentication in every route
Anuraag Jain 1 年之前
父节点
当前提交
a01b112f7f
共有 5 个文件被更改,包括 63 次插入89 次删除
  1. 2 1
      backend/.gitignore
  2. 6 4
      backend/apps/web/main.py
  3. 27 0
      backend/apps/web/middlewares/auth.py
  4. 23 79
      backend/apps/web/routers/chats.py
  5. 5 5
      backend/utils/utils.py

+ 2 - 1
backend/.gitignore

@@ -4,4 +4,5 @@ _old
 uploads
 uploads
 .ipynb_checkpoints
 .ipynb_checkpoints
 *.db
 *.db
-_test
+_test
+Pipfile

+ 6 - 4
backend/apps/web/main.py

@@ -1,8 +1,9 @@
-from fastapi import FastAPI, Request, Depends, HTTPException
+from fastapi import FastAPI
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
-
+from starlette.middleware.authentication import AuthenticationMiddleware
 from apps.web.routers import auths, users, chats, modelfiles, utils
 from apps.web.routers import auths, users, chats, modelfiles, utils
 from config import WEBUI_VERSION, WEBUI_AUTH
 from config import WEBUI_VERSION, WEBUI_AUTH
+from apps.web.middlewares.auth import BearerTokenAuthBackend, on_auth_error
 
 
 app = FastAPI()
 app = FastAPI()
 
 
@@ -18,11 +19,12 @@ app.add_middleware(
 
 
 
 
 app.include_router(auths.router, prefix="/auths", tags=["auths"])
 app.include_router(auths.router, prefix="/auths", tags=["auths"])
+
+app.add_middleware(AuthenticationMiddleware, backend=BearerTokenAuthBackend(), on_error=on_auth_error)
+
 app.include_router(users.router, prefix="/users", tags=["users"])
 app.include_router(users.router, prefix="/users", tags=["users"])
 app.include_router(chats.router, prefix="/chats", tags=["chats"])
 app.include_router(chats.router, prefix="/chats", tags=["chats"])
 app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
 app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
-
-
 app.include_router(utils.router, prefix="/utils", tags=["utils"])
 app.include_router(utils.router, prefix="/utils", tags=["utils"])
 
 
 
 

+ 27 - 0
backend/apps/web/middlewares/auth.py

@@ -0,0 +1,27 @@
+from apps.web.models.users import Users
+from fastapi import Request, status
+from starlette.authentication import (
+    AuthCredentials, AuthenticationBackend, AuthenticationError, 
+)
+from starlette.requests import HTTPConnection
+from utils.utils import verify_token
+from starlette.responses import JSONResponse
+from constants import ERROR_MESSAGES
+
+class BearerTokenAuthBackend(AuthenticationBackend):
+
+    async def authenticate(self, conn: HTTPConnection):
+        if "Authorization" not in conn.headers:
+            return
+        data = verify_token(conn)
+        if data != None and 'email' in data:
+            user = Users.get_user_by_email(data['email'])
+            if user is None:
+                raise AuthenticationError('Invalid credentials') 
+            return AuthCredentials([user.role]), user
+        else:
+            raise AuthenticationError('Invalid credentials') 
+
+def on_auth_error(request: Request, exc: Exception):
+    print('Authentication failed: ', exc)
+    return JSONResponse({"detail": ERROR_MESSAGES.INVALID_TOKEN}, status_code=status.HTTP_401_UNAUTHORIZED)

+ 23 - 79
backend/apps/web/routers/chats.py

@@ -1,5 +1,5 @@
-from fastapi import Response
-from fastapi import Depends, FastAPI, HTTPException, status
+
+from fastapi import Depends, Request, HTTPException, status
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 from typing import List, Union, Optional
 from typing import List, Union, Optional
 
 
@@ -30,17 +30,8 @@ router = APIRouter()
 
 
 
 
 @router.get("/", response_model=List[ChatTitleIdResponse])
 @router.get("/", response_model=List[ChatTitleIdResponse])
-async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def get_user_chats(request:Request, skip: int = 0, limit: int = 50):
+    return Chats.get_chat_lists_by_user_id(request.user.id, skip, limit)
 
 
 
 
 ############################
 ############################
@@ -49,20 +40,11 @@ async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch
 
 
 
 
 @router.get("/all", response_model=List[ChatResponse])
 @router.get("/all", response_model=List[ChatResponse])
-async def get_all_user_chats(cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        return [
+async def get_all_user_chats(request:Request,):
+    return [
             ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
             ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-            for chat in Chats.get_all_chats_by_user_id(user.id)
+            for chat in Chats.get_all_chats_by_user_id(request.user.id)
         ]
         ]
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
 
 
 
 
 ############################
 ############################
@@ -71,18 +53,9 @@ async def get_all_user_chats(cred=Depends(bearer_scheme)):
 
 
 
 
 @router.post("/new", response_model=Optional[ChatResponse])
 @router.post("/new", response_model=Optional[ChatResponse])
-async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        chat = Chats.insert_new_chat(user.id, form_data)
-        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def create_new_chat(form_data: ChatForm,request:Request):
+    chat = Chats.insert_new_chat(request.user.id, form_data)
+    return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
 
 
 
 
 ############################
 ############################
@@ -91,25 +64,14 @@ async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)):
 
 
 
 
 @router.get("/{id}", response_model=Optional[ChatResponse])
 @router.get("/{id}", response_model=Optional[ChatResponse])
-async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
+async def get_chat_by_id(id: str, request:Request):
+    chat = Chats.get_chat_by_id_and_user_id(id, request.user.id)
 
 
-    if user:
-        chat = Chats.get_chat_by_id_and_user_id(id, user.id)
-
-        if chat:
-            return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.NOT_FOUND,
-            )
+    if chat:
+        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
     else:
     else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
+                            detail=ERROR_MESSAGES.NOT_FOUND)
 
 
 
 
 ############################
 ############################
@@ -118,27 +80,18 @@ async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)):
 
 
 
 
 @router.post("/{id}", response_model=Optional[ChatResponse])
 @router.post("/{id}", response_model=Optional[ChatResponse])
-async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        chat = Chats.get_chat_by_id_and_user_id(id, user.id)
-        if chat:
+async def update_chat_by_id(id: str, form_data: ChatForm, request:Request):
+    chat = Chats.get_chat_by_id_and_user_id(id, request.user.id)
+    if chat:
             updated_chat = {**json.loads(chat.chat), **form_data.chat}
             updated_chat = {**json.loads(chat.chat), **form_data.chat}
 
 
             chat = Chats.update_chat_by_id(id, updated_chat)
             chat = Chats.update_chat_by_id(id, updated_chat)
             return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
             return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-        else:
+    else:
             raise HTTPException(
             raise HTTPException(
                 status_code=status.HTTP_401_UNAUTHORIZED,
                 status_code=status.HTTP_401_UNAUTHORIZED,
                 detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
                 detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
             )
             )
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
 
 
 
 
 ############################
 ############################
@@ -147,15 +100,6 @@ async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_sc
 
 
 
 
 @router.delete("/{id}", response_model=bool)
 @router.delete("/{id}", response_model=bool)
-async def delete_chat_by_id(id: str, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        result = Chats.delete_chat_by_id_and_user_id(id, user.id)
-        return result
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def delete_chat_by_id(id: str, request: Request):
+    result = Chats.delete_chat_by_id_and_user_id(id, request.user.id)
+    return result

+ 5 - 5
backend/utils/utils.py

@@ -55,13 +55,13 @@ def extract_token_from_auth_header(auth_header: str):
 
 
 def verify_token(request):
 def verify_token(request):
     try:
     try:
-        bearer = request.headers["authorization"]
-        if bearer:
-            token = bearer[len("Bearer ") :]
-            decoded = jwt.decode(
+        authorization = request.headers["authorization"]
+        if authorization:
+            _, token = authorization.split()
+            decoded_token = jwt.decode(
                 token, JWT_SECRET_KEY, options={"verify_signature": False}
                 token, JWT_SECRET_KEY, options={"verify_signature": False}
             )
             )
-            return decoded
+            return decoded_token
         else:
         else:
             return None
             return None
     except Exception as e:
     except Exception as e: