utils.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import socket
  2. import aiohttp
  3. import asyncio
  4. import urllib.parse
  5. import validators
  6. from typing import Any, AsyncIterator, Dict, Iterator, List, Sequence, Union
  7. from langchain_community.document_loaders import (
  8. WebBaseLoader,
  9. )
  10. from langchain_core.documents import Document
  11. from open_webui.constants import ERROR_MESSAGES
  12. from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH
  13. from open_webui.env import SRC_LOG_LEVELS
  14. import logging
  15. log = logging.getLogger(__name__)
  16. log.setLevel(SRC_LOG_LEVELS["RAG"])
  17. def validate_url(url: Union[str, Sequence[str]]):
  18. if isinstance(url, str):
  19. if isinstance(validators.url(url), validators.ValidationError):
  20. raise ValueError(ERROR_MESSAGES.INVALID_URL)
  21. if not ENABLE_RAG_LOCAL_WEB_FETCH:
  22. # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
  23. parsed_url = urllib.parse.urlparse(url)
  24. # Get IPv4 and IPv6 addresses
  25. ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
  26. # Check if any of the resolved addresses are private
  27. # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
  28. for ip in ipv4_addresses:
  29. if validators.ipv4(ip, private=True):
  30. raise ValueError(ERROR_MESSAGES.INVALID_URL)
  31. for ip in ipv6_addresses:
  32. if validators.ipv6(ip, private=True):
  33. raise ValueError(ERROR_MESSAGES.INVALID_URL)
  34. return True
  35. elif isinstance(url, Sequence):
  36. return all(validate_url(u) for u in url)
  37. else:
  38. return False
  39. def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
  40. valid_urls = []
  41. for u in url:
  42. try:
  43. if validate_url(u):
  44. valid_urls.append(u)
  45. except ValueError:
  46. continue
  47. return valid_urls
  48. def resolve_hostname(hostname):
  49. # Get address information
  50. addr_info = socket.getaddrinfo(hostname, None)
  51. # Extract IP addresses from address information
  52. ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
  53. ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
  54. return ipv4_addresses, ipv6_addresses
  55. class SafeWebBaseLoader(WebBaseLoader):
  56. """WebBaseLoader with enhanced error handling for URLs."""
  57. def __init__(self, trust_env: bool = False, *args, **kwargs):
  58. """Initialize SafeWebBaseLoader
  59. Args:
  60. trust_env (bool, optional): set to True if using proxy to make web requests, for example
  61. using http(s)_proxy environment variables. Defaults to False.
  62. """
  63. super().__init__(*args, **kwargs)
  64. self.trust_env = trust_env
  65. async def _fetch(
  66. self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
  67. ) -> str:
  68. async with aiohttp.ClientSession(trust_env=self.trust_env) as session:
  69. for i in range(retries):
  70. try:
  71. kwargs: Dict = dict(
  72. headers=self.session.headers,
  73. cookies=self.session.cookies.get_dict(),
  74. )
  75. if not self.session.verify:
  76. kwargs["ssl"] = False
  77. async with session.get(
  78. url, **(self.requests_kwargs | kwargs)
  79. ) as response:
  80. if self.raise_for_status:
  81. response.raise_for_status()
  82. return await response.text()
  83. except aiohttp.ClientConnectionError as e:
  84. if i == retries - 1:
  85. raise
  86. else:
  87. log.warning(
  88. f"Error fetching {url} with attempt "
  89. f"{i + 1}/{retries}: {e}. Retrying..."
  90. )
  91. await asyncio.sleep(cooldown * backoff**i)
  92. raise ValueError("retry count exceeded")
  93. def _unpack_fetch_results(
  94. self, results: Any, urls: List[str], parser: Union[str, None] = None
  95. ) -> List[Any]:
  96. """Unpack fetch results into BeautifulSoup objects."""
  97. from bs4 import BeautifulSoup
  98. final_results = []
  99. for i, result in enumerate(results):
  100. url = urls[i]
  101. if parser is None:
  102. if url.endswith(".xml"):
  103. parser = "xml"
  104. else:
  105. parser = self.default_parser
  106. self._check_parser(parser)
  107. final_results.append(BeautifulSoup(result, parser, **self.bs_kwargs))
  108. return final_results
  109. async def ascrape_all(
  110. self, urls: List[str], parser: Union[str, None] = None
  111. ) -> List[Any]:
  112. """Async fetch all urls, then return soups for all results."""
  113. results = await self.fetch_all(urls)
  114. return self._unpack_fetch_results(results, urls, parser=parser)
  115. def lazy_load(self) -> Iterator[Document]:
  116. """Lazy load text from the url(s) in web_path with error handling."""
  117. for path in self.web_paths:
  118. try:
  119. soup = self._scrape(path, bs_kwargs=self.bs_kwargs)
  120. text = soup.get_text(**self.bs_get_text_kwargs)
  121. # Build metadata
  122. metadata = {"source": path}
  123. if title := soup.find("title"):
  124. metadata["title"] = title.get_text()
  125. if description := soup.find("meta", attrs={"name": "description"}):
  126. metadata["description"] = description.get(
  127. "content", "No description found."
  128. )
  129. if html := soup.find("html"):
  130. metadata["language"] = html.get("lang", "No language found.")
  131. yield Document(page_content=text, metadata=metadata)
  132. except Exception as e:
  133. # Log the error and continue with the next URL
  134. log.error(f"Error loading {path}: {e}")
  135. async def alazy_load(self) -> AsyncIterator[Document]:
  136. """Async lazy load text from the url(s) in web_path."""
  137. results = await self.ascrape_all(self.web_paths)
  138. for path, soup in zip(self.web_paths, results):
  139. text = soup.get_text(**self.bs_get_text_kwargs)
  140. metadata = {"source": path}
  141. if title := soup.find("title"):
  142. metadata["title"] = title.get_text()
  143. if description := soup.find("meta", attrs={"name": "description"}):
  144. metadata["description"] = description.get(
  145. "content", "No description found."
  146. )
  147. if html := soup.find("html"):
  148. metadata["language"] = html.get("lang", "No language found.")
  149. yield Document(page_content=text, metadata=metadata)
  150. async def aload(self) -> list[Document]:
  151. """Load data into Document objects."""
  152. return [document async for document in self.alazy_load()]
  153. def get_web_loader(
  154. urls: Union[str, Sequence[str]],
  155. verify_ssl: bool = True,
  156. requests_per_second: int = 2,
  157. trust_env: bool = False,
  158. ):
  159. # Check if the URLs are valid
  160. safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
  161. return SafeWebBaseLoader(
  162. web_path=safe_urls,
  163. verify_ssl=verify_ssl,
  164. requests_per_second=requests_per_second,
  165. continue_on_failure=True,
  166. trust_env=trust_env
  167. )