main.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. from litellm.proxy.proxy_server import ProxyConfig, initialize
  2. from litellm.proxy.proxy_server import app
  3. from fastapi import FastAPI, Request, Depends, status, Response
  4. from fastapi.responses import JSONResponse
  5. from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
  6. from starlette.responses import StreamingResponse
  7. import json
  8. from utils.utils import get_http_authorization_cred, get_current_user
  9. from config import ENV
  10. from config import (
  11. MODEL_FILTER_ENABLED,
  12. MODEL_FILTER_LIST,
  13. )
  14. proxy_config = ProxyConfig()
  15. async def config():
  16. router, model_list, general_settings = await proxy_config.load_config(
  17. router=None, config_file_path="./data/litellm/config.yaml"
  18. )
  19. await initialize(config="./data/litellm/config.yaml", telemetry=False)
  20. async def startup():
  21. await config()
  22. @app.on_event("startup")
  23. async def on_startup():
  24. await startup()
  25. app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
  26. app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
  27. @app.middleware("http")
  28. async def auth_middleware(request: Request, call_next):
  29. auth_header = request.headers.get("Authorization", "")
  30. request.state.user = None
  31. if ENV != "dev":
  32. try:
  33. user = get_current_user(get_http_authorization_cred(auth_header))
  34. print(user)
  35. request.state.user = user
  36. except Exception as e:
  37. return JSONResponse(status_code=400, content={"detail": str(e)})
  38. response = await call_next(request)
  39. return response
  40. class ModifyModelsResponseMiddleware(BaseHTTPMiddleware):
  41. async def dispatch(
  42. self, request: Request, call_next: RequestResponseEndpoint
  43. ) -> Response:
  44. response = await call_next(request)
  45. user = request.state.user
  46. if "/models" in request.url.path:
  47. if isinstance(response, StreamingResponse):
  48. # Read the content of the streaming response
  49. body = b""
  50. async for chunk in response.body_iterator:
  51. body += chunk
  52. data = json.loads(body.decode("utf-8"))
  53. if app.state.MODEL_FILTER_ENABLED:
  54. if user and user.role == "user":
  55. data["data"] = list(
  56. filter(
  57. lambda model: model["id"]
  58. in app.state.MODEL_FILTER_LIST,
  59. data["data"],
  60. )
  61. )
  62. # Modified Flag
  63. data["modified"] = True
  64. return JSONResponse(content=data)
  65. return response
  66. app.add_middleware(ModifyModelsResponseMiddleware)