main.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. from litellm.proxy.proxy_server import ProxyConfig, initialize
  2. from litellm.proxy.proxy_server import app
  3. from fastapi import FastAPI, Request, Depends, status, Response
  4. from fastapi.responses import JSONResponse
  5. from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
  6. from starlette.responses import StreamingResponse
  7. import json
  8. from utils.utils import get_http_authorization_cred, get_current_user
  9. from config import ENV
  10. from config import (
  11. MODEL_FILTER_ENABLED,
  12. MODEL_FILTER_LIST,
  13. )
  14. proxy_config = ProxyConfig()
  15. async def config():
  16. router, model_list, general_settings = await proxy_config.load_config(
  17. router=None, config_file_path="./data/litellm/config.yaml"
  18. )
  19. await initialize(config="./data/litellm/config.yaml", telemetry=False)
  20. async def startup():
  21. await config()
  22. @app.on_event("startup")
  23. async def on_startup():
  24. await startup()
  25. app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
  26. app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
  27. @app.middleware("http")
  28. async def auth_middleware(request: Request, call_next):
  29. auth_header = request.headers.get("Authorization", "")
  30. request.state.user = None
  31. try:
  32. user = get_current_user(get_http_authorization_cred(auth_header))
  33. print(user)
  34. request.state.user = user
  35. except Exception as e:
  36. return JSONResponse(status_code=400, content={"detail": str(e)})
  37. response = await call_next(request)
  38. return response
  39. class ModifyModelsResponseMiddleware(BaseHTTPMiddleware):
  40. async def dispatch(
  41. self, request: Request, call_next: RequestResponseEndpoint
  42. ) -> Response:
  43. response = await call_next(request)
  44. user = request.state.user
  45. if "/models" in request.url.path:
  46. if isinstance(response, StreamingResponse):
  47. # Read the content of the streaming response
  48. body = b""
  49. async for chunk in response.body_iterator:
  50. body += chunk
  51. data = json.loads(body.decode("utf-8"))
  52. if app.state.MODEL_FILTER_ENABLED:
  53. if user and user.role == "user":
  54. data["data"] = list(
  55. filter(
  56. lambda model: model["id"]
  57. in app.state.MODEL_FILTER_LIST,
  58. data["data"],
  59. )
  60. )
  61. # Modified Flag
  62. data["modified"] = True
  63. return JSONResponse(content=data)
  64. return response
  65. app.add_middleware(ModifyModelsResponseMiddleware)