Browse Source

feat(trace): optimize for trace env and instrument hooks

orenzhang 2 months ago
parent
commit
7bfda6652f

+ 11 - 14
backend/open_webui/env.py

@@ -105,7 +105,6 @@ for source in log_sources:
 
 
 log.setLevel(SRC_LOG_LEVELS["CONFIG"])
 log.setLevel(SRC_LOG_LEVELS["CONFIG"])
 
 
-
 WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
 WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
 if WEBUI_NAME != "Open WebUI":
 if WEBUI_NAME != "Open WebUI":
     WEBUI_NAME += " (Open WebUI)"
     WEBUI_NAME += " (Open WebUI)"
@@ -130,7 +129,6 @@ else:
     except Exception:
     except Exception:
         PACKAGE_DATA = {"version": "0.0.0"}
         PACKAGE_DATA = {"version": "0.0.0"}
 
 
-
 VERSION = PACKAGE_DATA["version"]
 VERSION = PACKAGE_DATA["version"]
 
 
 
 
@@ -161,7 +159,6 @@ try:
 except Exception:
 except Exception:
     changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode()
     changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode()
 
 
-
 # Convert markdown content to HTML
 # Convert markdown content to HTML
 html_content = markdown.markdown(changelog_content)
 html_content = markdown.markdown(changelog_content)
 
 
@@ -192,7 +189,6 @@ for version in soup.find_all("h2"):
 
 
     changelog_json[version_number] = version_data
     changelog_json[version_number] = version_data
 
 
-
 CHANGELOG = changelog_json
 CHANGELOG = changelog_json
 
 
 ####################################
 ####################################
@@ -209,7 +205,6 @@ ENABLE_FORWARD_USER_INFO_HEADERS = (
     os.environ.get("ENABLE_FORWARD_USER_INFO_HEADERS", "False").lower() == "true"
     os.environ.get("ENABLE_FORWARD_USER_INFO_HEADERS", "False").lower() == "true"
 )
 )
 
 
-
 ####################################
 ####################################
 # WEBUI_BUILD_HASH
 # WEBUI_BUILD_HASH
 ####################################
 ####################################
@@ -244,7 +239,6 @@ if FROM_INIT_PY:
 
 
     DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data"))
     DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data"))
 
 
-
 STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static"))
 STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static"))
 
 
 FONTS_DIR = Path(os.getenv("FONTS_DIR", OPEN_WEBUI_DIR / "static" / "fonts"))
 FONTS_DIR = Path(os.getenv("FONTS_DIR", OPEN_WEBUI_DIR / "static" / "fonts"))
@@ -256,7 +250,6 @@ if FROM_INIT_PY:
         os.getenv("FRONTEND_BUILD_DIR", OPEN_WEBUI_DIR / "frontend")
         os.getenv("FRONTEND_BUILD_DIR", OPEN_WEBUI_DIR / "frontend")
     ).resolve()
     ).resolve()
 
 
-
 ####################################
 ####################################
 # Database
 # Database
 ####################################
 ####################################
@@ -321,7 +314,6 @@ RESET_CONFIG_ON_START = (
     os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
     os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
 )
 )
 
 
-
 ENABLE_REALTIME_CHAT_SAVE = (
 ENABLE_REALTIME_CHAT_SAVE = (
     os.environ.get("ENABLE_REALTIME_CHAT_SAVE", "False").lower() == "true"
     os.environ.get("ENABLE_REALTIME_CHAT_SAVE", "False").lower() == "true"
 )
 )
@@ -402,7 +394,6 @@ AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = os.environ.get(
     os.environ.get("AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", ""),
     os.environ.get("AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", ""),
 )
 )
 
 
-
 if AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST == "":
 if AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST == "":
     AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = None
     AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = None
 else:
 else:
@@ -411,7 +402,6 @@ else:
     except Exception:
     except Exception:
         AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = 5
         AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = 5
 
 
-
 ####################################
 ####################################
 # OFFLINE_MODE
 # OFFLINE_MODE
 ####################################
 ####################################
@@ -447,7 +437,14 @@ AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]
 # OPENTELEMETRY
 # OPENTELEMETRY
 ####################################
 ####################################
 
 
