Browse Source

Merge pull request #10860 from open-webui/audit-log-dev

feat: add audit logging feature
Timothy Jaeryang Baek 2 months ago
parent
commit
1ae702b8a6

+ 22 - 0
backend/open_webui/env.py

@@ -419,3 +419,25 @@ OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
 
 if OFFLINE_MODE:
     os.environ["HF_HUB_OFFLINE"] = "1"
+
+####################################
+# AUDIT LOGGING
+####################################
+ENABLE_AUDIT_LOGS = os.getenv("ENABLE_AUDIT_LOGS", "false").lower() == "true"
+# Where to store log file
+AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
+# Maximum size of a file before rotating into a new log file
+AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
+# METADATA | REQUEST | REQUEST_RESPONSE
+AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "REQUEST_RESPONSE").upper()
+try:
+    MAX_BODY_LOG_SIZE = int(os.environ.get("MAX_BODY_LOG_SIZE") or 2048)
+except ValueError:
+    MAX_BODY_LOG_SIZE = 2048
+
+# Comma separated list for urls to exclude from audit
+AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders").split(
+    ","
+)
+AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS]
+AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]

+ 20 - 0
backend/open_webui/main.py

@@ -45,6 +45,9 @@ from starlette.middleware.sessions import SessionMiddleware
 from starlette.responses import Response, StreamingResponse
 
 
