Просмотр исходного кода

feat: litellm model filter support

Timothy J. Baek 1 год назад
Родитель
Сommit
93c90dc186
2 измененных файлов с 65 добавлено и 2 удалено
  1. 64 1
      backend/apps/litellm/main.py
  2. 1 1
      backend/config.py

+ 64 - 1
backend/apps/litellm/main.py

@@ -1,11 +1,23 @@
 from litellm.proxy.proxy_server import ProxyConfig, initialize
 from litellm.proxy.proxy_server import app
 
-from fastapi import FastAPI, Request, Depends, status
+from fastapi import FastAPI, Request, Depends, status, Response
 from fastapi.responses import JSONResponse
+
+from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
+from starlette.responses import StreamingResponse
+import json
+
 from utils.utils import get_http_authorization_cred, get_current_user
 from config import ENV
 
+
+from config import (
+    MODEL_FILTER_ENABLED,
+    MODEL_FILTER_LIST,
+)
+
+
 proxy_config = ProxyConfig()
 
 
@@ -26,16 +38,67 @@ async def on_startup():
     await startup()
 
 
+app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
+app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
+
+
 @app.middleware("http")
 async def auth_middleware(request: Request, call_next):
     auth_header = request.headers.get("Authorization", "")
+    request.state.user = None
 
     if ENV != "dev":
         try:
             user = get_current_user(get_http_authorization_cred(auth_header))
             print(user)
+            request.state.user = user
         except Exception as e:
             return JSONResponse(status_code=400, content={"detail": str(e)})
 
     response = await call_next(request)
     return response
+
+
+class ModifyModelsResponseMiddleware(BaseHTTPMiddleware):
+    async def dispatch(
+        self, request: Request, call_next: RequestResponseEndpoint
+    ) -> Response:
+
+        response = await call_next(request)
+        user = request.state.user
+
+        # Check if the request is for the `/models` route
+        if "/models" in request.url.path:
+            # Ensure the response is a StreamingResponse
+            if isinstance(response, StreamingResponse):
+                # Read the content of the streaming response
+                body = b""
+                async for chunk in response.body_iterator:
+                    body += chunk
+
+                # Modify the content as needed
+                data = json.loads(body.decode("utf-8"))
+
+                print(data)
+
+                if app.state.MODEL_FILTER_ENABLED:
+                    if user and user.role == "user":
+                        data["data"] = list(
+                            filter(
+                                lambda model: model["id"]
+                                in app.state.MODEL_FILTER_LIST,
+                                data["data"],
+                            )
+                        )
+
+                # Example modification: Add a new key-value pair
+                data["modified"] = True
+
+                # Return a new JSON response with the modified content
+                return JSONResponse(content=data)
+
+        return response
+
+
+# Add the middleware to the app
+app.add_middleware(ModifyModelsResponseMiddleware)

+ 1 - 1
backend/config.py

@@ -298,7 +298,7 @@ USER_PERMISSIONS_CHAT_DELETION = (
 USER_PERMISSIONS = {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}}
 
 
-MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", False)
+MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", "False").lower() == "true"
 MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
 MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")]