main.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from litellm.proxy.proxy_server import ProxyConfig, initialize
  2. from litellm.proxy.proxy_server import app
  3. from fastapi import FastAPI, Request, Depends, status, Response
  4. from fastapi.responses import JSONResponse
  5. from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
  6. from starlette.responses import StreamingResponse
  7. import json
  8. from utils.utils import get_http_authorization_cred, get_current_user
  9. from config import ENV
  10. from config import (
  11. MODEL_FILTER_ENABLED,
  12. MODEL_FILTER_LIST,
  13. )
  14. proxy_config = ProxyConfig()
  15. async def config():
  16. router, model_list, general_settings = await proxy_config.load_config(
  17. router=None, config_file_path="./data/litellm/config.yaml"
  18. )
  19. await initialize(config="./data/litellm/config.yaml", telemetry=False)
  20. async def startup():
  21. await config()
  22. @app.on_event("startup")
  23. async def on_startup():
  24. await startup()
  25. app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
  26. app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
  27. @app.middleware("http")
  28. async def auth_middleware(request: Request, call_next):
  29. auth_header = request.headers.get("Authorization", "")
  30. request.state.user = None
  31. if ENV != "dev":
  32. try:
  33. user = get_current_user(get_http_authorization_cred(auth_header))
  34. print(user)
  35. request.state.user = user
  36. except Exception as e:
  37. return JSONResponse(status_code=400, content={"detail": str(e)})
  38. response = await call_next(request)
  39. return response
  40. class ModifyModelsResponseMiddleware(BaseHTTPMiddleware):
  41. async def dispatch(
  42. self, request: Request, call_next: RequestResponseEndpoint
  43. ) -> Response:
  44. response = await call_next(request)
  45. user = request.state.user
  46. # Check if the request is for the `/models` route
  47. if "/models" in request.url.path:
  48. # Ensure the response is a StreamingResponse
  49. if isinstance(response, StreamingResponse):
  50. # Read the content of the streaming response
  51. body = b""
  52. async for chunk in response.body_iterator:
  53. body += chunk
  54. # Modify the content as needed
  55. data = json.loads(body.decode("utf-8"))
  56. print(data)
  57. if app.state.MODEL_FILTER_ENABLED:
  58. if user and user.role == "user":
  59. data["data"] = list(
  60. filter(
  61. lambda model: model["id"]
  62. in app.state.MODEL_FILTER_LIST,
  63. data["data"],
  64. )
  65. )
  66. # Example modification: Add a new key-value pair
  67. data["modified"] = True
  68. # Return a new JSON response with the modified content
  69. return JSONResponse(content=data)
  70. return response
  71. # Add the middleware to the app
  72. app.add_middleware(ModifyModelsResponseMiddleware)