main.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import logging
  2. from litellm.proxy.proxy_server import ProxyConfig, initialize
  3. from litellm.proxy.proxy_server import app
  4. from fastapi import FastAPI, Request, Depends, status, Response
  5. from fastapi.responses import JSONResponse
  6. from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
  7. from starlette.responses import StreamingResponse
  8. import json
  9. from utils.utils import get_http_authorization_cred, get_current_user
  10. from config import SRC_LOG_LEVELS, ENV
  11. log = logging.getLogger(__name__)
  12. log.setLevel(SRC_LOG_LEVELS["LITELLM"])
  13. from config import (
  14. MODEL_FILTER_ENABLED,
  15. MODEL_FILTER_LIST,
  16. )
  17. proxy_config = ProxyConfig()
  18. async def config():
  19. router, model_list, general_settings = await proxy_config.load_config(
  20. router=None, config_file_path="./data/litellm/config.yaml"
  21. )
  22. await initialize(config="./data/litellm/config.yaml", telemetry=False)
  23. async def startup():
  24. await config()
  25. @app.on_event("startup")
  26. async def on_startup():
  27. await startup()
  28. app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
  29. app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
  30. @app.middleware("http")
  31. async def auth_middleware(request: Request, call_next):
  32. auth_header = request.headers.get("Authorization", "")
  33. request.state.user = None
  34. try:
  35. user = get_current_user(get_http_authorization_cred(auth_header))
  36. log.debug(f"user: {user}")
  37. request.state.user = user
  38. except Exception as e:
  39. return JSONResponse(status_code=400, content={"detail": str(e)})
  40. response = await call_next(request)
  41. return response
  42. class ModifyModelsResponseMiddleware(BaseHTTPMiddleware):
  43. async def dispatch(
  44. self, request: Request, call_next: RequestResponseEndpoint
  45. ) -> Response:
  46. response = await call_next(request)
  47. user = request.state.user
  48. if "/models" in request.url.path:
  49. if isinstance(response, StreamingResponse):
  50. # Read the content of the streaming response
  51. body = b""
  52. async for chunk in response.body_iterator:
  53. body += chunk
  54. data = json.loads(body.decode("utf-8"))
  55. if app.state.MODEL_FILTER_ENABLED:
  56. if user and user.role == "user":
  57. data["data"] = list(
  58. filter(
  59. lambda model: model["id"]
  60. in app.state.MODEL_FILTER_LIST,
  61. data["data"],
  62. )
  63. )
  64. # Modified Flag
  65. data["modified"] = True
  66. return JSONResponse(content=data)
  67. return response
  68. app.add_middleware(ModifyModelsResponseMiddleware)