Browse Source

feat: secure litellm api

Timothy J. Baek 1 năm trước cách đây
mục cha
commit
b5bd07a06a
2 tập tin đã thay đổi với 30 bổ sung2 xóa
  1. 19 2
      backend/main.py
  2. 11 0
      backend/utils/utils.py

+ 19 - 2
backend/main.py

@@ -4,9 +4,10 @@ import markdown
 import time
 
 
-from fastapi import FastAPI, Request
+from fastapi import FastAPI, Request, Depends
 from fastapi.staticfiles import StaticFiles
 from fastapi import HTTPException
+from fastapi.responses import JSONResponse
 from fastapi.middleware.wsgi import WSGIMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from starlette.exceptions import HTTPException as StarletteHTTPException
@@ -19,10 +20,11 @@ from apps.openai.main import app as openai_app
 from apps.audio.main import app as audio_app
 from apps.images.main import app as images_app
 from apps.rag.main import app as rag_app
-
 from apps.web.main import app as webui_app
 
+
 from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
+from utils.utils import get_http_authorization_cred, get_current_user
 
 
 class SPAStaticFiles(StaticFiles):
@@ -59,6 +61,21 @@ async def check_url(request: Request, call_next):
     return response
 
 
+@litellm_app.middleware("http")
+async def auth_middleware(request: Request, call_next):
+    auth_header = request.headers.get("Authorization", "")
+
+    if ENV != "dev":
+        try:
+            user = get_current_user(get_http_authorization_cred(auth_header))
+            print(user)
+        except Exception as e:
+            return JSONResponse(status_code=400, content={"detail": str(e)})
+
+    response = await call_next(request)
+    return response
+
+
 app.mount("/api/v1", webui_app)
 app.mount("/litellm/api", litellm_app)
 

+ 11 - 0
backend/utils/utils.py

@@ -58,6 +58,17 @@ def extract_token_from_auth_header(auth_header: str):
     return auth_header[len("Bearer ") :]
 
 
+def get_http_authorization_cred(auth_header: str):
+    try:
+        scheme, credentials = auth_header.split(" ")
+        return {
+            "scheme": scheme,
+            "credentials": credentials,
+        }
+    except:
+        raise ValueError(ERROR_MESSAGES.INVALID_TOKEN)
+
+
 def get_current_user(
     auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
 ):