Browse Source

feat: secure litellm api

Timothy J. Baek 1 year ago
parent
commit
b5bd07a06a
2 changed files with 30 additions and 2 deletions
  1. 19 2
      backend/main.py
  2. 11 0
      backend/utils/utils.py

+ 19 - 2
backend/main.py

@@ -4,9 +4,10 @@ import markdown
 import time
 import time
 
 
 
 
-from fastapi import FastAPI, Request
+from fastapi import FastAPI, Request, Depends
 from fastapi.staticfiles import StaticFiles
 from fastapi.staticfiles import StaticFiles
 from fastapi import HTTPException
 from fastapi import HTTPException
+from fastapi.responses import JSONResponse
 from fastapi.middleware.wsgi import WSGIMiddleware
 from fastapi.middleware.wsgi import WSGIMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from starlette.exceptions import HTTPException as StarletteHTTPException
 from starlette.exceptions import HTTPException as StarletteHTTPException
@@ -19,10 +20,11 @@ from apps.openai.main import app as openai_app
 from apps.audio.main import app as audio_app
 from apps.audio.main import app as audio_app
 from apps.images.main import app as images_app
 from apps.images.main import app as images_app
 from apps.rag.main import app as rag_app
 from apps.rag.main import app as rag_app
-
 from apps.web.main import app as webui_app
 from apps.web.main import app as webui_app
 
 
+
 from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
 from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
+from utils.utils import get_http_authorization_cred, get_current_user
 
 
 
 
 class SPAStaticFiles(StaticFiles):
 class SPAStaticFiles(StaticFiles):
@@ -59,6 +61,21 @@ async def check_url(request: Request, call_next):
     return response
     return response
 
 
 
 
+@litellm_app.middleware("http")
+async def auth_middleware(request: Request, call_next):
+    auth_header = request.headers.get("Authorization", "")
+
+    if ENV != "dev":
+        try:
+            user = get_current_user(get_http_authorization_cred(auth_header))
+            print(user)
+        except Exception as e:
+            return JSONResponse(status_code=400, content={"detail": str(e)})
+
+    response = await call_next(request)
+    return response
+
+
 app.mount("/api/v1", webui_app)
 app.mount("/api/v1", webui_app)
 app.mount("/litellm/api", litellm_app)
 app.mount("/litellm/api", litellm_app)
 
 

+ 11 - 0
backend/utils/utils.py

@@ -58,6 +58,17 @@ def extract_token_from_auth_header(auth_header: str):
     return auth_header[len("Bearer ") :]
     return auth_header[len("Bearer ") :]
 
 
 
 
+def get_http_authorization_cred(auth_header: str):
+    try:
+        scheme, credentials = auth_header.split(" ")
+        return {
+            "scheme": scheme,
+            "credentials": credentials,
+        }
+    except:
+        raise ValueError(ERROR_MESSAGES.INVALID_TOKEN)
+
+
 def get_current_user(
 def get_current_user(
     auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
     auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
 ):
 ):