Browse Source

feat: Add Tavily extract web loader integration

Luke 2 months ago
parent
commit
987954c817
2 changed files with 254 additions and 67 deletions
  1. 98 0
      backend/open_webui/retrieval/loaders/tavily.py
  2. 156 67
      backend/open_webui/retrieval/web/utils.py

+ 98 - 0
backend/open_webui/retrieval/loaders/tavily.py

@@ -0,0 +1,98 @@
+import requests
+import logging
+from typing import Iterator, List, Literal, Union
+
+from langchain_core.document_loaders import BaseLoader
+from langchain_core.documents import Document
+from open_webui.env import SRC_LOG_LEVELS
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["RAG"])
+
+class TavilyLoader(BaseLoader):
+    """Extract web page content from URLs using Tavily Extract API.
+    
+    This is a LangChain document loader that uses Tavily's Extract API to
+    retrieve content from web pages and return it as Document objects.
+    
+    Args:
+        urls: URL or list of URLs to extract content from.
+        api_key: The Tavily API key.
+        extract_depth: Depth of extraction, either "basic" or "advanced".
+        continue_on_failure: Whether to continue if extraction of a URL fails.
+    """
+    def __init__(
+        self,
+        urls: Union[str, List[str]],
+        api_key: str,
+        extract_depth: Literal["basic", "advanced"] = "basic",
+        continue_on_failure: bool = True,
+    ) -> None:
+        """Initialize Tavily Extract client.
+        
+        Args:
+            urls: URL or list of URLs to extract content from.
+            api_key: The Tavily API key.
+            include_images: Whether to include images in the extraction.
+            extract_depth: Depth of extraction, either "basic" or "advanced".
+                advanced extraction retrieves more data, including tables and 
+                embedded content, with higher success but may increase latency.
+                basic costs 1 credit per 5 successful URL extractions,
+                advanced costs 2 credits per 5 successful URL extractions.
+            continue_on_failure: Whether to continue if extraction of a URL fails.
+        """
+        if not urls:
+            raise ValueError("At least one URL must be provided.")
+            
+        self.api_key = api_key
+        self.urls = urls if isinstance(urls, list) else [urls]
+        self.extract_depth = extract_depth
+        self.continue_on_failure = continue_on_failure
+        self.api_url = "https://api.tavily.com/extract"
+        
+    def lazy_load(self) -> Iterator[Document]:
+        """Extract and yield documents from the URLs using Tavily Extract API."""
+        batch_size = 20
+        for i in range(0, len(self.urls), batch_size):
+            batch_urls = self.urls[i:i + batch_size]
+            try:
+                headers = {
+                    "Content-Type": "application/json",
+                    "Authorization": f"Bearer {self.api_key}"
+                }
+                # Use string for single URL, array for multiple URLs
+                urls_param = batch_urls[0] if len(batch_urls) == 1 else batch_urls
+                payload = {
+                    "urls": urls_param,
+                    "extract_depth": self.extract_depth
+                }
+                # Make the API call
+                response = requests.post(
+                    self.api_url,
+                    headers=headers,
+                    json=payload
+                )
+                response.raise_for_status()
+                response_data = response.json()
+                # Process successful results
+                for result in response_data.get("results", []):
+                    url = result.get("url", "")
+                    content = result.get("raw_content", "")
+                    if not content:
+                        log.warning(f"No content extracted from {url}")
+                        continue
+                    # Add URLs as metadata
+                    metadata = {"source": url}
+                    yield Document(
+                        page_content=content,
+                        metadata=metadata,
+                    )
+                for failed in response_data.get("failed_results", []):
+                    url = failed.get("url", "")
+                    error = failed.get("error", "Unknown error")
+                    log.error(f"Failed to extract content from {url}: {error}")
+            except Exception as e:
+                if self.continue_on_failure:
+                    log.error(f"Error extracting content from batch {batch_urls}: {e}")
+                else:
+                    raise e

