utils.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. import asyncio
  2. from datetime import datetime, time, timedelta
  3. import socket
  4. import ssl
  5. import urllib.parse
  6. import certifi
  7. import validators
  8. from collections import defaultdict
  9. from typing import AsyncIterator, Dict, List, Optional, Union, Sequence, Iterator
  10. from langchain_community.document_loaders import (
  11. WebBaseLoader,
  12. PlaywrightURLLoader
  13. )
  14. from langchain_core.documents import Document
  15. from open_webui.constants import ERROR_MESSAGES
  16. from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH, RAG_WEB_LOADER
  17. from open_webui.env import SRC_LOG_LEVELS
  18. import logging
  19. log = logging.getLogger(__name__)
  20. log.setLevel(SRC_LOG_LEVELS["RAG"])
  21. def validate_url(url: Union[str, Sequence[str]]):
  22. if isinstance(url, str):
  23. if isinstance(validators.url(url), validators.ValidationError):
  24. raise ValueError(ERROR_MESSAGES.INVALID_URL)
  25. if not ENABLE_RAG_LOCAL_WEB_FETCH:
  26. # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
  27. parsed_url = urllib.parse.urlparse(url)
  28. # Get IPv4 and IPv6 addresses
  29. ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
  30. # Check if any of the resolved addresses are private
  31. # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
  32. for ip in ipv4_addresses:
  33. if validators.ipv4(ip, private=True):
  34. raise ValueError(ERROR_MESSAGES.INVALID_URL)
  35. for ip in ipv6_addresses:
  36. if validators.ipv6(ip, private=True):
  37. raise ValueError(ERROR_MESSAGES.INVALID_URL)
  38. return True
  39. elif isinstance(url, Sequence):
  40. return all(validate_url(u) for u in url)
  41. else:
  42. return False
  43. def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
  44. valid_urls = []
  45. for u in url:
  46. try:
  47. if validate_url(u):
  48. valid_urls.append(u)
  49. except ValueError:
  50. continue
  51. return valid_urls
  52. def resolve_hostname(hostname):
  53. # Get address information
  54. addr_info = socket.getaddrinfo(hostname, None)
  55. # Extract IP addresses from address information
  56. ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
  57. ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
  58. return ipv4_addresses, ipv6_addresses
  59. def extract_metadata(soup, url):
  60. metadata = {
  61. "source": url
  62. }
  63. if title := soup.find("title"):
  64. metadata["title"] = title.get_text()
  65. if description := soup.find("meta", attrs={"name": "description"}):
  66. metadata["description"] = description.get(
  67. "content", "No description found."
  68. )
  69. if html := soup.find("html"):
  70. metadata["language"] = html.get("lang", "No language found.")
  71. return metadata
  72. class SafePlaywrightURLLoader(PlaywrightURLLoader):
  73. """Load HTML pages safely with Playwright, supporting SSL verification and rate limiting.
  74. Attributes:
  75. urls (List[str]): List of URLs to load.
  76. verify_ssl (bool): If True, verify SSL certificates.
  77. requests_per_second (Optional[float]): Number of requests per second to limit to.
  78. continue_on_failure (bool): If True, continue loading other URLs on failure.
  79. headless (bool): If True, the browser will run in headless mode.
  80. """
  81. def __init__(
  82. self,
  83. urls: List[str],
  84. verify_ssl: bool = True,
  85. requests_per_second: Optional[float] = None,
  86. continue_on_failure: bool = True,
  87. headless: bool = True,
  88. remove_selectors: Optional[List[str]] = None,
  89. proxy: Optional[Dict[str, str]] = None
  90. ):
  91. """Initialize with additional safety parameters."""
  92. super().__init__(
  93. urls=urls,
  94. continue_on_failure=continue_on_failure,
  95. headless=headless,
  96. remove_selectors=remove_selectors,
  97. proxy=proxy
  98. )
  99. self.verify_ssl = verify_ssl
  100. self.requests_per_second = requests_per_second
  101. self.last_request_time = None
  102. def _verify_ssl_cert(self, url: str) -> bool:
  103. """Verify SSL certificate for the given URL."""
  104. if not url.startswith("https://"):
  105. return True
  106. try:
  107. hostname = url.split("://")[-1].split("/")[0]
  108. context = ssl.create_default_context(cafile=certifi.where())
  109. with context.wrap_socket(ssl.socket(), server_hostname=hostname) as s:
  110. s.connect((hostname, 443))
  111. return True
  112. except ssl.SSLError:
  113. return False
  114. except Exception as e:
  115. log.warning(f"SSL verification failed for {url}: {str(e)}")
  116. return False
  117. async def _wait_for_rate_limit(self):
  118. """Wait to respect the rate limit if specified."""
  119. if self.requests_per_second and self.last_request_time:
  120. min_interval = timedelta(seconds=1.0 / self.requests_per_second)
  121. time_since_last = datetime.now() - self.last_request_time
  122. if time_since_last < min_interval:
  123. await asyncio.sleep((min_interval - time_since_last).total_seconds())
  124. self.last_request_time = datetime.now()
  125. def _sync_wait_for_rate_limit(self):
  126. """Synchronous version of rate limit wait."""
  127. if self.requests_per_second and self.last_request_time:
  128. min_interval = timedelta(seconds=1.0 / self.requests_per_second)
  129. time_since_last = datetime.now() - self.last_request_time
  130. if time_since_last < min_interval:
  131. time.sleep((min_interval - time_since_last).total_seconds())
  132. self.last_request_time = datetime.now()
  133. async def _safe_process_url(self, url: str) -> bool:
  134. """Perform safety checks before processing a URL."""
  135. if self.verify_ssl and not self._verify_ssl_cert(url):
  136. raise ValueError(f"SSL certificate verification failed for {url}")
  137. await self._wait_for_rate_limit()
  138. return True
  139. def _safe_process_url_sync(self, url: str) -> bool:
  140. """Synchronous version of safety checks."""
  141. if self.verify_ssl and not self._verify_ssl_cert(url):
  142. raise ValueError(f"SSL certificate verification failed for {url}")
  143. self._sync_wait_for_rate_limit()
  144. return True
  145. async def alazy_load(self) -> AsyncIterator[Document]:
  146. """Safely load URLs asynchronously."""
  147. parent_iterator = super().alazy_load()
  148. async for document in parent_iterator:
  149. url = document.metadata["source"]
  150. try:
  151. await self._safe_process_url(url)
  152. yield document
  153. except Exception as e:
  154. if self.continue_on_failure:
  155. log.exception(e, "Error loading %s", url)
  156. continue
  157. raise e
  158. def lazy_load(self) -> Iterator[Document]:
  159. """Safely load URLs synchronously."""
  160. parent_iterator = super().lazy_load()
  161. for document in parent_iterator:
  162. url = document.metadata["source"]
  163. try:
  164. self._safe_process_url_sync(url)
  165. yield document
  166. except Exception as e:
  167. if self.continue_on_failure:
  168. log.exception(e, "Error loading %s", url)
  169. continue
  170. raise e
  171. class SafeWebBaseLoader(WebBaseLoader):
  172. """WebBaseLoader with enhanced error handling for URLs."""
  173. def lazy_load(self) -> Iterator[Document]:
  174. """Lazy load text from the url(s) in web_path with error handling."""
  175. for path in self.web_paths:
  176. try:
  177. soup = self._scrape(path, bs_kwargs=self.bs_kwargs)
  178. text = soup.get_text(**self.bs_get_text_kwargs)
  179. # Build metadata
  180. metadata = extract_metadata(soup, path)
  181. yield Document(page_content=text, metadata=metadata)
  182. except Exception as e:
  183. # Log the error and continue with the next URL
  184. log.exception(e, "Error loading %s", path)
  185. RAG_WEB_LOADERS = defaultdict(lambda: SafeWebBaseLoader)
  186. RAG_WEB_LOADERS["playwright"] = SafePlaywrightURLLoader
  187. RAG_WEB_LOADERS["safe_web"] = SafeWebBaseLoader
  188. def get_web_loader(
  189. urls: Union[str, Sequence[str]],
  190. verify_ssl: bool = True,
  191. requests_per_second: int = 2,
  192. ):
  193. # Check if the URLs are valid
  194. safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
  195. # Get the appropriate WebLoader based on the configuration
  196. WebLoaderClass = RAG_WEB_LOADERS[RAG_WEB_LOADER.value]
  197. web_loader = WebLoaderClass(
  198. safe_urls,
  199. verify_ssl=verify_ssl,
  200. requests_per_second=requests_per_second,
  201. continue_on_failure=True,
  202. )
  203. log.debug("Using RAG_WEB_LOADER %s for %s URLs", web_loader.__class__.__name__, len(safe_urls))
  204. return web_loader