-OT_ENABLED = os.environ.get("OT_ENABLED", "false").lower() == "true"
-OT_SERVICE_NAME = os.environ.get("OT_SERVICE_NAME", "open-webui")
-OT_HOST = os.environ.get("OT_HOST", "http://localhost:4317")
-OT_TOKEN = os.environ.get("OT_TOKEN", "")
+OTEL_SDK_DISABLED = os.environ.get("OTEL_SDK_DISABLED", "true").lower() == "true"
+OTEL_EXPORTER_OTLP_ENDPOINT = os.environ.get(
+    "OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317"
+)
+OTEL_SERVICE_NAME = os.environ.get("OTEL_SERVICE_NAME", "open-webui")
+OTEL_RESOURCE_ATTRIBUTES = os.environ.get(
+    "OTEL_RESOURCE_ATTRIBUTES", ""
+)  # e.g. key1=val1,key2=val2
+OTEL_TRACES_SAMPLER = os.environ.get(
+    "OTEL_TRACES_SAMPLER", "parentbased_always_on"
+).lower()

+ 4 - 4
backend/open_webui/main.py

@@ -84,7 +84,7 @@ from open_webui.routers.retrieval import (
     get_rf,
     get_rf,
 )
 )
 
 
-from open_webui.internal.db import Session
+from open_webui.internal.db import Session, engine
 
 
 from open_webui.models.functions import Functions
 from open_webui.models.functions import Functions
 from open_webui.models.models import Models
 from open_webui.models.models import Models
@@ -330,7 +330,7 @@ from open_webui.env import (
     BYPASS_MODEL_ACCESS_CONTROL,
     BYPASS_MODEL_ACCESS_CONTROL,
     RESET_CONFIG_ON_START,
     RESET_CONFIG_ON_START,
     OFFLINE_MODE,
     OFFLINE_MODE,
-    OT_ENABLED,
+    OTEL_SDK_DISABLED,
 )
 )
 
 
 
 
@@ -434,8 +434,8 @@ app.state.LICENSE_METADATA = None
 #
 #
 ########################################
 ########################################
 
 
-if OT_ENABLED:
-    setup(app)
+if not OTEL_SDK_DISABLED:
+    setup(app=app, db_engine=engine)
 
 
 
 
 ########################################
 ########################################

+ 53 - 6
backend/open_webui/utils/trace/instrumentors.py

@@ -1,8 +1,14 @@
 import logging
 import logging
 import traceback
 import traceback
-from typing import Collection
+from typing import Collection, Union
 
 
+from aiohttp import (
+    TraceRequestStartParams,
+    TraceRequestEndParams,
+    TraceRequestExceptionParams,
+)
 from chromadb.telemetry.opentelemetry.fastapi import instrument_fastapi
 from chromadb.telemetry.opentelemetry.fastapi import instrument_fastapi
