123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249 |
- from contextlib import asynccontextmanager
- from dataclasses import asdict, dataclass
- from enum import Enum
- import re
- from typing import (
- TYPE_CHECKING,
- Any,
- AsyncGenerator,
- Dict,
- MutableMapping,
- Optional,
- cast,
- )
- import uuid
- from asgiref.typing import (
- ASGI3Application,
- ASGIReceiveCallable,
- ASGIReceiveEvent,
- ASGISendCallable,
- ASGISendEvent,
- Scope as ASGIScope,
- )
- from loguru import logger
- from starlette.requests import Request
- from open_webui.env import AUDIT_LOG_LEVEL, MAX_BODY_LOG_SIZE
- from open_webui.utils.auth import get_current_user, get_http_authorization_cred
- from open_webui.models.users import UserModel
- if TYPE_CHECKING:
- from loguru import Logger
- @dataclass(frozen=True)
- class AuditLogEntry:
- # `Metadata` audit level properties
- id: str
- user: dict[str, Any]
- audit_level: str
- verb: str
- request_uri: str
- user_agent: Optional[str] = None
- source_ip: Optional[str] = None
- # `Request` audit level properties
- request_object: Any = None
- # `Request Response` level
- response_object: Any = None
- response_status_code: Optional[int] = None
- class AuditLevel(str, Enum):
- NONE = "NONE"
- METADATA = "METADATA"
- REQUEST = "REQUEST"
- REQUEST_RESPONSE = "REQUEST_RESPONSE"
- class AuditLogger:
- """
- A helper class that encapsulates audit logging functionality. It uses Loguru’s logger with an auditable binding to ensure that audit log entries are filtered correctly.
- Parameters:
- logger (Logger): An instance of Loguru’s logger.
- """
- def __init__(self, logger: "Logger"):
- self.logger = logger.bind(auditable=True)
- def write(
- self,
- audit_entry: AuditLogEntry,
- *,
- log_level: str = "INFO",
- extra: Optional[dict] = None,
- ):
- entry = asdict(audit_entry)
- if extra:
- entry["extra"] = extra
- self.logger.log(
- log_level,
- "",
- **entry,
- )
- class AuditContext:
- """
- Captures and aggregates the HTTP request and response bodies during the processing of a request. It ensures that only a configurable maximum amount of data is stored to prevent excessive memory usage.
- Attributes:
- request_body (bytearray): Accumulated request payload.
- response_body (bytearray): Accumulated response payload.
- max_body_size (int): Maximum number of bytes to capture.
- metadata (Dict[str, Any]): A dictionary to store additional audit metadata (user, http verb, user agent, etc.).
- """
- def __init__(self, max_body_size: int = MAX_BODY_LOG_SIZE):
- self.request_body = bytearray()
- self.response_body = bytearray()
- self.max_body_size = max_body_size
- self.metadata: Dict[str, Any] = {}
- def add_request_chunk(self, chunk: bytes):
- if len(self.request_body) < self.max_body_size:
- self.request_body.extend(
- chunk[: self.max_body_size - len(self.request_body)]
- )
- def add_response_chunk(self, chunk: bytes):
- if len(self.response_body) < self.max_body_size:
- self.response_body.extend(
- chunk[: self.max_body_size - len(self.response_body)]
- )
- class AuditLoggingMiddleware:
- """
- ASGI middleware that intercepts HTTP requests and responses to perform audit logging. It captures request/response bodies (depending on audit level), headers, HTTP methods, and user information, then logs a structured audit entry at the end of the request cycle.
- """
- AUDITED_METHODS = {"PUT", "PATCH", "DELETE", "POST"}
- def __init__(
- self,
- app: ASGI3Application,
- *,
- excluded_paths: Optional[list[str]] = None,
- max_body_size: int = MAX_BODY_LOG_SIZE,
- audit_level: AuditLevel = AuditLevel.NONE,
- ) -> None:
- self.app = app
- self.audit_logger = AuditLogger(logger)
- self.excluded_paths = excluded_paths or []
- self.max_body_size = max_body_size
- self.audit_level = audit_level
- async def __call__(
- self,
- scope: ASGIScope,
- receive: ASGIReceiveCallable,
- send: ASGISendCallable,
- ) -> None:
- if scope["type"] != "http":
- return await self.app(scope, receive, send)
- request = Request(scope=cast(MutableMapping, scope))
- if self._should_skip_auditing(request):
- return await self.app(scope, receive, send)
- async with self._audit_context(request) as context:
- async def send_wrapper(message: ASGISendEvent) -> None:
- if self.audit_level == AuditLevel.REQUEST_RESPONSE:
- await self._capture_response(message, context)
- await send(message)
- original_receive = receive
- async def receive_wrapper() -> ASGIReceiveEvent:
- nonlocal original_receive
- message = await original_receive()
- if self.audit_level in (
- AuditLevel.REQUEST,
- AuditLevel.REQUEST_RESPONSE,
- ):
- await self._capture_request(message, context)
- return message
- await self.app(scope, receive_wrapper, send_wrapper)
- @asynccontextmanager
- async def _audit_context(
- self, request: Request
- ) -> AsyncGenerator[AuditContext, None]:
- """
- async context manager that ensures that an audit log entry is recorded after the request is processed.
- """
- context = AuditContext()
- try:
- yield context
- finally:
- await self._log_audit_entry(request, context)
- async def _get_authenticated_user(self, request: Request) -> UserModel:
- auth_header = request.headers.get("Authorization")
- assert auth_header
- user = get_current_user(request, None, get_http_authorization_cred(auth_header))
- return user
- def _should_skip_auditing(self, request: Request) -> bool:
- if (
- request.method not in {"POST", "PUT", "PATCH", "DELETE"}
- or AUDIT_LOG_LEVEL == "NONE"
- or not request.headers.get("authorization")
- ):
- return True
- # match either /api/<resource>/...(for the endpoint /api/chat case) or /api/v1/<resource>/...
- pattern = re.compile(
- r"^/api(?:/v1)?/(" + "|".join(self.excluded_paths) + r")\b"
- )
- if pattern.match(request.url.path):
- return True
- return False
- async def _capture_request(self, message: ASGIReceiveEvent, context: AuditContext):
- if message["type"] == "http.request":
- body = message.get("body", b"")
- context.add_request_chunk(body)
- async def _capture_response(self, message: ASGISendEvent, context: AuditContext):
- if message["type"] == "http.response.start":
- context.metadata["response_status_code"] = message["status"]
- elif message["type"] == "http.response.body":
- body = message.get("body", b"")
- context.add_response_chunk(body)
- async def _log_audit_entry(self, request: Request, context: AuditContext):
- try:
- user = await self._get_authenticated_user(request)
- entry = AuditLogEntry(
- id=str(uuid.uuid4()),
- user=user.model_dump(include={"id", "name", "email", "role"}),
- audit_level=self.audit_level.value,
- verb=request.method,
- request_uri=str(request.url),
- response_status_code=context.metadata.get("response_status_code", None),
- source_ip=request.client.host if request.client else None,
- user_agent=request.headers.get("user-agent"),
- request_object=context.request_body.decode("utf-8", errors="replace"),
- response_object=context.response_body.decode("utf-8", errors="replace"),
- )
- self.audit_logger.write(entry)
- except Exception as e:
- logger.error(f"Failed to log audit entry: {str(e)}")
|