|
@@ -5,6 +5,7 @@ import urllib.parse
|
|
import validators
|
|
import validators
|
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Sequence, Union
|
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Sequence, Union
|
|
|
|
|
|
|
|
+
|
|
from langchain_community.document_loaders import (
|
|
from langchain_community.document_loaders import (
|
|
WebBaseLoader,
|
|
WebBaseLoader,
|
|
)
|
|
)
|
|
@@ -70,6 +71,45 @@ def resolve_hostname(hostname):
|
|
class SafeWebBaseLoader(WebBaseLoader):
|
|
class SafeWebBaseLoader(WebBaseLoader):
|
|
"""WebBaseLoader with enhanced error handling for URLs."""
|
|
"""WebBaseLoader with enhanced error handling for URLs."""
|
|
|
|
|
|
|
|
+ def __init__(self, trust_env: bool = False, *args, **kwargs):
|
|
|
|
+ """Initialize SafeWebBaseLoader
|
|
|
|
+ Args:
|
|
|
|
+ trust_env (bool, optional): set to True if using proxy to make web requests, for example
|
|
|
|
+ using http(s)_proxy environment variables. Defaults to False.
|
|
|
|
+ """
|
|
|
|
+ super().__init__(*args, **kwargs)
|
|
|
|
+ self.trust_env = trust_env
|
|
|
|
+
|
|
|
|
+ async def _fetch(
|
|
|
|
+ self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
|
|
|
|
+ ) -> str:
|
|
|
|
+ async with aiohttp.ClientSession(trust_env=self.trust_env) as session:
|
|
|
|
+ for i in range(retries):
|
|
|
|
+ try:
|
|
|
|
+ kwargs: Dict = dict(
|
|
|
|
+ headers=self.session.headers,
|
|
|
|
+ cookies=self.session.cookies.get_dict(),
|
|
|
|
+ )
|
|
|
|
+ if not self.session.verify:
|
|
|
|
+ kwargs["ssl"] = False
|
|
|
|
+
|
|
|
|
+ async with session.get(
|
|
|
|
+ url, **(self.requests_kwargs | kwargs)
|
|
|
|
+ ) as response:
|
|
|
|
+ if self.raise_for_status:
|
|
|
|
+ response.raise_for_status()
|
|
|
|
+ return await response.text()
|
|
|
|
+ except aiohttp.ClientConnectionError as e:
|
|
|
|
+ if i == retries - 1:
|
|
|
|
+ raise
|
|
|
|
+ else:
|
|
|
|
+ log.warning(
|
|
|
|
+ f"Error fetching {url} with attempt "
|
|
|
|
+ f"{i + 1}/{retries}: {e}. Retrying..."
|
|
|
|
+ )
|
|
|
|
+ await asyncio.sleep(cooldown * backoff**i)
|
|
|
|
+ raise ValueError("retry count exceeded")
|
|
|
|
+
|
|
def _unpack_fetch_results(
|
|
def _unpack_fetch_results(
|
|
self, results: Any, urls: List[str], parser: Union[str, None] = None
|
|
self, results: Any, urls: List[str], parser: Union[str, None] = None
|
|
) -> List[Any]:
|
|
) -> List[Any]:
|
|
@@ -95,6 +135,7 @@ class SafeWebBaseLoader(WebBaseLoader):
|
|
results = await self.fetch_all(urls)
|
|
results = await self.fetch_all(urls)
|
|
return self._unpack_fetch_results(results, urls, parser=parser)
|
|
return self._unpack_fetch_results(results, urls, parser=parser)
|
|
|
|
|
|
|
|
+
|
|
def lazy_load(self) -> Iterator[Document]:
|
|
def lazy_load(self) -> Iterator[Document]:
|
|
"""Lazy load text from the url(s) in web_path with error handling."""
|
|
"""Lazy load text from the url(s) in web_path with error handling."""
|
|
for path in self.web_paths:
|
|
for path in self.web_paths:
|
|
@@ -143,13 +184,15 @@ def get_web_loader(
|
|
urls: Union[str, Sequence[str]],
|
|
urls: Union[str, Sequence[str]],
|
|
verify_ssl: bool = True,
|
|
verify_ssl: bool = True,
|
|
requests_per_second: int = 2,
|
|
requests_per_second: int = 2,
|
|
|
|
+ trust_env: bool = False,
|
|
):
|
|
):
|
|
# Check if the URLs are valid
|
|
# Check if the URLs are valid
|
|
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
|
|
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
|
|
|
|
|
|
return SafeWebBaseLoader(
|
|
return SafeWebBaseLoader(
|
|
- safe_urls,
|
|
|
|
|
|
+ web_path=safe_urls,
|
|
verify_ssl=verify_ssl,
|
|
verify_ssl=verify_ssl,
|
|
requests_per_second=requests_per_second,
|
|
requests_per_second=requests_per_second,
|
|
continue_on_failure=True,
|
|
continue_on_failure=True,
|
|
|
|
+ trust_env=trust_env
|
|
)
|
|
)
|