+from fastapi import FastAPI
 from opentelemetry.instrumentation.httpx import (
 from opentelemetry.instrumentation.httpx import (
     HTTPXClientInstrumentor,
     HTTPXClientInstrumentor,
     RequestInfo,
     RequestInfo,
@@ -17,6 +23,8 @@ from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrument
 from opentelemetry.trace import Span, StatusCode
 from opentelemetry.trace import Span, StatusCode
 from redis import Redis
 from redis import Redis
 from requests import PreparedRequest, Response
 from requests import PreparedRequest, Response
+from sqlalchemy import Engine
+from fastapi import status
 
 
 from open_webui.utils.trace.constants import SPAN_REDIS_TYPE, SpanAttributes
 from open_webui.utils.trace.constants import SPAN_REDIS_TYPE, SpanAttributes
 
 
@@ -105,7 +113,7 @@ def httpx_response_hook(span: Span, request: RequestInfo, response: ResponseInfo
     )
     )
 
 
 
 
-async def httpx_async_request_hook(span, request):
+async def httpx_async_request_hook(span: Span, request: RequestInfo):
     """
     """
     Async Request Hook
     Async Request Hook
     """
     """
@@ -113,7 +121,9 @@ async def httpx_async_request_hook(span, request):
     httpx_request_hook(span, request)
     httpx_request_hook(span, request)
 
 
 
 
-async def httpx_async_response_hook(span, request, response):
+async def httpx_async_response_hook(
+    span: Span, request: RequestInfo, response: ResponseInfo
+):
     """
     """
     Async Response Hook
     Async Response Hook
     """
     """
@@ -121,20 +131,54 @@ async def httpx_async_response_hook(span, request, response):
     httpx_response_hook(span, request, response)
     httpx_response_hook(span, request, response)
 
 
 
 
+def aiohttp_request_hook(span: Span, request: TraceRequestStartParams):
+    """
+    Aiohttp Request Hook
+    """
+
+    span.update_name(f"{request.method} {str(request.url)}")
+    span.set_attributes(
+        attributes={
+            SpanAttributes.HTTP_URL: str(request.url),
+            SpanAttributes.HTTP_METHOD: request.method,
+        }
+    )
+
+
+def aiohttp_response_hook(
+    span: Span, response: Union[TraceRequestExceptionParams, TraceRequestEndParams]
+):
+    """
+    Aiohttp Response Hook
+    """
+
+    if isinstance(response, TraceRequestEndParams):
+        span.set_attribute(SpanAttributes.HTTP_STATUS_CODE, response.response.status)
+        span.set_status(
+            StatusCode.ERROR
+            if response.response.status >= status.HTTP_400_BAD_REQUEST
+            else StatusCode.OK
+        )
+    elif isinstance(response, TraceRequestExceptionParams):
+        span.set_status(StatusCode.ERROR)
+        span.set_attribute(SpanAttributes.ERROR_MESSAGE, str(response.exception))
+
+
 class Instrumentor(BaseInstrumentor):
 class Instrumentor(BaseInstrumentor):
     """
     """
     Instrument OT
     Instrument OT
     """
     """
 
 
-    def __init__(self, app):
+    def __init__(self, app: FastAPI, db_engine: Engine):
         self.app = app
         self.app = app
+        self.db_engine = db_engine
 
 
     def instrumentation_dependencies(self) -> Collection[str]:
     def instrumentation_dependencies(self) -> Collection[str]:
         return []
         return []
 
 
     def _instrument(self, **kwargs):
     def _instrument(self, **kwargs):
         instrument_fastapi(app=self.app)
         instrument_fastapi(app=self.app)
-        SQLAlchemyInstrumentor().instrument()
+        SQLAlchemyInstrumentor().instrument(engine=self.db_engine)
         RedisInstrumentor().instrument(request_hook=redis_request_hook)
         RedisInstrumentor().instrument(request_hook=redis_request_hook)
         RequestsInstrumentor().instrument(
         RequestsInstrumentor().instrument(
             request_hook=requests_hook, response_hook=response_hook
             request_hook=requests_hook, response_hook=response_hook
@@ -146,7 +190,10 @@ class Instrumentor(BaseInstrumentor):
             async_request_hook=httpx_async_request_hook,
             async_request_hook=httpx_async_request_hook,
             async_response_hook=httpx_async_response_hook,
             async_response_hook=httpx_async_response_hook,
         )
         )
-        AioHttpClientInstrumentor().instrument()
+        AioHttpClientInstrumentor().instrument(
+            request_hook=aiohttp_request_hook,
+            response_hook=aiohttp_response_hook,
+        )
 
 
     def _uninstrument(self, **kwargs):
     def _uninstrument(self, **kwargs):
         if getattr(self, "instrumentors", None) is None:
         if getattr(self, "instrumentors", None) is None:

+ 9 - 10
backend/open_webui/utils/trace/setup.py

@@ -1,24 +1,23 @@
+from fastapi import FastAPI
 from opentelemetry import trace
 from opentelemetry import trace
 from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
 from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
 from opentelemetry.sdk.resources import SERVICE_NAME, Resource
 from opentelemetry.sdk.resources import SERVICE_NAME, Resource
 from opentelemetry.sdk.trace import TracerProvider
 from opentelemetry.sdk.trace import TracerProvider
-from opentelemetry.sdk.trace.sampling import ALWAYS_ON
+from sqlalchemy import Engine
 
 
 from open_webui.utils.trace.exporters import LazyBatchSpanProcessor
 from open_webui.utils.trace.exporters import LazyBatchSpanProcessor
 from open_webui.utils.trace.instrumentors import Instrumentor
 from open_webui.utils.trace.instrumentors import Instrumentor
-from open_webui.env import OT_SERVICE_NAME, OT_HOST, OT_TOKEN
+from open_webui.env import OTEL_SERVICE_NAME, OTEL_EXPORTER_OTLP_ENDPOINT
 
 
 
 
-def setup(app):
+def setup(app: FastAPI, db_engine: Engine):
+    # set up trace
     trace.set_tracer_provider(
     trace.set_tracer_provider(
         TracerProvider(
         TracerProvider(
-            resource=Resource.create(
-                {SERVICE_NAME: OT_SERVICE_NAME, "token": OT_TOKEN}
-            ),
-            sampler=ALWAYS_ON,
+            resource=Resource.create(attributes={SERVICE_NAME: OTEL_SERVICE_NAME})
         )
         )
     )
     )
-    # otlp
-    exporter = OTLPSpanExporter(endpoint=OT_HOST)
+    # otlp export
+    exporter = OTLPSpanExporter(endpoint=OTEL_EXPORTER_OTLP_ENDPOINT)
     trace.get_tracer_provider().add_span_processor(LazyBatchSpanProcessor(exporter))
     trace.get_tracer_provider().add_span_processor(LazyBatchSpanProcessor(exporter))
-    Instrumentor(app=app).instrument()
+    Instrumentor(app=app, db_engine=db_engine).instrument()