Browse Source

feat: Add Tavily extract web loader integration

Luke 2 months ago
parent
commit
987954c817

+ 98 - 0
backend/open_webui/retrieval/loaders/tavily.py

@@ -0,0 +1,98 @@
+import requests
+import logging
+from typing import Iterator, List, Literal, Union
+
+from langchain_core.document_loaders import BaseLoader
+from langchain_core.documents import Document
+from open_webui.env import SRC_LOG_LEVELS
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["RAG"])
+
+class TavilyLoader(BaseLoader):
+    """Extract web page content from URLs using Tavily Extract API.
+    
+    This is a LangChain document loader that uses Tavily's Extract API to
+    retrieve content from web pages and return it as Document objects.
+    
+    Args:
+        urls: URL or list of URLs to extract content from.
+        api_key: The Tavily API key.
+        extract_depth: Depth of extraction, either "basic" or "advanced".
+        continue_on_failure: Whether to continue if extraction of a URL fails.
+    """
+    def __init__(
+        self,
+        urls: Union[str, List[str]],
+        api_key: str,
+        extract_depth: Literal["basic", "advanced"] = "basic",
+        continue_on_failure: bool = True,
+    ) -> None:
+        """Initialize Tavily Extract client.
+        
+        Args:
+            urls: URL or list of URLs to extract content from.
+            api_key: The Tavily API key.
+            include_images: Whether to include images in the extraction.
+            extract_depth: Depth of extraction, either "basic" or "advanced".
+                advanced extraction retrieves more data, including tables and 
+                embedded content, with higher success but may increase latency.
+                basic costs 1 credit per 5 successful URL extractions,
+                advanced costs 2 credits per 5 successful URL extractions.
+            continue_on_failure: Whether to continue if extraction of a URL fails.
+        """
+        if not urls:
+            raise ValueError("At least one URL must be provided.")
+            
+        self.api_key = api_key
+        self.urls = urls if isinstance(urls, list) else [urls]
+        self.extract_depth = extract_depth
+        self.continue_on_failure = continue_on_failure
+        self.api_url = "https://api.tavily.com/extract"
+        
+    def lazy_load(self) -> Iterator[Document]:
+        """Extract and yield documents from the URLs using Tavily Extract API."""
+        batch_size = 20
+        for i in range(0, len(self.urls), batch_size):
+            batch_urls = self.urls[i:i + batch_size]
+            try:
+                headers = {
+                    "Content-Type": "application/json",
+                    "Authorization": f"Bearer {self.api_key}"
+                }
+                # Use string for single URL, array for multiple URLs
+                urls_param = batch_urls[0] if len(batch_urls) == 1 else batch_urls
+                payload = {
+                    "urls": urls_param,
+                    "extract_depth": self.extract_depth
+                }
+                # Make the API call
+                response = requests.post(
+                    self.api_url,
+                    headers=headers,
+                    json=payload
+                )
+                response.raise_for_status()
+                response_data = response.json()
+                # Process successful results
+                for result in response_data.get("results", []):
+                    url = result.get("url", "")
+                    content = result.get("raw_content", "")
+                    if not content:
+                        log.warning(f"No content extracted from {url}")
+                        continue
+                    # Add URLs as metadata
+                    metadata = {"source": url}
+                    yield Document(
+                        page_content=content,
+                        metadata=metadata,
+                    )
+                for failed in response_data.get("failed_results", []):
+                    url = failed.get("url", "")
+                    error = failed.get("error", "Unknown error")
+                    log.error(f"Failed to extract content from {url}: {error}")
+            except Exception as e:
+                if self.continue_on_failure:
+                    log.error(f"Error extracting content from batch {batch_urls}: {e}")
+                else:
+                    raise e

+ 156 - 67
backend/open_webui/retrieval/web/utils.py

@@ -24,6 +24,7 @@ from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoa
 from langchain_community.document_loaders.firecrawl import FireCrawlLoader
 from langchain_community.document_loaders.firecrawl import FireCrawlLoader
 from langchain_community.document_loaders.base import BaseLoader
 from langchain_community.document_loaders.base import BaseLoader
 from langchain_core.documents import Document
 from langchain_core.documents import Document
