Browse Source

Add RAG_WEB_LOADER + Playwright mode + improve stability of search

Rory 3 months ago
parent
commit
4e8b390682

+ 5 - 0
backend/open_webui/config.py

@@ -1712,6 +1712,11 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
     int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
 )
 
+RAG_WEB_LOADER = PersistentConfig(
+    "RAG_WEB_LOADER",
+    "rag.web.loader",
+    os.environ.get("RAG_WEB_LOADER", "safe_web")
+)
 
 ####################################
 # Images

+ 2 - 0
backend/open_webui/main.py

@@ -129,6 +129,7 @@ from open_webui.config import (
     AUDIO_TTS_VOICE,
     AUDIO_TTS_AZURE_SPEECH_REGION,
     AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
+    RAG_WEB_LOADER,
     WHISPER_MODEL,
     WHISPER_MODEL_AUTO_UPDATE,
     WHISPER_MODEL_DIR,
@@ -526,6 +527,7 @@ app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_K
 
 app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
 app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
+app.state.config.RAG_WEB_LOADER = RAG_WEB_LOADER
 
 app.state.EMBEDDING_FUNCTION = None
 app.state.ef = None

+ 4 - 0
backend/open_webui/retrieval/web/main.py

@@ -1,3 +1,5 @@
+import validators
+
 from typing import Optional
 from urllib.parse import urlparse
 
@@ -10,6 +12,8 @@ def get_filtered_results(results, filter_list):
     filtered_results = []
     for result in results:
         url = result.get("url") or result.get("link", "")
+        if not validators.url(url):
+            continue
         domain = urlparse(url).netloc
         if any(domain.endswith(filtered_domain) for filtered_domain in filter_list):
             filtered_results.append(result)

+ 160 - 19
backend/open_webui/retrieval/web/utils.py

@@ -1,16 +1,21 @@
+import asyncio
+from datetime import datetime, time, timedelta
 import socket
+import ssl
 import urllib.parse
+import certifi
 import validators
-from typing import Union, Sequence, Iterator
+from typing import AsyncIterator, Dict, List, Optional, Union, Sequence, Iterator
 
 from langchain_community.document_loaders import (
     WebBaseLoader,
+    PlaywrightURLLoader
 )
 from langchain_core.documents import Document
 
 
 from open_webui.constants import ERROR_MESSAGES
-from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH
+from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH, RAG_WEB_LOADER
 from open_webui.env import SRC_LOG_LEVELS
 
 import logging
@@ -42,6 +47,15 @@ def validate_url(url: Union[str, Sequence[str]]):
     else:
         return False
 
+def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
+    valid_urls = []
+    for u in url:
+        try:
+            if validate_url(u):
+                valid_urls.append(u)
+        except ValueError:
+            continue
+    return valid_urls
 
 def resolve_hostname(hostname):
     # Get address information
@@ -53,6 +67,131 @@ def resolve_hostname(hostname):
 
     return ipv4_addresses, ipv6_addresses
 
+def extract_metadata(soup, url):
+    metadata = {
+        "source": url
+    }
+    if title := soup.find("title"):
+        metadata["title"] = title.get_text()
+    if description := soup.find("meta", attrs={"name": "description"}):
+        metadata["description"] = description.get(
+            "content", "No description found."
+        )
+    if html := soup.find("html"):
+        metadata["language"] = html.get("lang", "No language found.")
+    return metadata
+
+class SafePlaywrightURLLoader(PlaywrightURLLoader):
+    """Load HTML pages safely with Playwright, supporting SSL verification and rate limiting.
+    
+    Attributes:
+        urls (List[str]): List of URLs to load.
+        verify_ssl (bool): If True, verify SSL certificates.
+        requests_per_second (Optional[float]): Number of requests per second to limit to.
+        continue_on_failure (bool): If True, continue loading other URLs on failure.
+        headless (bool): If True, the browser will run in headless mode.
+    """
+
+    def __init__(
+        self,
+        urls: List[str],
+        verify_ssl: bool = True,
+        requests_per_second: Optional[float] = None,
+        continue_on_failure: bool = True,
+        headless: bool = True,
+        remove_selectors: Optional[List[str]] = None,
+        proxy: Optional[Dict[str, str]] = None
+    ):
+        """Initialize with additional safety parameters."""
+        super().__init__(
+            urls=urls,
+            continue_on_failure=continue_on_failure,
+            headless=headless,
+            remove_selectors=remove_selectors,
+            proxy=proxy
+        )
+        self.verify_ssl = verify_ssl
+        self.requests_per_second = requests_per_second
+        self.last_request_time = None
+
+    def _verify_ssl_cert(self, url: str) -> bool:
+        """Verify SSL certificate for the given URL."""
+        if not url.startswith("https://"):
+            return True
+            
+        try:
+            hostname = url.split("://")[-1].split("/")[0]
+            context = ssl.create_default_context(cafile=certifi.where())
+            with context.wrap_socket(ssl.socket(), server_hostname=hostname) as s:
+                s.connect((hostname, 443))
+            return True
+        except ssl.SSLError:
+            return False
+        except Exception as e:
+            log.warning(f"SSL verification failed for {url}: {str(e)}")
+            return False
+
+    async def _wait_for_rate_limit(self):
+        """Wait to respect the rate limit if specified."""
+        if self.requests_per_second and self.last_request_time:
+            min_interval = timedelta(seconds=1.0 / self.requests_per_second)
+            time_since_last = datetime.now() - self.last_request_time
+            if time_since_last < min_interval:
+                await asyncio.sleep((min_interval - time_since_last).total_seconds())
+        self.last_request_time = datetime.now()
+
+    def _sync_wait_for_rate_limit(self):
+        """Synchronous version of rate limit wait."""
+        if self.requests_per_second and self.last_request_time:
+            min_interval = timedelta(seconds=1.0 / self.requests_per_second)
+            time_since_last = datetime.now() - self.last_request_time
+            if time_since_last < min_interval:
+                time.sleep((min_interval - time_since_last).total_seconds())
+        self.last_request_time = datetime.now()
+
+    async def _safe_process_url(self, url: str) -> bool:
+        """Perform safety checks before processing a URL."""
+        if self.verify_ssl and not self._verify_ssl_cert(url):
+            raise ValueError(f"SSL certificate verification failed for {url}")
+        await self._wait_for_rate_limit()
+        return True
+
+    def _safe_process_url_sync(self, url: str) -> bool:
+        """Synchronous version of safety checks."""
+        if self.verify_ssl and not self._verify_ssl_cert(url):
+            raise ValueError(f"SSL certificate verification failed for {url}")
+        self._sync_wait_for_rate_limit()
+        return True
+
+    async def alazy_load(self) -> AsyncIterator[Document]:
+        """Safely load URLs asynchronously."""
+        parent_iterator = super().alazy_load()
+        
+        async for document in parent_iterator:
+            url = document.metadata["source"]
+            try:
+                await self._safe_process_url(url)
+                yield document
+            except Exception as e:
+                if self.continue_on_failure:
+                    log.error(f"Error processing {url}, exception: {e}")
+                    continue
+                raise e
+
+    def lazy_load(self) -> Iterator[Document]:
+        """Safely load URLs synchronously."""
+        parent_iterator = super().lazy_load()
+        
+        for document in parent_iterator:
+            url = document.metadata["source"]
+            try:
+                self._safe_process_url_sync(url)
+                yield document
+            except Exception as e:
+                if self.continue_on_failure:
+                    log.error(f"Error processing {url}, exception: {e}")
+                    continue
+                raise e
 
 class SafeWebBaseLoader(WebBaseLoader):
     """WebBaseLoader with enhanced error handling for URLs."""
@@ -65,15 +204,7 @@ class SafeWebBaseLoader(WebBaseLoader):
                 text = soup.get_text(**self.bs_get_text_kwargs)
 
                 # Build metadata
-                metadata = {"source": path}
-                if title := soup.find("title"):
-                    metadata["title"] = title.get_text()
-                if description := soup.find("meta", attrs={"name": "description"}):
-                    metadata["description"] = description.get(
-                        "content", "No description found."
-                    )
-                if html := soup.find("html"):
-                    metadata["language"] = html.get("lang", "No language found.")
+                metadata = extract_metadata(soup, path)
 
                 yield Document(page_content=text, metadata=metadata)
             except Exception as e:
@@ -87,11 +218,21 @@ def get_web_loader(
     requests_per_second: int = 2,
 ):
     # Check if the URL is valid
-    if not validate_url(urls):
-        raise ValueError(ERROR_MESSAGES.INVALID_URL)
-    return SafeWebBaseLoader(
-        urls,
-        verify_ssl=verify_ssl,
-        requests_per_second=requests_per_second,
-        continue_on_failure=True,
-    )
+    safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
+
+    if RAG_WEB_LOADER.value == "chromium":
+        log.info("Using SafePlaywrightURLLoader")
+        return SafePlaywrightURLLoader(
+            safe_urls,
+            verify_ssl=verify_ssl,
+            requests_per_second=requests_per_second,
+            continue_on_failure=True,
+        )
+    else:
+        log.info("Using SafeWebBaseLoader")
+        return SafeWebBaseLoader(
+            safe_urls,
+            verify_ssl=verify_ssl,
+            requests_per_second=requests_per_second,
+            continue_on_failure=True,
+        )

+ 18 - 3
backend/open_webui/routers/retrieval.py

@@ -1238,9 +1238,11 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
 
 
 @router.post("/process/web/search")
-def process_web_search(
-    request: Request, form_data: SearchForm, user=Depends(get_verified_user)
+async def process_web_search(
+    request: Request, form_data: SearchForm, extra_params: dict, user=Depends(get_verified_user)
 ):
+    event_emitter = extra_params["__event_emitter__"]
+    
     try:
         logging.info(
             f"trying to web search with {request.app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
@@ -1258,6 +1260,18 @@ def process_web_search(
 
     log.debug(f"web_results: {web_results}")
 
+    await event_emitter(
+        {
+            "type": "status",
+            "data": {
+                "action": "web_search",
+                "description": "Loading {{count}} sites...",
+                "urls": [result.link for result in web_results],
+                "done": False
+            },
+        }
+    )
+
     try:
         collection_name = form_data.collection_name
         if collection_name == "" or collection_name is None:
@@ -1271,7 +1285,8 @@ def process_web_search(
             verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
             requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
         )
-        docs = loader.load()
+        docs = [doc async for doc in loader.alazy_load()]
+        # docs = loader.load()
         save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
 
         return {

+ 11 - 16
backend/open_webui/utils/middleware.py

@@ -419,21 +419,16 @@ async def chat_web_search_handler(
 
     try:
 
-        # Offload process_web_search to a separate thread
-        loop = asyncio.get_running_loop()
-        with ThreadPoolExecutor() as executor:
-            results = await loop.run_in_executor(
-                executor,
-                lambda: process_web_search(
-                    request,
-                    SearchForm(
-                        **{
-                            "query": searchQuery,
-                        }
-                    ),
-                    user,
-                ),
-            )
+        results = await process_web_search(
+            request,
+            SearchForm(
+                **{
+                    "query": searchQuery,
+                }
+            ),
+            extra_params=extra_params,
+            user=user
+        )
 
         if results:
             await event_emitter(
@@ -441,7 +436,7 @@ async def chat_web_search_handler(
                     "type": "status",
                     "data": {
                         "action": "web_search",
-                        "description": "Searched {{count}} sites",
+                        "description": "Loaded {{count}} sites",
                         "query": searchQuery,
                         "urls": results["filenames"],
                         "done": True,

+ 1 - 1
backend/requirements.txt

@@ -46,7 +46,7 @@ chromadb==0.6.2
 pymilvus==2.5.0
 qdrant-client~=1.12.0
 opensearch-py==2.7.1
-
+playwright==1.49.1
 
 transformers
 sentence-transformers==3.3.1

+ 9 - 0
backend/start.sh

@@ -3,6 +3,15 @@
 SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
 cd "$SCRIPT_DIR" || exit
 
+# Add conditional Playwright browser installation
+if [[ "${RAG_WEB_LOADER,,}" == "chromium" ]]; then
+    echo "Installing Playwright browsers..."
+    playwright install chromium
+    playwright install-deps chromium
+
+    python -c "import nltk; nltk.download('punkt_tab')"
+fi
+
 KEY_FILE=.webui_secret_key
 
 PORT="${PORT:-8080}"

+ 9 - 0
backend/start_windows.bat

@@ -6,6 +6,15 @@ SETLOCAL ENABLEDELAYEDEXPANSION
 SET "SCRIPT_DIR=%~dp0"
 cd /d "%SCRIPT_DIR%" || exit /b
 
+:: Add conditional Playwright browser installation
+IF /I "%RAG_WEB_LOADER%" == "chromium" (
+    echo Installing Playwright browsers...
+    playwright install chromium
+    playwright install-deps chromium
+
+    python -c "import nltk; nltk.download('punkt_tab')"
+)
+
 SET "KEY_FILE=.webui_secret_key"
 IF "%PORT%"=="" SET PORT=8080
 IF "%HOST%"=="" SET HOST=0.0.0.0

+ 1 - 0
pyproject.toml

@@ -53,6 +53,7 @@ dependencies = [
     "pymilvus==2.5.0",
     "qdrant-client~=1.12.0",
     "opensearch-py==2.7.1",
+    "playwright==1.49.1",
 
     "transformers",
     "sentence-transformers==3.3.1",