Parcourir la source

Merge pull request #10052 from roryeckel/playwright

Support Playwright RAG Web Loader: Revised
Timothy Jaeryang Baek il y a 2 mois
Parent
commit
1bbecd46c8

+ 12 - 0
backend/open_webui/config.py

@@ -1926,12 +1926,24 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
     int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
 )
 
+RAG_WEB_LOADER_ENGINE = PersistentConfig(
+    "RAG_WEB_LOADER_ENGINE",
+    "rag.web.loader.engine",
+    os.environ.get("RAG_WEB_LOADER_ENGINE", "safe_web")
+)
+
 RAG_WEB_SEARCH_TRUST_ENV = PersistentConfig(
     "RAG_WEB_SEARCH_TRUST_ENV",
     "rag.web.search.trust_env",
     os.getenv("RAG_WEB_SEARCH_TRUST_ENV", False),
 )
 
+PLAYWRIGHT_WS_URI = PersistentConfig(
+    "PLAYWRIGHT_WS_URI",
+    "rag.web.loader.engine.playwright.ws.uri",
+    os.environ.get("PLAYWRIGHT_WS_URI", None)
+)
+
 ####################################
 # Images
 ####################################

+ 4 - 0
backend/open_webui/main.py

@@ -147,6 +147,8 @@ from open_webui.config import (
     AUDIO_TTS_VOICE,
     AUDIO_TTS_AZURE_SPEECH_REGION,
     AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
+    PLAYWRIGHT_WS_URI,
+    RAG_WEB_LOADER_ENGINE,
     WHISPER_MODEL,
     DEEPGRAM_API_KEY,
     WHISPER_MODEL_AUTO_UPDATE,
@@ -578,7 +580,9 @@ app.state.config.EXA_API_KEY = EXA_API_KEY
 
 app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
 app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
+app.state.config.RAG_WEB_LOADER_ENGINE = RAG_WEB_LOADER_ENGINE
 app.state.config.RAG_WEB_SEARCH_TRUST_ENV = RAG_WEB_SEARCH_TRUST_ENV
+app.state.config.PLAYWRIGHT_WS_URI = PLAYWRIGHT_WS_URI
 
 app.state.EMBEDDING_FUNCTION = None
 app.state.ef = None

+ 216 - 29
backend/open_webui/retrieval/web/utils.py

@@ -1,23 +1,33 @@
-import socket
-import aiohttp
 import asyncio
+import logging
+import socket
+import ssl
 import urllib.parse
+import urllib.request
+from collections import defaultdict
+from datetime import datetime, time, timedelta
+from typing import (
+    Any,
+    AsyncIterator,
+    Dict,
+    Iterator,
+    List,
+    Optional,
+    Sequence,
+    Union
+)
+import aiohttp
+import certifi
 import validators
-from typing import Any, AsyncIterator, Dict, Iterator, List, Sequence, Union
-
-
 from langchain_community.document_loaders import (
-    WebBaseLoader,
+    PlaywrightURLLoader,
+    WebBaseLoader
 )
 from langchain_core.documents import Document
-
-
 from open_webui.constants import ERROR_MESSAGES
-from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH
+from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH, PLAYWRIGHT_WS_URI, RAG_WEB_LOADER_ENGINE
 from open_webui.env import SRC_LOG_LEVELS
 
-import logging
-
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
@@ -56,7 +66,6 @@ def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
             continue
     return valid_urls
 
-
 def resolve_hostname(hostname):
     # Get address information
     addr_info = socket.getaddrinfo(hostname, None)
@@ -67,6 +76,178 @@ def resolve_hostname(hostname):
 
     return ipv4_addresses, ipv6_addresses
 
+def extract_metadata(soup, url):
+    metadata = {
+        "source": url
+    }
+    if title := soup.find("title"):
+        metadata["title"] = title.get_text()
+    if description := soup.find("meta", attrs={"name": "description"}):
+        metadata["description"] = description.get(
+            "content", "No description found."
+        )
+    if html := soup.find("html"):
+        metadata["language"] = html.get("lang", "No language found.")
+    return metadata
+
+class SafePlaywrightURLLoader(PlaywrightURLLoader):
+    """Load HTML pages safely with Playwright, supporting SSL verification, rate limiting, and remote browser connection.
+    
+    Attributes:
+        web_paths (List[str]): List of URLs to load.
+        verify_ssl (bool): If True, verify SSL certificates.
+        trust_env (bool): If True, use proxy settings from environment variables.
+        requests_per_second (Optional[float]): Number of requests per second to limit to.
+        continue_on_failure (bool): If True, continue loading other URLs on failure.
+        headless (bool): If True, the browser will run in headless mode.
+        proxy (dict): Proxy override settings for the Playwright session.
+        playwright_ws_url (Optional[str]): WebSocket endpoint URI for remote browser connection.
+    """
+
+    def __init__(
+        self,
+        web_paths: List[str],
+        verify_ssl: bool = True,
+        trust_env: bool = False,
+        requests_per_second: Optional[float] = None,
+        continue_on_failure: bool = True,
+        headless: bool = True,
+        remove_selectors: Optional[List[str]] = None,
+        proxy: Optional[Dict[str, str]] = None,
+        playwright_ws_url: Optional[str] = None
+    ):
+        """Initialize with additional safety parameters and remote browser support."""
+
+        proxy_server = proxy.get('server') if proxy else None
+        if trust_env and not proxy_server:
+            env_proxies = urllib.request.getproxies()
+            env_proxy_server = env_proxies.get('https') or env_proxies.get('http')
+            if env_proxy_server:
+                if proxy:
+                    proxy['server'] = env_proxy_server
+                else:
+                    proxy = { 'server': env_proxy_server }
+
+        # We'll set headless to False if using playwright_ws_url since it's handled by the remote browser
+        super().__init__(
+            urls=web_paths,
+            continue_on_failure=continue_on_failure,
+            headless=headless if playwright_ws_url is None else False,
+            remove_selectors=remove_selectors,
+            proxy=proxy
+        )
+        self.verify_ssl = verify_ssl
+        self.requests_per_second = requests_per_second
+        self.last_request_time = None
+        self.playwright_ws_url = playwright_ws_url
+        self.trust_env = trust_env
+
+    def lazy_load(self) -> Iterator[Document]:
+        """Safely load URLs synchronously with support for remote browser."""
+        from playwright.sync_api import sync_playwright
+
+        with sync_playwright() as p:
+            # Use remote browser if ws_endpoint is provided, otherwise use local browser
+            if self.playwright_ws_url:
+                browser = p.chromium.connect(self.playwright_ws_url)
+            else:
+                browser = p.chromium.launch(headless=self.headless, proxy=self.proxy)
+
+            for url in self.urls:
+                try:
+                    self._safe_process_url_sync(url)
+                    page = browser.new_page()
+                    response = page.goto(url)
+                    if response is None:
+                        raise ValueError(f"page.goto() returned None for url {url}")
+
+                    text = self.evaluator.evaluate(page, browser, response)
+                    metadata = {"source": url}
+                    yield Document(page_content=text, metadata=metadata)
+                except Exception as e:
+                    if self.continue_on_failure:
+                        log.exception(e, "Error loading %s", url)
+                        continue
+                    raise e
+            browser.close()
+
+    async def alazy_load(self) -> AsyncIterator[Document]:
+        """Safely load URLs asynchronously with support for remote browser."""
+        from playwright.async_api import async_playwright
+
+        async with async_playwright() as p:
+            # Use remote browser if ws_endpoint is provided, otherwise use local browser
+            if self.playwright_ws_url:
+                browser = await p.chromium.connect(self.playwright_ws_url)
+            else:
+                browser = await p.chromium.launch(headless=self.headless, proxy=self.proxy)
+
+            for url in self.urls:
+                try:
+                    await self._safe_process_url(url)
+                    page = await browser.new_page()
+                    response = await page.goto(url)
+                    if response is None:
+                        raise ValueError(f"page.goto() returned None for url {url}")
+
+                    text = await self.evaluator.evaluate_async(page, browser, response)
+                    metadata = {"source": url}
+                    yield Document(page_content=text, metadata=metadata)
+                except Exception as e:
+                    if self.continue_on_failure:
+                        log.exception(e, "Error loading %s", url)
+                        continue
+                    raise e
+            await browser.close()
+
+    def _verify_ssl_cert(self, url: str) -> bool:
+        """Verify SSL certificate for the given URL."""
+        if not url.startswith("https://"):
+            return True
+            
+        try:
+            hostname = url.split("://")[-1].split("/")[0]
+            context = ssl.create_default_context(cafile=certifi.where())
+            with context.wrap_socket(ssl.socket(), server_hostname=hostname) as s:
+                s.connect((hostname, 443))
+            return True
+        except ssl.SSLError:
+            return False
+        except Exception as e:
+            log.warning(f"SSL verification failed for {url}: {str(e)}")
+            return False
+
+    async def _wait_for_rate_limit(self):
+        """Wait to respect the rate limit if specified."""
+        if self.requests_per_second and self.last_request_time:
+            min_interval = timedelta(seconds=1.0 / self.requests_per_second)
+            time_since_last = datetime.now() - self.last_request_time
+            if time_since_last < min_interval:
+                await asyncio.sleep((min_interval - time_since_last).total_seconds())
+        self.last_request_time = datetime.now()
+
+    def _sync_wait_for_rate_limit(self):
+        """Synchronous version of rate limit wait."""
+        if self.requests_per_second and self.last_request_time:
+            min_interval = timedelta(seconds=1.0 / self.requests_per_second)
+            time_since_last = datetime.now() - self.last_request_time
+            if time_since_last < min_interval:
+                time.sleep((min_interval - time_since_last).total_seconds())
+        self.last_request_time = datetime.now()
+
+    async def _safe_process_url(self, url: str) -> bool:
+        """Perform safety checks before processing a URL."""
+        if self.verify_ssl and not self._verify_ssl_cert(url):
+            raise ValueError(f"SSL certificate verification failed for {url}")
+        await self._wait_for_rate_limit()
+        return True
+
+    def _safe_process_url_sync(self, url: str) -> bool:
+        """Synchronous version of safety checks."""
+        if self.verify_ssl and not self._verify_ssl_cert(url):
+            raise ValueError(f"SSL certificate verification failed for {url}")
+        self._sync_wait_for_rate_limit()
+        return True
 
 class SafeWebBaseLoader(WebBaseLoader):
     """WebBaseLoader with enhanced error handling for URLs."""
@@ -143,20 +324,12 @@ class SafeWebBaseLoader(WebBaseLoader):
                 text = soup.get_text(**self.bs_get_text_kwargs)
 
                 # Build metadata
-                metadata = {"source": path}
-                if title := soup.find("title"):
-                    metadata["title"] = title.get_text()
-                if description := soup.find("meta", attrs={"name": "description"}):
-                    metadata["description"] = description.get(
-                        "content", "No description found."
-                    )
-                if html := soup.find("html"):
-                    metadata["language"] = html.get("lang", "No language found.")
+                metadata = extract_metadata(soup, path)
 
                 yield Document(page_content=text, metadata=metadata)
             except Exception as e:
                 # Log the error and continue with the next URL
-                log.error(f"Error loading {path}: {e}")
+                log.exception(e, "Error loading %s", path)
 
     async def alazy_load(self) -> AsyncIterator[Document]:
         """Async lazy load text from the url(s) in web_path."""
@@ -178,6 +351,9 @@ class SafeWebBaseLoader(WebBaseLoader):
         """Load data into Document objects."""
         return [document async for document in self.alazy_load()]
 
+RAG_WEB_LOADER_ENGINES = defaultdict(lambda: SafeWebBaseLoader)
+RAG_WEB_LOADER_ENGINES["playwright"] = SafePlaywrightURLLoader
+RAG_WEB_LOADER_ENGINES["safe_web"] = SafeWebBaseLoader
 
 def get_web_loader(
     urls: Union[str, Sequence[str]],
@@ -188,10 +364,21 @@ def get_web_loader(
     # Check if the URLs are valid
     safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
 
-    return SafeWebBaseLoader(
-        web_path=safe_urls,
-        verify_ssl=verify_ssl,
-        requests_per_second=requests_per_second,
-        continue_on_failure=True,
-        trust_env=trust_env,
-    )
+    web_loader_args = {
+        "web_paths": safe_urls,
+        "verify_ssl": verify_ssl,
+        "requests_per_second": requests_per_second,
+        "continue_on_failure": True,
+        "trust_env": trust_env
+    }
+
+    if PLAYWRIGHT_WS_URI.value:
+        web_loader_args["playwright_ws_url"] = PLAYWRIGHT_WS_URI.value
+
+    # Create the appropriate WebLoader based on the configuration
+    WebLoaderClass = RAG_WEB_LOADER_ENGINES[RAG_WEB_LOADER_ENGINE.value]
+    web_loader = WebLoaderClass(**web_loader_args)
+
+    log.debug("Using RAG_WEB_LOADER_ENGINE %s for %s URLs", web_loader.__class__.__name__, len(safe_urls))
+
+    return web_loader

+ 1 - 1
backend/open_webui/routers/retrieval.py

@@ -1379,7 +1379,7 @@ async def process_web_search(
                 docs,
                 collection_name,
                 overwrite=True,
-                user=user,
+                user=user
             )
 
             return {

+ 1 - 1
backend/open_webui/utils/middleware.py

@@ -344,7 +344,7 @@ async def chat_web_search_handler(
                     "query": searchQuery,
                 }
             ),
-            user,
+            user=user
         )
 
         if results:

+ 2 - 2
backend/requirements.txt

@@ -45,7 +45,7 @@ chromadb==0.6.2
 pymilvus==2.5.0
 qdrant-client~=1.12.0
 opensearch-py==2.8.0
-
+playwright==1.49.1 # Caution: version must match docker-compose.playwright.yaml
 
 transformers
 sentence-transformers==3.3.1
@@ -59,7 +59,7 @@ fpdf2==2.8.2
 pymdown-extensions==10.14.2
 docx2txt==0.8
 python-pptx==1.0.0
-unstructured==0.16.11
+unstructured==0.16.17
 nltk==3.9.1
 Markdown==3.7
 pypandoc==1.13

+ 11 - 0
backend/start.sh

@@ -3,6 +3,17 @@
 SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
 cd "$SCRIPT_DIR" || exit
 
+# Add conditional Playwright browser installation
+if [[ "${RAG_WEB_LOADER_ENGINE,,}" == "playwright" ]]; then
+    if [[ -z "${PLAYWRIGHT_WS_URI}" ]]; then
+        echo "Installing Playwright browsers..."
+        playwright install chromium
+        playwright install-deps chromium
+    fi
+
+    python -c "import nltk; nltk.download('punkt_tab')"
+fi
+
 KEY_FILE=.webui_secret_key
 
 PORT="${PORT:-8080}"

+ 11 - 0
backend/start_windows.bat

@@ -6,6 +6,17 @@ SETLOCAL ENABLEDELAYEDEXPANSION
 SET "SCRIPT_DIR=%~dp0"
 cd /d "%SCRIPT_DIR%" || exit /b
 
+:: Add conditional Playwright browser installation
+IF /I "%RAG_WEB_LOADER_ENGINE%" == "playwright" (
+    IF "%PLAYWRIGHT_WS_URI%" == "" (
+        echo Installing Playwright browsers...
+        playwright install chromium
+        playwright install-deps chromium
+    )
+
+    python -c "import nltk; nltk.download('punkt_tab')"
+)
+
 SET "KEY_FILE=.webui_secret_key"
 IF "%PORT%"=="" SET PORT=8080
 IF "%HOST%"=="" SET HOST=0.0.0.0

+ 10 - 0
docker-compose.playwright.yaml

@@ -0,0 +1,10 @@
+services:
+  playwright:
+    image: mcr.microsoft.com/playwright:v1.49.1-noble # Version must match requirements.txt
+    container_name: playwright
+    command: npx -y playwright@1.49.1 run-server --port 3000 --host 0.0.0.0
+
+  open-webui:
+    environment:
+      - 'RAG_WEB_LOADER_ENGINE=playwright'
+      - 'PLAYWRIGHT_WS_URI=ws://playwright:3000'

+ 2 - 1
pyproject.toml

@@ -53,6 +53,7 @@ dependencies = [
     "pymilvus==2.5.0",
     "qdrant-client~=1.12.0",
     "opensearch-py==2.8.0",
+    "playwright==1.49.1",
 
     "transformers",
     "sentence-transformers==3.3.1",
@@ -65,7 +66,7 @@ dependencies = [
     "pymdown-extensions==10.14.2",
     "docx2txt==0.8",
     "python-pptx==1.0.0",
-    "unstructured==0.16.11",
+    "unstructured==0.16.17",
     "nltk==3.9.1",
     "Markdown==3.7",
     "pypandoc==1.13",

+ 9 - 0
run-compose.sh

@@ -74,6 +74,7 @@ usage() {
     echo "  --enable-api[port=PORT]    Enable API and expose it on the specified port."
     echo "  --webui[port=PORT]         Set the port for the web user interface."
     echo "  --data[folder=PATH]        Bind mount for ollama data folder (by default will create the 'ollama' volume)."
+    echo "  --playwright               Enable Playwright support for web scraping."
     echo "  --build                    Build the docker image before running the compose project."
     echo "  --drop                     Drop the compose project."
     echo "  -q, --quiet                Run script in headless mode."
@@ -100,6 +101,7 @@ webui_port=3000
 headless=false
 build_image=false
 kill_compose=false
+enable_playwright=false
 
 # Function to extract value from the parameter
 extract_value() {
@@ -129,6 +131,9 @@ while [[ $# -gt 0 ]]; do
             value=$(extract_value "$key")
             data_dir=${value:-"./ollama-data"}
             ;;
+        --playwright)
+            enable_playwright=true
+            ;;
         --drop)
             kill_compose=true
             ;;
@@ -182,6 +187,9 @@ else
         DEFAULT_COMPOSE_COMMAND+=" -f docker-compose.data.yaml"
         export OLLAMA_DATA_DIR=$data_dir # Set OLLAMA_DATA_DIR environment variable
     fi
+    if [[ $enable_playwright == true ]]; then
+        DEFAULT_COMPOSE_COMMAND+=" -f docker-compose.playwright.yaml"
+    fi
     if [[ -n $webui_port ]]; then
         export OPEN_WEBUI_PORT=$webui_port # Set OPEN_WEBUI_PORT environment variable
     fi
@@ -201,6 +209,7 @@ echo -e "   ${GREEN}${BOLD}GPU Count:${NC} ${OLLAMA_GPU_COUNT:-Not Enabled}"
 echo -e "   ${GREEN}${BOLD}WebAPI Port:${NC} ${OLLAMA_WEBAPI_PORT:-Not Enabled}"
 echo -e "   ${GREEN}${BOLD}Data Folder:${NC} ${data_dir:-Using ollama volume}"
 echo -e "   ${GREEN}${BOLD}WebUI Port:${NC} $webui_port"
+echo -e "   ${GREEN}${BOLD}Playwright:${NC} ${enable_playwright:-false}"
 echo
 
 if [[ $headless == true ]]; then