audit.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. from contextlib import asynccontextmanager
  2. from dataclasses import asdict, dataclass
  3. from enum import Enum
  4. import re
  5. from typing import (
  6. TYPE_CHECKING,
  7. Any,
  8. AsyncGenerator,
  9. Dict,
  10. MutableMapping,
  11. Optional,
  12. cast,
  13. )
  14. import uuid
  15. from asgiref.typing import (
  16. ASGI3Application,
  17. ASGIReceiveCallable,
  18. ASGIReceiveEvent,
  19. ASGISendCallable,
  20. ASGISendEvent,
  21. Scope as ASGIScope,
  22. )
  23. from loguru import logger
  24. from starlette.requests import Request
  25. from open_webui.env import AUDIT_LOG_LEVEL, MAX_BODY_LOG_SIZE
  26. from open_webui.utils.auth import get_current_user, get_http_authorization_cred
  27. from open_webui.models.users import UserModel
  28. if TYPE_CHECKING:
  29. from loguru import Logger
  30. @dataclass(frozen=True)
  31. class AuditLogEntry:
  32. # `Metadata` audit level properties
  33. id: str
  34. user: dict[str, Any]
  35. audit_level: str
  36. verb: str
  37. request_uri: str
  38. user_agent: Optional[str] = None
  39. source_ip: Optional[str] = None
  40. # `Request` audit level properties
  41. request_object: Any = None
  42. # `Request Response` level
  43. response_object: Any = None
  44. response_status_code: Optional[int] = None
  45. class AuditLevel(str, Enum):
  46. NONE = "NONE"
  47. METADATA = "METADATA"
  48. REQUEST = "REQUEST"
  49. REQUEST_RESPONSE = "REQUEST_RESPONSE"
  50. class AuditLogger:
  51. """
  52. A helper class that encapsulates audit logging functionality. It uses Loguru’s logger with an auditable binding to ensure that audit log entries are filtered correctly.
  53. Parameters:
  54. logger (Logger): An instance of Loguru’s logger.
  55. """
  56. def __init__(self, logger: "Logger"):
  57. self.logger = logger.bind(auditable=True)
  58. def write(
  59. self,
  60. audit_entry: AuditLogEntry,
  61. *,
  62. log_level: str = "INFO",
  63. extra: Optional[dict] = None,
  64. ):
  65. entry = asdict(audit_entry)
  66. if extra:
  67. entry["extra"] = extra
  68. self.logger.log(
  69. log_level,
  70. "",
  71. **entry,
  72. )
  73. class AuditContext:
  74. """
  75. Captures and aggregates the HTTP request and response bodies during the processing of a request. It ensures that only a configurable maximum amount of data is stored to prevent excessive memory usage.
  76. Attributes:
  77. request_body (bytearray): Accumulated request payload.
  78. response_body (bytearray): Accumulated response payload.
  79. max_body_size (int): Maximum number of bytes to capture.
  80. metadata (Dict[str, Any]): A dictionary to store additional audit metadata (user, http verb, user agent, etc.).
  81. """
  82. def __init__(self, max_body_size: int = MAX_BODY_LOG_SIZE):
  83. self.request_body = bytearray()
  84. self.response_body = bytearray()
  85. self.max_body_size = max_body_size
  86. self.metadata: Dict[str, Any] = {}
  87. def add_request_chunk(self, chunk: bytes):
  88. if len(self.request_body) < self.max_body_size:
  89. self.request_body.extend(
  90. chunk[: self.max_body_size - len(self.request_body)]
  91. )
  92. def add_response_chunk(self, chunk: bytes):
  93. if len(self.response_body) < self.max_body_size:
  94. self.response_body.extend(
  95. chunk[: self.max_body_size - len(self.response_body)]
  96. )
  97. class AuditLoggingMiddleware:
  98. """
  99. ASGI middleware that intercepts HTTP requests and responses to perform audit logging. It captures request/response bodies (depending on audit level), headers, HTTP methods, and user information, then logs a structured audit entry at the end of the request cycle.
  100. """
  101. AUDITED_METHODS = {"PUT", "PATCH", "DELETE", "POST"}
  102. def __init__(
  103. self,
  104. app: ASGI3Application,
  105. *,
  106. excluded_paths: Optional[list[str]] = None,
  107. max_body_size: int = MAX_BODY_LOG_SIZE,
  108. audit_level: AuditLevel = AuditLevel.NONE,
  109. ) -> None:
  110. self.app = app
  111. self.audit_logger = AuditLogger(logger)
  112. self.excluded_paths = excluded_paths or []
  113. self.max_body_size = max_body_size
  114. self.audit_level = audit_level
  115. async def __call__(
  116. self,
  117. scope: ASGIScope,
  118. receive: ASGIReceiveCallable,
  119. send: ASGISendCallable,
  120. ) -> None:
  121. if scope["type"] != "http":
  122. return await self.app(scope, receive, send)
  123. request = Request(scope=cast(MutableMapping, scope))
  124. if self._should_skip_auditing(request):
  125. return await self.app(scope, receive, send)
  126. async with self._audit_context(request) as context:
  127. async def send_wrapper(message: ASGISendEvent) -> None:
  128. if self.audit_level == AuditLevel.REQUEST_RESPONSE:
  129. await self._capture_response(message, context)
  130. await send(message)
  131. original_receive = receive
  132. async def receive_wrapper() -> ASGIReceiveEvent:
  133. nonlocal original_receive
  134. message = await original_receive()
  135. if self.audit_level in (
  136. AuditLevel.REQUEST,
  137. AuditLevel.REQUEST_RESPONSE,
  138. ):
  139. await self._capture_request(message, context)
  140. return message
  141. await self.app(scope, receive_wrapper, send_wrapper)
  142. @asynccontextmanager
  143. async def _audit_context(
  144. self, request: Request
  145. ) -> AsyncGenerator[AuditContext, None]:
  146. """
  147. async context manager that ensures that an audit log entry is recorded after the request is processed.
  148. """
  149. context = AuditContext()
  150. try:
  151. yield context
  152. finally:
  153. await self._log_audit_entry(request, context)
  154. async def _get_authenticated_user(self, request: Request) -> UserModel:
  155. auth_header = request.headers.get("Authorization")
  156. assert auth_header
  157. user = get_current_user(request, None, get_http_authorization_cred(auth_header))
  158. return user
  159. def _should_skip_auditing(self, request: Request) -> bool:
  160. if (
  161. request.method not in {"POST", "PUT", "PATCH", "DELETE"}
  162. or AUDIT_LOG_LEVEL == "NONE"
  163. or not request.headers.get("authorization")
  164. ):
  165. return True
  166. # match either /api/<resource>/...(for the endpoint /api/chat case) or /api/v1/<resource>/...
  167. pattern = re.compile(
  168. r"^/api(?:/v1)?/(" + "|".join(self.excluded_paths) + r")\b"
  169. )
  170. if pattern.match(request.url.path):
  171. return True
  172. return False
  173. async def _capture_request(self, message: ASGIReceiveEvent, context: AuditContext):
  174. if message["type"] == "http.request":
  175. body = message.get("body", b"")
  176. context.add_request_chunk(body)
  177. async def _capture_response(self, message: ASGISendEvent, context: AuditContext):
  178. if message["type"] == "http.response.start":
  179. context.metadata["response_status_code"] = message["status"]
  180. elif message["type"] == "http.response.body":
  181. body = message.get("body", b"")
  182. context.add_response_chunk(body)
  183. async def _log_audit_entry(self, request: Request, context: AuditContext):
  184. try:
  185. user = await self._get_authenticated_user(request)
  186. entry = AuditLogEntry(
  187. id=str(uuid.uuid4()),
  188. user=user.model_dump(include={"id", "name", "email", "role"}),
  189. audit_level=self.audit_level.value,
  190. verb=request.method,
  191. request_uri=str(request.url),
  192. response_status_code=context.metadata.get("response_status_code", None),
  193. source_ip=request.client.host if request.client else None,
  194. user_agent=request.headers.get("user-agent"),
  195. request_object=context.request_body.decode("utf-8", errors="replace"),
  196. response_object=context.response_body.decode("utf-8", errors="replace"),
  197. )
  198. self.audit_logger.write(entry)
  199. except Exception as e:
  200. logger.error(f"Failed to log audit entry: {str(e)}")