+from open_webui.utils import logger
+from open_webui.utils.audit import AuditLevel, AuditLoggingMiddleware
+from open_webui.utils.logger import start_logger
 from open_webui.socket.main import (
     app as socket_app,
     periodic_usage_pool_cleanup,
@@ -304,8 +307,11 @@ from open_webui.config import (
     reset_config,
 )
 from open_webui.env import (
+    AUDIT_EXCLUDED_PATHS,
+    AUDIT_LOG_LEVEL,
     CHANGELOG,
     GLOBAL_LOG_LEVEL,
+    MAX_BODY_LOG_SIZE,
     SAFE_MODE,
     SRC_LOG_LEVELS,
     VERSION,
@@ -390,6 +396,7 @@ https://github.com/open-webui/open-webui
 
 @asynccontextmanager
 async def lifespan(app: FastAPI):
+    start_logger()
     if RESET_CONFIG_ON_START:
         reset_config()
 
@@ -891,6 +898,19 @@ app.include_router(
 app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"])
 
 
+try:
+    audit_level = AuditLevel(AUDIT_LOG_LEVEL)
+except ValueError as e:
+    logger.error(f"Invalid audit level: {AUDIT_LOG_LEVEL}. Error: {e}")
+    audit_level = AuditLevel.NONE
+
+if audit_level != AuditLevel.NONE:
+    app.add_middleware(
+        AuditLoggingMiddleware,
+        audit_level=audit_level,
+        excluded_paths=AUDIT_EXCLUDED_PATHS,
+        max_body_size=MAX_BODY_LOG_SIZE,
+    )
 ##################################
 #
 # Chat Endpoints

+ 249 - 0
backend/open_webui/utils/audit.py

@@ -0,0 +1,249 @@
+from contextlib import asynccontextmanager
+from dataclasses import asdict, dataclass
+from enum import Enum
+import re
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    AsyncGenerator,
+    Dict,
+    MutableMapping,
+    Optional,
+    cast,
+)
+import uuid
+
+from asgiref.typing import (
+    ASGI3Application,
+    ASGIReceiveCallable,
+    ASGIReceiveEvent,
+    ASGISendCallable,
+    ASGISendEvent,
+    Scope as ASGIScope,
+)
+from loguru import logger
+from starlette.requests import Request
+
+from open_webui.env import AUDIT_LOG_LEVEL, MAX_BODY_LOG_SIZE
+from open_webui.utils.auth import get_current_user, get_http_authorization_cred
+from open_webui.models.users import UserModel
+
+
+if TYPE_CHECKING:
+    from loguru import Logger
+
+
+@dataclass(frozen=True)
+class AuditLogEntry:
+    # `Metadata` audit level properties
+    id: str
+    user: dict[str, Any]
+    audit_level: str
+    verb: str
+    request_uri: str
+    user_agent: Optional[str] = None
+    source_ip: Optional[str] = None
+    # `Request` audit level properties
+    request_object: Any = None
+    # `Request Response` level
+    response_object: Any = None
+    response_status_code: Optional[int] = None
+
+
+class AuditLevel(str, Enum):
+    NONE = "NONE"
+    METADATA = "METADATA"
+    REQUEST = "REQUEST"
+    REQUEST_RESPONSE = "REQUEST_RESPONSE"
+
+
+class AuditLogger:
+    """
+    A helper class that encapsulates audit logging functionality. It uses Loguru’s logger with an auditable binding to ensure that audit log entries are filtered correctly.
+
+    Parameters:
+    logger (Logger): An instance of Loguru’s logger.
+    """
+
+    def __init__(self, logger: "Logger"):
+        self.logger = logger.bind(auditable=True)
+
+    def write(
+        self,
+        audit_entry: AuditLogEntry,
+        *,
+        log_level: str = "INFO",
+        extra: Optional[dict] = None,
+    ):
+
+        entry = asdict(audit_entry)
+
+        if extra:
+            entry["extra"] = extra
+
+        self.logger.log(
+            log_level,
+            "",
+            **entry,
+        )
+
+
+class AuditContext:
+    """
+    Captures and aggregates the HTTP request and response bodies during the processing of a request. It ensures that only a configurable maximum amount of data is stored to prevent excessive memory usage.
+
+    Attributes:
+    request_body (bytearray): Accumulated request payload.
+    response_body (bytearray): Accumulated response payload.
+    max_body_size (int): Maximum number of bytes to capture.
+    metadata (Dict[str, Any]): A dictionary to store additional audit metadata (user, http verb, user agent, etc.).
+    """
+
+    def __init__(self, max_body_size: int = MAX_BODY_LOG_SIZE):
+        self.request_body = bytearray()
+        self.response_body = bytearray()
+        self.max_body_size = max_body_size
+        self.metadata: Dict[str, Any] = {}
+
+    def add_request_chunk(self, chunk: bytes):
+        if len(self.request_body) < self.max_body_size:
+            self.request_body.extend(
+                chunk[: self.max_body_size - len(self.request_body)]
+            )
+
+    def add_response_chunk(self, chunk: bytes):
+        if len(self.response_body) < self.max_body_size:
+            self.response_body.extend(
+                chunk[: self.max_body_size - len(self.response_body)]
+            )
+
+
+class AuditLoggingMiddleware:
+    """
+    ASGI middleware that intercepts HTTP requests and responses to perform audit logging. It captures request/response bodies (depending on audit level), headers, HTTP methods, and user information, then logs a structured audit entry at the end of the request cycle.
+    """
+
+    AUDITED_METHODS = {"PUT", "PATCH", "DELETE", "POST"}
+
+    def __init__(
+        self,
+        app: ASGI3Application,
+        *,
+        excluded_paths: Optional[list[str]] = None,
+        max_body_size: int = MAX_BODY_LOG_SIZE,
+        audit_level: AuditLevel = AuditLevel.NONE,
+    ) -> None:
+        self.app = app
+        self.audit_logger = AuditLogger(logger)
+        self.excluded_paths = excluded_paths or []
+        self.max_body_size = max_body_size
+        self.audit_level = audit_level
+
+    async def __call__(
+        self,
+        scope: ASGIScope,
+        receive: ASGIReceiveCallable,
+        send: ASGISendCallable,
+    ) -> None:
+        if scope["type"] != "http":
+            return await self.app(scope, receive, send)
+
+        request = Request(scope=cast(MutableMapping, scope))
+
+        if self._should_skip_auditing(request):
+            return await self.app(scope, receive, send)
+
+        async with self._audit_context(request) as context:
+
+            async def send_wrapper(message: ASGISendEvent) -> None:
+                if self.audit_level == AuditLevel.REQUEST_RESPONSE:
+                    await self._capture_response(message, context)
+
+                await send(message)
+
+            original_receive = receive
+
+            async def receive_wrapper() -> ASGIReceiveEvent:
+                nonlocal original_receive
+                message = await original_receive()
+
+                if self.audit_level in (
+                    AuditLevel.REQUEST,
+                    AuditLevel.REQUEST_RESPONSE,
+                ):
+                    await self._capture_request(message, context)
+
+                return message
+
+            await self.app(scope, receive_wrapper, send_wrapper)
+
+    @asynccontextmanager
+    async def _audit_context(
+        self, request: Request
+    ) -> AsyncGenerator[AuditContext, None]:
+        """
+        async context manager that ensures that an audit log entry is recorded after the request is processed.
+        """
+        context = AuditContext()
+        try:
+            yield context
+        finally:
+            await self._log_audit_entry(request, context)
+
+    async def _get_authenticated_user(self, request: Request) -> UserModel:
+
+        auth_header = request.headers.get("Authorization")
+        assert auth_header
+        user = get_current_user(request, get_http_authorization_cred(auth_header))
+
+        return user
+
+    def _should_skip_auditing(self, request: Request) -> bool:
+        if (
+            request.method not in {"POST", "PUT", "PATCH", "DELETE"}
+            or AUDIT_LOG_LEVEL == "NONE"
+            or not request.headers.get("authorization")
+        ):
+            return True
+        # match either /api/<resource>/...(for the endpoint /api/chat case) or /api/v1/<resource>/...
+        pattern = re.compile(
+            r"^/api(?:/v1)?/(" + "|".join(self.excluded_paths) + r")\b"
+        )
+        if pattern.match(request.url.path):
+            return True
+
+        return False
+
+    async def _capture_request(self, message: ASGIReceiveEvent, context: AuditContext):
+        if message["type"] == "http.request":
+            body = message.get("body", b"")
+            context.add_request_chunk(body)
+
+    async def _capture_response(self, message: ASGISendEvent, context: AuditContext):
+        if message["type"] == "http.response.start":
+            context.metadata["response_status_code"] = message["status"]
+
+        elif message["type"] == "http.response.body":
+            body = message.get("body", b"")
+            context.add_response_chunk(body)
+
+    async def _log_audit_entry(self, request: Request, context: AuditContext):
+        try:
+            user = await self._get_authenticated_user(request)
+
+            entry = AuditLogEntry(
+                id=str(uuid.uuid4()),
+                user=user.model_dump(include={"id", "name", "email", "role"}),
+                audit_level=self.audit_level.value,
+                verb=request.method,
+                request_uri=str(request.url),
+                response_status_code=context.metadata.get("response_status_code", None),
+                source_ip=request.client.host if request.client else None,
+                user_agent=request.headers.get("user-agent"),
+                request_object=context.request_body.decode("utf-8", errors="replace"),
+                response_object=context.response_body.decode("utf-8", errors="replace"),
+            )
+
+            self.audit_logger.write(entry)
+        except Exception as e:
+            logger.error(f"Failed to log audit entry: {str(e)}")

+ 140 - 0
backend/open_webui/utils/logger.py

@@ -0,0 +1,140 @@
+import json
+import logging
+import sys
+from typing import TYPE_CHECKING
+
+from loguru import logger
+
+from open_webui.env import (
+    AUDIT_LOG_FILE_ROTATION_SIZE,
+    AUDIT_LOG_LEVEL,
+    AUDIT_LOGS_FILE_PATH,
+    GLOBAL_LOG_LEVEL,
+)
+
+
+if TYPE_CHECKING:
+    from loguru import Record
+
+
+def stdout_format(record: "Record") -> str:
+    """
+    Generates a formatted string for log records that are output to the console. This format includes a timestamp, log level, source location (module, function, and line), the log message, and any extra data (serialized as JSON).
+
+    Parameters:
+    record (Record): A Loguru record that contains logging details including time, level, name, function, line, message, and any extra context.
+    Returns:
+    str: A formatted log string intended for stdout.
+    """
+    record["extra"]["extra_json"] = json.dumps(record["extra"])
+    return (
+        "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
+        "<level>{level: <8}</level> | "
+        "<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
+        "<level>{message}</level> - {extra[extra_json]}"
+        "\n{exception}"
+    )
+
+
+class InterceptHandler(logging.Handler):
+    """
+    Intercepts log records from Python's standard logging module
+    and redirects them to Loguru's logger.
+    """
+
+    def emit(self, record):
+        """
+        Called by the standard logging module for each log event.
+        It transforms the standard `LogRecord` into a format compatible with Loguru
+        and passes it to Loguru's logger.
+        """
+        try:
+            level = logger.level(record.levelname).name
+        except ValueError:
+            level = record.levelno
+
+        frame, depth = sys._getframe(6), 6
+        while frame and frame.f_code.co_filename == logging.__file__:
+            frame = frame.f_back
+            depth += 1
+
+        logger.opt(depth=depth, exception=record.exc_info).log(
+            level, record.getMessage()
+        )
+
+
+def file_format(record: "Record"):
+    """
+    Formats audit log records into a structured JSON string for file output.
+
+    Parameters:
+    record (Record): A Loguru record containing extra audit data.
+    Returns:
+    str: A JSON-formatted string representing the audit data.
+    """
+
+    audit_data = {
+        "id": record["extra"].get("id", ""),
+        "timestamp": int(record["time"].timestamp()),
+        "user": record["extra"].get("user", dict()),
+        "audit_level": record["extra"].get("audit_level", ""),
+        "verb": record["extra"].get("verb", ""),
+        "request_uri": record["extra"].get("request_uri", ""),
+        "response_status_code": record["extra"].get("response_status_code", 0),
+        "source_ip": record["extra"].get("source_ip", ""),
+        "user_agent": record["extra"].get("user_agent", ""),
+        "request_object": record["extra"].get("request_object", b""),
+        "response_object": record["extra"].get("response_object", b""),
+        "extra": record["extra"].get("extra", {}),
+    }
+
+    record["extra"]["file_extra"] = json.dumps(audit_data, default=str)
+    return "{extra[file_extra]}\n"
+
+
+def start_logger():
+    """
+    Initializes and configures Loguru's logger with distinct handlers:
+
+    A console (stdout) handler for general log messages (excluding those marked as auditable).
+    An optional file handler for audit logs if audit logging is enabled.
+    Additionally, this function reconfigures Python’s standard logging to route through Loguru and adjusts logging levels for Uvicorn.
+
+    Parameters:
+    enable_audit_logging (bool): Determines whether audit-specific log entries should be recorded to file.
+    """
+    logger.remove()
+
+    logger.add(
+        sys.stdout,
+        level=GLOBAL_LOG_LEVEL,
+        format=stdout_format,
+        filter=lambda record: "auditable" not in record["extra"],
+    )
+
+    if AUDIT_LOG_LEVEL != "NONE":
+        try:
+            logger.add(
+                AUDIT_LOGS_FILE_PATH,
+                level="INFO",
+                rotation=AUDIT_LOG_FILE_ROTATION_SIZE,
+                compression="zip",
+                format=file_format,
+                filter=lambda record: record["extra"].get("auditable") is True,
+            )
+        except Exception as e:
+            logger.error(f"Failed to initialize audit log file handler: {str(e)}")
+
+    logging.basicConfig(
+        handlers=[InterceptHandler()], level=GLOBAL_LOG_LEVEL, force=True
+    )
+    for uvicorn_logger_name in ["uvicorn", "uvicorn.error"]:
+        uvicorn_logger = logging.getLogger(uvicorn_logger_name)
+        uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
+        uvicorn_logger.handlers = []
+    for uvicorn_logger_name in ["uvicorn.access"]:
+        uvicorn_logger = logging.getLogger(uvicorn_logger_name)
+        uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
+        uvicorn_logger.handlers = [InterceptHandler()]
+
+    logger.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")

+ 3 - 0
backend/requirements.txt

@@ -31,6 +31,9 @@ APScheduler==3.10.4
 
 RestrictedPython==8.0
 
+loguru==0.7.2
+asgiref==3.8.1
+
 # AI libraries
 openai
 anthropic

+ 3 - 0
pyproject.toml

@@ -40,6 +40,9 @@ dependencies = [
 
     "RestrictedPython==8.0",
 
+    "loguru==0.7.2",
+    "asgiref==3.8.1",
+
     "openai",
     "anthropic",
     "google-generativeai==0.7.2",