utils.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. import asyncio
  2. import logging
  3. import socket
  4. import ssl
  5. import urllib.parse
  6. import urllib.request
  7. from collections import defaultdict
  8. from datetime import datetime, time, timedelta
  9. from typing import (
  10. Any,
  11. AsyncIterator,
  12. Dict,
  13. Iterator,
  14. List,
  15. Optional,
  16. Sequence,
  17. Union
  18. )
  19. import aiohttp
  20. import certifi
  21. import validators
  22. from langchain_community.document_loaders import (
  23. PlaywrightURLLoader,
  24. WebBaseLoader
  25. )
  26. from langchain_core.documents import Document
  27. from open_webui.constants import ERROR_MESSAGES
  28. from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH, PLAYWRIGHT_WS_URI, RAG_WEB_LOADER
  29. from open_webui.env import SRC_LOG_LEVELS
  30. log = logging.getLogger(__name__)
  31. log.setLevel(SRC_LOG_LEVELS["RAG"])
  32. def validate_url(url: Union[str, Sequence[str]]):
  33. if isinstance(url, str):
  34. if isinstance(validators.url(url), validators.ValidationError):
  35. raise ValueError(ERROR_MESSAGES.INVALID_URL)
  36. if not ENABLE_RAG_LOCAL_WEB_FETCH:
  37. # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
  38. parsed_url = urllib.parse.urlparse(url)
  39. # Get IPv4 and IPv6 addresses
  40. ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
  41. # Check if any of the resolved addresses are private
  42. # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
  43. for ip in ipv4_addresses:
  44. if validators.ipv4(ip, private=True):
  45. raise ValueError(ERROR_MESSAGES.INVALID_URL)
  46. for ip in ipv6_addresses:
  47. if validators.ipv6(ip, private=True):
  48. raise ValueError(ERROR_MESSAGES.INVALID_URL)
  49. return True
  50. elif isinstance(url, Sequence):
  51. return all(validate_url(u) for u in url)
  52. else:
  53. return False
  54. def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
  55. valid_urls = []
  56. for u in url:
  57. try:
  58. if validate_url(u):
  59. valid_urls.append(u)
  60. except ValueError:
  61. continue
  62. return valid_urls
  63. def resolve_hostname(hostname):
  64. # Get address information
  65. addr_info = socket.getaddrinfo(hostname, None)
  66. # Extract IP addresses from address information
  67. ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
  68. ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
  69. return ipv4_addresses, ipv6_addresses
  70. def extract_metadata(soup, url):
  71. metadata = {
  72. "source": url
  73. }
  74. if title := soup.find("title"):
  75. metadata["title"] = title.get_text()
  76. if description := soup.find("meta", attrs={"name": "description"}):
  77. metadata["description"] = description.get(
  78. "content", "No description found."
  79. )
  80. if html := soup.find("html"):
  81. metadata["language"] = html.get("lang", "No language found.")
  82. return metadata
  83. class SafePlaywrightURLLoader(PlaywrightURLLoader):
  84. """Load HTML pages safely with Playwright, supporting SSL verification, rate limiting, and remote browser connection.
  85. Attributes:
  86. web_paths (List[str]): List of URLs to load.
  87. verify_ssl (bool): If True, verify SSL certificates.
  88. trust_env (bool): If True, use proxy settings from environment variables.
  89. requests_per_second (Optional[float]): Number of requests per second to limit to.
  90. continue_on_failure (bool): If True, continue loading other URLs on failure.
  91. headless (bool): If True, the browser will run in headless mode.
  92. proxy (dict): Proxy override settings for the Playwright session.
  93. playwright_ws_url (Optional[str]): WebSocket endpoint URI for remote browser connection.
  94. """
  95. def __init__(
  96. self,
  97. web_paths: List[str],
  98. verify_ssl: bool = True,
  99. trust_env: bool = False,
  100. requests_per_second: Optional[float] = None,
  101. continue_on_failure: bool = True,
  102. headless: bool = True,
  103. remove_selectors: Optional[List[str]] = None,
  104. proxy: Optional[Dict[str, str]] = None,
  105. playwright_ws_url: Optional[str] = None
  106. ):
  107. """Initialize with additional safety parameters and remote browser support."""
  108. proxy_server = proxy.get('server') if proxy else None
  109. if trust_env and not proxy_server:
  110. env_proxies = urllib.request.getproxies()
  111. env_proxy_server = env_proxies.get('https') or env_proxies.get('http')
  112. if env_proxy_server:
  113. if proxy:
  114. proxy['server'] = env_proxy_server
  115. else:
  116. proxy = { 'server': env_proxy_server }
  117. # We'll set headless to False if using playwright_ws_url since it's handled by the remote browser
  118. super().__init__(
  119. urls=web_paths,
  120. continue_on_failure=continue_on_failure,
  121. headless=headless if playwright_ws_url is None else False,
  122. remove_selectors=remove_selectors,
  123. proxy=proxy
  124. )
  125. self.verify_ssl = verify_ssl
  126. self.requests_per_second = requests_per_second
  127. self.last_request_time = None
  128. self.playwright_ws_url = playwright_ws_url
  129. self.trust_env = trust_env
  130. def lazy_load(self) -> Iterator[Document]:
  131. """Safely load URLs synchronously with support for remote browser."""
  132. from playwright.sync_api import sync_playwright
  133. with sync_playwright() as p:
  134. # Use remote browser if ws_endpoint is provided, otherwise use local browser
  135. if self.playwright_ws_url:
  136. browser = p.chromium.connect(self.playwright_ws_url)
  137. else:
  138. browser = p.chromium.launch(headless=self.headless, proxy=self.proxy)
  139. for url in self.urls:
  140. try:
  141. self._safe_process_url_sync(url)
  142. page = browser.new_page()
  143. response = page.goto(url)
  144. if response is None:
  145. raise ValueError(f"page.goto() returned None for url {url}")
  146. text = self.evaluator.evaluate(page, browser, response)
  147. metadata = {"source": url}
  148. yield Document(page_content=text, metadata=metadata)
  149. except Exception as e:
  150. if self.continue_on_failure:
  151. log.exception(e, "Error loading %s", url)
  152. continue
  153. raise e
  154. browser.close()
  155. async def alazy_load(self) -> AsyncIterator[Document]:
  156. """Safely load URLs asynchronously with support for remote browser."""
  157. from playwright.async_api import async_playwright
  158. async with async_playwright() as p:
  159. # Use remote browser if ws_endpoint is provided, otherwise use local browser
  160. if self.playwright_ws_url:
  161. browser = await p.chromium.connect(self.playwright_ws_url)
  162. else:
  163. browser = await p.chromium.launch(headless=self.headless, proxy=self.proxy)
  164. for url in self.urls:
  165. try:
  166. await self._safe_process_url(url)
  167. page = await browser.new_page()
  168. response = await page.goto(url)
  169. if response is None:
  170. raise ValueError(f"page.goto() returned None for url {url}")
  171. text = await self.evaluator.evaluate_async(page, browser, response)
  172. metadata = {"source": url}
  173. yield Document(page_content=text, metadata=metadata)
  174. except Exception as e:
  175. if self.continue_on_failure:
  176. log.exception(e, "Error loading %s", url)
  177. continue
  178. raise e
  179. await browser.close()
  180. def _verify_ssl_cert(self, url: str) -> bool:
  181. """Verify SSL certificate for the given URL."""
  182. if not url.startswith("https://"):
  183. return True
  184. try:
  185. hostname = url.split("://")[-1].split("/")[0]
  186. context = ssl.create_default_context(cafile=certifi.where())
  187. with context.wrap_socket(ssl.socket(), server_hostname=hostname) as s:
  188. s.connect((hostname, 443))
  189. return True
  190. except ssl.SSLError:
  191. return False
  192. except Exception as e:
  193. log.warning(f"SSL verification failed for {url}: {str(e)}")
  194. return False
  195. async def _wait_for_rate_limit(self):
  196. """Wait to respect the rate limit if specified."""
  197. if self.requests_per_second and self.last_request_time:
  198. min_interval = timedelta(seconds=1.0 / self.requests_per_second)
  199. time_since_last = datetime.now() - self.last_request_time
  200. if time_since_last < min_interval:
  201. await asyncio.sleep((min_interval - time_since_last).total_seconds())
  202. self.last_request_time = datetime.now()
  203. def _sync_wait_for_rate_limit(self):
  204. """Synchronous version of rate limit wait."""
  205. if self.requests_per_second and self.last_request_time:
  206. min_interval = timedelta(seconds=1.0 / self.requests_per_second)
  207. time_since_last = datetime.now() - self.last_request_time
  208. if time_since_last < min_interval:
  209. time.sleep((min_interval - time_since_last).total_seconds())
  210. self.last_request_time = datetime.now()
  211. async def _safe_process_url(self, url: str) -> bool:
  212. """Perform safety checks before processing a URL."""
  213. if self.verify_ssl and not self._verify_ssl_cert(url):
  214. raise ValueError(f"SSL certificate verification failed for {url}")
  215. await self._wait_for_rate_limit()
  216. return True
  217. def _safe_process_url_sync(self, url: str) -> bool:
  218. """Synchronous version of safety checks."""
  219. if self.verify_ssl and not self._verify_ssl_cert(url):
  220. raise ValueError(f"SSL certificate verification failed for {url}")
  221. self._sync_wait_for_rate_limit()
  222. return True
  223. class SafeWebBaseLoader(WebBaseLoader):
  224. """WebBaseLoader with enhanced error handling for URLs."""
  225. def __init__(self, trust_env: bool = False, *args, **kwargs):
  226. """Initialize SafeWebBaseLoader
  227. Args:
  228. trust_env (bool, optional): set to True if using proxy to make web requests, for example
  229. using http(s)_proxy environment variables. Defaults to False.
  230. """
  231. super().__init__(*args, **kwargs)
  232. self.trust_env = trust_env
  233. async def _fetch(
  234. self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
  235. ) -> str:
  236. async with aiohttp.ClientSession(trust_env=self.trust_env) as session:
  237. for i in range(retries):
  238. try:
  239. kwargs: Dict = dict(
  240. headers=self.session.headers,
  241. cookies=self.session.cookies.get_dict(),
  242. )
  243. if not self.session.verify:
  244. kwargs["ssl"] = False
  245. async with session.get(
  246. url, **(self.requests_kwargs | kwargs)
  247. ) as response:
  248. if self.raise_for_status:
  249. response.raise_for_status()
  250. return await response.text()
  251. except aiohttp.ClientConnectionError as e:
  252. if i == retries - 1:
  253. raise
  254. else:
  255. log.warning(
  256. f"Error fetching {url} with attempt "
  257. f"{i + 1}/{retries}: {e}. Retrying..."
  258. )
  259. await asyncio.sleep(cooldown * backoff**i)
  260. raise ValueError("retry count exceeded")
  261. def _unpack_fetch_results(
  262. self, results: Any, urls: List[str], parser: Union[str, None] = None
  263. ) -> List[Any]:
  264. """Unpack fetch results into BeautifulSoup objects."""
  265. from bs4 import BeautifulSoup
  266. final_results = []
  267. for i, result in enumerate(results):
  268. url = urls[i]
  269. if parser is None:
  270. if url.endswith(".xml"):
  271. parser = "xml"
  272. else:
  273. parser = self.default_parser
  274. self._check_parser(parser)
  275. final_results.append(BeautifulSoup(result, parser, **self.bs_kwargs))
  276. return final_results
  277. async def ascrape_all(
  278. self, urls: List[str], parser: Union[str, None] = None
  279. ) -> List[Any]:
  280. """Async fetch all urls, then return soups for all results."""
  281. results = await self.fetch_all(urls)
  282. return self._unpack_fetch_results(results, urls, parser=parser)
  283. def lazy_load(self) -> Iterator[Document]:
  284. """Lazy load text from the url(s) in web_path with error handling."""
  285. for path in self.web_paths:
  286. try:
  287. soup = self._scrape(path, bs_kwargs=self.bs_kwargs)
  288. text = soup.get_text(**self.bs_get_text_kwargs)
  289. # Build metadata
  290. metadata = extract_metadata(soup, path)
  291. yield Document(page_content=text, metadata=metadata)
  292. except Exception as e:
  293. # Log the error and continue with the next URL
  294. log.exception(e, "Error loading %s", path)
  295. async def alazy_load(self) -> AsyncIterator[Document]:
  296. """Async lazy load text from the url(s) in web_path."""
  297. results = await self.ascrape_all(self.web_paths)
  298. for path, soup in zip(self.web_paths, results):
  299. text = soup.get_text(**self.bs_get_text_kwargs)
  300. metadata = {"source": path}
  301. if title := soup.find("title"):
  302. metadata["title"] = title.get_text()
  303. if description := soup.find("meta", attrs={"name": "description"}):
  304. metadata["description"] = description.get(
  305. "content", "No description found."
  306. )
  307. if html := soup.find("html"):
  308. metadata["language"] = html.get("lang", "No language found.")
  309. yield Document(page_content=text, metadata=metadata)
  310. async def aload(self) -> list[Document]:
  311. """Load data into Document objects."""
  312. return [document async for document in self.alazy_load()]
  313. RAG_WEB_LOADERS = defaultdict(lambda: SafeWebBaseLoader)
  314. RAG_WEB_LOADERS["playwright"] = SafePlaywrightURLLoader
  315. RAG_WEB_LOADERS["safe_web"] = SafeWebBaseLoader
  316. def get_web_loader(
  317. urls: Union[str, Sequence[str]],
  318. verify_ssl: bool = True,
  319. requests_per_second: int = 2,
  320. trust_env: bool = False,
  321. ):
  322. # Check if the URLs are valid
  323. safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
  324. web_loader_args = {
  325. "web_paths": safe_urls,
  326. "verify_ssl": verify_ssl,
  327. "requests_per_second": requests_per_second,
  328. "continue_on_failure": True,
  329. "trust_env": trust_env
  330. }
  331. if PLAYWRIGHT_WS_URI.value:
  332. web_loader_args["playwright_ws_url"] = PLAYWRIGHT_WS_URI.value
  333. # Create the appropriate WebLoader based on the configuration
  334. WebLoaderClass = RAG_WEB_LOADERS[RAG_WEB_LOADER.value]
  335. web_loader = WebLoaderClass(**web_loader_args)
  336. log.debug("Using RAG_WEB_LOADER %s for %s URLs", web_loader.__class__.__name__, len(safe_urls))
  337. return web_loader