+ 156 - 67
backend/open_webui/retrieval/web/utils.py

@@ -24,6 +24,7 @@ from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoa
 from langchain_community.document_loaders.firecrawl import FireCrawlLoader
 from langchain_community.document_loaders.base import BaseLoader
 from langchain_core.documents import Document
+from open_webui.retrieval.loaders.tavily import TavilyLoader
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.config import (
     ENABLE_RAG_LOCAL_WEB_FETCH,
@@ -31,6 +32,7 @@ from open_webui.config import (
     RAG_WEB_LOADER_ENGINE,
     FIRECRAWL_API_BASE_URL,
     FIRECRAWL_API_KEY,
+    TAVILY_API_KEY,
 )
 from open_webui.env import SRC_LOG_LEVELS
 
@@ -113,7 +115,47 @@ def verify_ssl_cert(url: str) -> bool:
         return False
 
 
-class SafeFireCrawlLoader(BaseLoader):
+class RateLimitMixin:
+    async def _wait_for_rate_limit(self):
+        """Wait to respect the rate limit if specified."""
+        if self.requests_per_second and self.last_request_time:
+            min_interval = timedelta(seconds=1.0 / self.requests_per_second)
+            time_since_last = datetime.now() - self.last_request_time
+            if time_since_last < min_interval:
+                await asyncio.sleep((min_interval - time_since_last).total_seconds())
+        self.last_request_time = datetime.now()
+
+    def _sync_wait_for_rate_limit(self):
+        """Synchronous version of rate limit wait."""
+        if self.requests_per_second and self.last_request_time:
+            min_interval = timedelta(seconds=1.0 / self.requests_per_second)
+            time_since_last = datetime.now() - self.last_request_time
+            if time_since_last < min_interval:
+                time.sleep((min_interval - time_since_last).total_seconds())
+        self.last_request_time = datetime.now()
+
+
+class URLProcessingMixin:  
+    def _verify_ssl_cert(self, url: str) -> bool:
+        """Verify SSL certificate for a URL."""
+        return verify_ssl_cert(url)
+        
+    async def _safe_process_url(self, url: str) -> bool:
+        """Perform safety checks before processing a URL."""
+        if self.verify_ssl and not self._verify_ssl_cert(url):
+            raise ValueError(f"SSL certificate verification failed for {url}")
+        await self._wait_for_rate_limit()
+        return True
+    
+    def _safe_process_url_sync(self, url: str) -> bool:
+        """Synchronous version of safety checks."""
+        if self.verify_ssl and not self._verify_ssl_cert(url):
+            raise ValueError(f"SSL certificate verification failed for {url}")
+        self._sync_wait_for_rate_limit()
+        return True
+
+
+class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
     def __init__(
         self,
         web_paths,
@@ -208,43 +250,120 @@ class SafeFireCrawlLoader(BaseLoader):
                     continue
                 raise e
 
-    def _verify_ssl_cert(self, url: str) -> bool:
-        return verify_ssl_cert(url)
-
-    async def _wait_for_rate_limit(self):
-        """Wait to respect the rate limit if specified."""
-        if self.requests_per_second and self.last_request_time:
-            min_interval = timedelta(seconds=1.0 / self.requests_per_second)
-            time_since_last = datetime.now() - self.last_request_time
-            if time_since_last < min_interval:
-                await asyncio.sleep((min_interval - time_since_last).total_seconds())
-        self.last_request_time = datetime.now()
 
-    def _sync_wait_for_rate_limit(self):
-        """Synchronous version of rate limit wait."""
-        if self.requests_per_second and self.last_request_time:
-            min_interval = timedelta(seconds=1.0 / self.requests_per_second)
-            time_since_last = datetime.now() - self.last_request_time
-            if time_since_last < min_interval:
-                time.sleep((min_interval - time_since_last).total_seconds())
-        self.last_request_time = datetime.now()
+class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
+    def __init__(
+        self,
+        web_paths: Union[str, List[str]],
+        api_key: str,
+        extract_depth: Literal["basic", "advanced"] = "basic",
+        continue_on_failure: bool = True,
+        requests_per_second: Optional[float] = None,
+        verify_ssl: bool = True,
+        trust_env: bool = False,
+        proxy: Optional[Dict[str, str]] = None,
+    ):
+        """Initialize SafeTavilyLoader with rate limiting and SSL verification support.
 
-    async def _safe_process_url(self, url: str) -> bool:
-        """Perform safety checks before processing a URL."""
-        if self.verify_ssl and not self._verify_ssl_cert(url):
-            raise ValueError(f"SSL certificate verification failed for {url}")
-        await self._wait_for_rate_limit()
-        return True
+        Args:
+            web_paths: List of URLs/paths to process.
+            api_key: The Tavily API key.
+            extract_depth: Depth of extraction ("basic" or "advanced").
+            continue_on_failure: Whether to continue if extraction of a URL fails.
+            requests_per_second: Number of requests per second to limit to.
+            verify_ssl: If True, verify SSL certificates.
+            trust_env: If True, use proxy settings from environment variables.
+            proxy: Optional proxy configuration.
+        """
+        # Initialize proxy configuration if using environment variables
+        proxy_server = proxy.get("server") if proxy else None
+        if trust_env and not proxy_server:
+            env_proxies = urllib.request.getproxies()
+            env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
+            if env_proxy_server:
+                if proxy:
+                    proxy["server"] = env_proxy_server
+                else:
+                    proxy = {"server": env_proxy_server}
+                    
+        # Store parameters for creating TavilyLoader instances
+        self.web_paths = web_paths if isinstance(web_paths, list) else [web_paths]
+        self.api_key = api_key
+        self.extract_depth = extract_depth
+        self.continue_on_failure = continue_on_failure
+        self.verify_ssl = verify_ssl
+        self.trust_env = trust_env
+        self.proxy = proxy
+        
+        # Add rate limiting
+        self.requests_per_second = requests_per_second
+        self.last_request_time = None
 
-    def _safe_process_url_sync(self, url: str) -> bool:
-        """Synchronous version of safety checks."""
-        if self.verify_ssl and not self._verify_ssl_cert(url):
-            raise ValueError(f"SSL certificate verification failed for {url}")
-        self._sync_wait_for_rate_limit()
-        return True
+    def lazy_load(self) -> Iterator[Document]:
+        """Load documents with rate limiting support, delegating to TavilyLoader."""
+        valid_urls = []
+        for url in self.web_paths:
+            try:
+                self._safe_process_url_sync(url)
+                valid_urls.append(url)
+            except Exception as e:
+                log.warning(f"SSL verification failed for {url}: {str(e)}")
+                if not self.continue_on_failure:
+                    raise e
+        if not valid_urls:
+            if self.continue_on_failure:
+                log.warning("No valid URLs to process after SSL verification")
+                return
+            raise ValueError("No valid URLs to process after SSL verification")
+        try:
+            loader = TavilyLoader(
+                urls=valid_urls,
+                api_key=self.api_key,
+                extract_depth=self.extract_depth,
+                continue_on_failure=self.continue_on_failure,
+            )
+            yield from loader.lazy_load()
+        except Exception as e:
+            if self.continue_on_failure:
+                log.exception(e, "Error extracting content from URLs")
+            else:
+                raise e
+    
+    async def alazy_load(self) -> AsyncIterator[Document]:
+        """Async version with rate limiting and SSL verification."""
+        valid_urls = []
+        for url in self.web_paths:
+            try:
+                await self._safe_process_url(url)
+                valid_urls.append(url)
+            except Exception as e:
+                log.warning(f"SSL verification failed for {url}: {str(e)}")
+                if not self.continue_on_failure:
+                    raise e
+        
+        if not valid_urls:
+            if self.continue_on_failure:
+                log.warning("No valid URLs to process after SSL verification")
+                return
+            raise ValueError("No valid URLs to process after SSL verification")
+        
+        try:
+            loader = TavilyLoader(
+                urls=valid_urls,
+                api_key=self.api_key,
+                extract_depth=self.extract_depth,
+                continue_on_failure=self.continue_on_failure,
+            )
+            async for document in loader.alazy_load():
+                yield document
+        except Exception as e:
+            if self.continue_on_failure:
+                log.exception(e, "Error loading URLs")
+            else:
+                raise e
 
 
-class SafePlaywrightURLLoader(PlaywrightURLLoader):
+class SafePlaywrightURLLoader(PlaywrightURLLoader, RateLimitMixin, URLProcessingMixin):
     """Load HTML pages safely with Playwright, supporting SSL verification, rate limiting, and remote browser connection.
 
     Attributes:
@@ -356,40 +475,6 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
                     raise e
             await browser.close()
 
-    def _verify_ssl_cert(self, url: str) -> bool:
-        return verify_ssl_cert(url)
-
-    async def _wait_for_rate_limit(self):
-        """Wait to respect the rate limit if specified."""
-        if self.requests_per_second and self.last_request_time:
-            min_interval = timedelta(seconds=1.0 / self.requests_per_second)
-            time_since_last = datetime.now() - self.last_request_time
-            if time_since_last < min_interval:
-                await asyncio.sleep((min_interval - time_since_last).total_seconds())
-        self.last_request_time = datetime.now()
-
-    def _sync_wait_for_rate_limit(self):
-        """Synchronous version of rate limit wait."""
-        if self.requests_per_second and self.last_request_time:
-            min_interval = timedelta(seconds=1.0 / self.requests_per_second)
-            time_since_last = datetime.now() - self.last_request_time
-            if time_since_last < min_interval:
-                time.sleep((min_interval - time_since_last).total_seconds())
-        self.last_request_time = datetime.now()
-
-    async def _safe_process_url(self, url: str) -> bool:
-        """Perform safety checks before processing a URL."""
-        if self.verify_ssl and not self._verify_ssl_cert(url):
-            raise ValueError(f"SSL certificate verification failed for {url}")
-        await self._wait_for_rate_limit()
-        return True
-
-    def _safe_process_url_sync(self, url: str) -> bool:
-        """Synchronous version of safety checks."""
-        if self.verify_ssl and not self._verify_ssl_cert(url):
-            raise ValueError(f"SSL certificate verification failed for {url}")
-        self._sync_wait_for_rate_limit()
-        return True
 
 
 class SafeWebBaseLoader(WebBaseLoader):
@@ -499,6 +584,7 @@ RAG_WEB_LOADER_ENGINES = defaultdict(lambda: SafeWebBaseLoader)
 RAG_WEB_LOADER_ENGINES["playwright"] = SafePlaywrightURLLoader
 RAG_WEB_LOADER_ENGINES["safe_web"] = SafeWebBaseLoader
 RAG_WEB_LOADER_ENGINES["firecrawl"] = SafeFireCrawlLoader
+RAG_WEB_LOADER_ENGINES["tavily"] = SafeTavilyLoader
 
 
 def get_web_loader(
@@ -525,6 +611,9 @@ def get_web_loader(
         web_loader_args["api_key"] = FIRECRAWL_API_KEY.value
         web_loader_args["api_url"] = FIRECRAWL_API_BASE_URL.value
 
+    if RAG_WEB_LOADER_ENGINE.value == "tavily":
+        web_loader_args["api_key"] = TAVILY_API_KEY.value
+
     # Create the appropriate WebLoader based on the configuration
     WebLoaderClass = RAG_WEB_LOADER_ENGINES[RAG_WEB_LOADER_ENGINE.value]
     web_loader = WebLoaderClass(**web_loader_args)