+from open_webui.retrieval.loaders.tavily import TavilyLoader
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.config import (
 from open_webui.config import (
     ENABLE_RAG_LOCAL_WEB_FETCH,
     ENABLE_RAG_LOCAL_WEB_FETCH,
@@ -31,6 +32,7 @@ from open_webui.config import (
     RAG_WEB_LOADER_ENGINE,
     RAG_WEB_LOADER_ENGINE,
     FIRECRAWL_API_BASE_URL,
     FIRECRAWL_API_BASE_URL,
     FIRECRAWL_API_KEY,
     FIRECRAWL_API_KEY,
+    TAVILY_API_KEY,
 )
 )
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 
@@ -113,7 +115,47 @@ def verify_ssl_cert(url: str) -> bool:
         return False
         return False
 
 
 
 
-class SafeFireCrawlLoader(BaseLoader):
+class RateLimitMixin:
+    async def _wait_for_rate_limit(self):
+        """Wait to respect the rate limit if specified."""
+        if self.requests_per_second and self.last_request_time:
+            min_interval = timedelta(seconds=1.0 / self.requests_per_second)
+            time_since_last = datetime.now() - self.last_request_time
+            if time_since_last < min_interval:
+                await asyncio.sleep((min_interval - time_since_last).total_seconds())
+        self.last_request_time = datetime.now()
+
+    def _sync_wait_for_rate_limit(self):
+        """Synchronous version of rate limit wait."""
+        if self.requests_per_second and self.last_request_time:
+            min_interval = timedelta(seconds=1.0 / self.requests_per_second)
+            time_since_last = datetime.now() - self.last_request_time
+            if time_since_last < min_interval:
+                time.sleep((min_interval - time_since_last).total_seconds())
+        self.last_request_time = datetime.now()
+
+
+class URLProcessingMixin:  
+    def _verify_ssl_cert(self, url: str) -> bool:
+        """Verify SSL certificate for a URL."""
+        return verify_ssl_cert(url)
+        
+    async def _safe_process_url(self, url: str) -> bool:
+        """Perform safety checks before processing a URL."""
+        if self.verify_ssl and not self._verify_ssl_cert(url):
+            raise ValueError(f"SSL certificate verification failed for {url}")
+        await self._wait_for_rate_limit()
+        return True
+    
+    def _safe_process_url_sync(self, url: str) -> bool:
+        """Synchronous version of safety checks."""
+        if self.verify_ssl and not self._verify_ssl_cert(url):
+            raise ValueError(f"SSL certificate verification failed for {url}")
+        self._sync_wait_for_rate_limit()
+        return True
+
+
+class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
     def __init__(
     def __init__(
         self,
         self,
         web_paths,
         web_paths,
@@ -208,43 +250,120 @@ class SafeFireCrawlLoader(BaseLoader):
                     continue
                     continue
                 raise e
                 raise e
 
 
-    def _verify_ssl_cert(self, url: str) -> bool:
-        return verify_ssl_cert(url)
-
-    async def _wait_for_rate_limit(self):
-        """Wait to respect the rate limit if specified."""
-        if self.requests_per_second and self.last_request_time:
-            min_interval = timedelta(seconds=1.0 / self.requests_per_second)
-            time_since_last = datetime.now() - self.last_request_time
-            if time_since_last < min_interval:
-                await asyncio.sleep((min_interval - time_since_last).total_seconds())
-        self.last_request_time = datetime.now()
 
 
-    def _sync_wait_for_rate_limit(self):
-        """Synchronous version of rate limit wait."""
-        if self.requests_per_second and self.last_request_time:
-            min_interval = timedelta(seconds=1.0 / self.requests_per_second)
-            time_since_last = datetime.now() - self.last_request_time
-            if time_since_last < min_interval:
-                time.sleep((min_interval - time_since_last).total_seconds())
-        self.last_request_time = datetime.now()
+class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
+    def __init__(
+        self,
+        web_paths: Union[str, List[str]],
+        api_key: str,
+        extract_depth: Literal["basic", "advanced"] = "basic",
+        continue_on_failure: bool = True,
+        requests_per_second: Optional[float] = None,
+        verify_ssl: bool = True,
+        trust_env: bool = False,
+        proxy: Optional[Dict[str, str]] = None,
+    ):
+        """Initialize SafeTavilyLoader with rate limiting and SSL verification support.
 
 
-    async def _safe_process_url(self, url: str) -> bool:
-        """Perform safety checks before processing a URL."""
-        if self.verify_ssl and not self._verify_ssl_cert(url):
-            raise ValueError(f"SSL certificate verification failed for {url}")
-        await self._wait_for_rate_limit()
-        return True
+        Args:
+            web_paths: List of URLs/paths to process.
+            api_key: The Tavily API key.
+            extract_depth: Depth of extraction ("basic" or "advanced").
+            continue_on_failure: Whether to continue if extraction of a URL fails.
+            requests_per_second: Number of requests per second to limit to.
+            verify_ssl: If True, verify SSL certificates.
+            trust_env: If True, use proxy settings from environment variables.
+            proxy: Optional proxy configuration.
+        """
+        # Initialize proxy configuration if using environment variables
+        proxy_server = proxy.get("server") if proxy else None
+        if trust_env and not proxy_server:
+            env_proxies = urllib.request.getproxies()
+            env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
+            if env_proxy_server:
+                if proxy:
+                    proxy["server"] = env_proxy_server
+                else:
+                    proxy = {"server": env_proxy_server}
+                    
+        # Store parameters for creating TavilyLoader instances
+        self.web_paths = web_paths if isinstance(web_paths, list) else [web_paths]
+        self.api_key = api_key
+        self.extract_depth = extract_depth
+        self.continue_on_failure = continue_on_failure
+        self.verify_ssl = verify_ssl
+        self.trust_env = trust_env
+        self.proxy = proxy
+        
+        # Add rate limiting
+        self.requests_per_second = requests_per_second
+        self.last_request_time = None
 
 
-    def _safe_process_url_sync(self, url: str) -> bool:
-        """Synchronous version of safety checks."""
-        if self.verify_ssl and not self._verify_ssl_cert(url):
-            raise ValueError(f"SSL certificate verification failed for {url}")
-        self._sync_wait_for_rate_limit()
-        return True
+    def lazy_load(self) -> Iterator[Document]:
+        """Load documents with rate limiting support, delegating to TavilyLoader."""
+        valid_urls = []
+        for url in self.web_paths:
+            try:
+                self._safe_process_url_sync(url)
+                valid_urls.append(url)
+            except Exception as e:
+                log.warning(f"SSL verification failed for {url}: {str(e)}")
+                if not self.continue_on_failure:
+                    raise e
+        if not valid_urls:
+            if self.continue_on_failure:
+                log.warning("No valid URLs to process after SSL verification")
+                return
+            raise ValueError("No valid URLs to process after SSL verification")
+        try:
+            loader = TavilyLoader(
+                urls=valid_urls,
+                api_key=self.api_key,
+                extract_depth=self.extract_depth,
+                continue_on_failure=self.continue_on_failure,
+            )
+            yield from loader.lazy_load()
+        except Exception as e:
+            if self.continue_on_failure:
+                log.exception(e, "Error extracting content from URLs")
+            else:
+                raise e
+    
+    async def alazy_load(self) -> AsyncIterator[Document]:
+        """Async version with rate limiting and SSL verification."""
+        valid_urls = []
+        for url in self.web_paths:
+            try:
+                await self._safe_process_url(url)
+                valid_urls.append(url)
+            except Exception as e:
+                log.warning(f"SSL verification failed for {url}: {str(e)}")
+                if not self.continue_on_failure:
+                    raise e
+        
+        if not valid_urls:
+            if self.continue_on_failure:
+                log.warning("No valid URLs to process after SSL verification")
+                return
+            raise ValueError("No valid URLs to process after SSL verification")
+        
+        try:
+            loader = TavilyLoader(
+                urls=valid_urls,
+                api_key=self.api_key,
+                extract_depth=self.extract_depth,
+                continue_on_failure=self.continue_on_failure,
+            )
+            async for document in loader.alazy_load():
+                yield document
+        except Exception as e:
+            if self.continue_on_failure:
+                log.exception(e, "Error loading URLs")
+            else:
+                raise e
 
 
 
 
-class SafePlaywrightURLLoader(PlaywrightURLLoader):
+class SafePlaywrightURLLoader(PlaywrightURLLoader, RateLimitMixin, URLProcessingMixin):
     """Load HTML pages safely with Playwright, supporting SSL verification, rate limiting, and remote browser connection.
     """Load HTML pages safely with Playwright, supporting SSL verification, rate limiting, and remote browser connection.
 
 
     Attributes:
     Attributes:
@@ -356,40 +475,6 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
                     raise e
                     raise e
             await browser.close()
             await browser.close()
 
 
-    def _verify_ssl_cert(self, url: str) -> bool:
-        return verify_ssl_cert(url)
-
-    async def _wait_for_rate_limit(self):
-        """Wait to respect the rate limit if specified."""
-        if self.requests_per_second and self.last_request_time:
-            min_interval = timedelta(seconds=1.0 / self.requests_per_second)
-            time_since_last = datetime.now() - self.last_request_time
-            if time_since_last < min_interval:
-                await asyncio.sleep((min_interval - time_since_last).total_seconds())
-        self.last_request_time = datetime.now()
-
-    def _sync_wait_for_rate_limit(self):
-        """Synchronous version of rate limit wait."""
-        if self.requests_per_second and self.last_request_time:
-            min_interval = timedelta(seconds=1.0 / self.requests_per_second)
-            time_since_last = datetime.now() - self.last_request_time
-            if time_since_last < min_interval:
-                time.sleep((min_interval - time_since_last).total_seconds())
-        self.last_request_time = datetime.now()
-
-    async def _safe_process_url(self, url: str) -> bool:
-        """Perform safety checks before processing a URL."""
-        if self.verify_ssl and not self._verify_ssl_cert(url):
-            raise ValueError(f"SSL certificate verification failed for {url}")
-        await self._wait_for_rate_limit()
-        return True
-
-    def _safe_process_url_sync(self, url: str) -> bool:
-        """Synchronous version of safety checks."""
-        if self.verify_ssl and not self._verify_ssl_cert(url):
-            raise ValueError(f"SSL certificate verification failed for {url}")
-        self._sync_wait_for_rate_limit()
-        return True
 
 
 
 
 class SafeWebBaseLoader(WebBaseLoader):
 class SafeWebBaseLoader(WebBaseLoader):
@@ -499,6 +584,7 @@ RAG_WEB_LOADER_ENGINES = defaultdict(lambda: SafeWebBaseLoader)
 RAG_WEB_LOADER_ENGINES["playwright"] = SafePlaywrightURLLoader
 RAG_WEB_LOADER_ENGINES["playwright"] = SafePlaywrightURLLoader
 RAG_WEB_LOADER_ENGINES["safe_web"] = SafeWebBaseLoader
 RAG_WEB_LOADER_ENGINES["safe_web"] = SafeWebBaseLoader
 RAG_WEB_LOADER_ENGINES["firecrawl"] = SafeFireCrawlLoader
 RAG_WEB_LOADER_ENGINES["firecrawl"] = SafeFireCrawlLoader
+RAG_WEB_LOADER_ENGINES["tavily"] = SafeTavilyLoader
 
 
 
 
 def get_web_loader(
 def get_web_loader(
@@ -525,6 +611,9 @@ def get_web_loader(
         web_loader_args["api_key"] = FIRECRAWL_API_KEY.value
         web_loader_args["api_key"] = FIRECRAWL_API_KEY.value
         web_loader_args["api_url"] = FIRECRAWL_API_BASE_URL.value
         web_loader_args["api_url"] = FIRECRAWL_API_BASE_URL.value
 
 
+    if RAG_WEB_LOADER_ENGINE.value == "tavily":
+        web_loader_args["api_key"] = TAVILY_API_KEY.value
+
     # Create the appropriate WebLoader based on the configuration
     # Create the appropriate WebLoader based on the configuration
     WebLoaderClass = RAG_WEB_LOADER_ENGINES[RAG_WEB_LOADER_ENGINE.value]
     WebLoaderClass = RAG_WEB_LOADER_ENGINES[RAG_WEB_LOADER_ENGINE.value]
     web_loader = WebLoaderClass(**web_loader_args)
     web_loader = WebLoaderClass(**web_loader_args)