utils.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import socket
  2. import aiohttp
  3. import asyncio
  4. import urllib.parse
  5. import validators
  6. from typing import Union, Sequence, Iterator, Dict
  7. from langchain_community.document_loaders import (
  8. WebBaseLoader,
  9. )
  10. from langchain_core.documents import Document
  11. from open_webui.constants import ERROR_MESSAGES
  12. from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH
  13. from open_webui.env import SRC_LOG_LEVELS
  14. import logging
  15. log = logging.getLogger(__name__)
  16. log.setLevel(SRC_LOG_LEVELS["RAG"])
  17. def validate_url(url: Union[str, Sequence[str]]):
  18. if isinstance(url, str):
  19. if isinstance(validators.url(url), validators.ValidationError):
  20. raise ValueError(ERROR_MESSAGES.INVALID_URL)
  21. if not ENABLE_RAG_LOCAL_WEB_FETCH:
  22. # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
  23. parsed_url = urllib.parse.urlparse(url)
  24. # Get IPv4 and IPv6 addresses
  25. ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
  26. # Check if any of the resolved addresses are private
  27. # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
  28. for ip in ipv4_addresses:
  29. if validators.ipv4(ip, private=True):
  30. raise ValueError(ERROR_MESSAGES.INVALID_URL)
  31. for ip in ipv6_addresses:
  32. if validators.ipv6(ip, private=True):
  33. raise ValueError(ERROR_MESSAGES.INVALID_URL)
  34. return True
  35. elif isinstance(url, Sequence):
  36. return all(validate_url(u) for u in url)
  37. else:
  38. return False
  39. def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
  40. valid_urls = []
  41. for u in url:
  42. try:
  43. if validate_url(u):
  44. valid_urls.append(u)
  45. except ValueError:
  46. continue
  47. return valid_urls
  48. def resolve_hostname(hostname):
  49. # Get address information
  50. addr_info = socket.getaddrinfo(hostname, None)
  51. # Extract IP addresses from address information
  52. ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
  53. ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
  54. return ipv4_addresses, ipv6_addresses
  55. class SafeWebBaseLoader(WebBaseLoader):
  56. """WebBaseLoader with enhanced error handling for URLs."""
  57. def __init__(self, trust_env: bool = False, *args, **kwargs):
  58. """Initialize SafeWebBaseLoader
  59. Args:
  60. trust_env (bool, optional): set to True if using proxy to make web requests, for example
  61. using http(s)_proxy environment variables. Defaults to False.
  62. """
  63. super().__init__(*args, **kwargs)
  64. self.trust_env = trust_env
  65. async def _fetch(
  66. self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
  67. ) -> str:
  68. async with aiohttp.ClientSession(trust_env=self.trust_env) as session:
  69. for i in range(retries):
  70. try:
  71. kwargs: Dict = dict(
  72. headers=self.session.headers,
  73. cookies=self.session.cookies.get_dict(),
  74. )
  75. if not self.session.verify:
  76. kwargs["ssl"] = False
  77. async with session.get(
  78. url, **(self.requests_kwargs | kwargs)
  79. ) as response:
  80. if self.raise_for_status:
  81. response.raise_for_status()
  82. return await response.text()
  83. except aiohttp.ClientConnectionError as e:
  84. if i == retries - 1:
  85. raise
  86. else:
  87. log.warning(
  88. f"Error fetching {url} with attempt "
  89. f"{i + 1}/{retries}: {e}. Retrying..."
  90. )
  91. await asyncio.sleep(cooldown * backoff**i)
  92. raise ValueError("retry count exceeded")
  93. def lazy_load(self) -> Iterator[Document]:
  94. """Lazy load text from the url(s) in web_path with error handling."""
  95. for path in self.web_paths:
  96. try:
  97. soup = self._scrape(path, bs_kwargs=self.bs_kwargs)
  98. text = soup.get_text(**self.bs_get_text_kwargs)
  99. # Build metadata
  100. metadata = {"source": path}
  101. if title := soup.find("title"):
  102. metadata["title"] = title.get_text()
  103. if description := soup.find("meta", attrs={"name": "description"}):
  104. metadata["description"] = description.get(
  105. "content", "No description found."
  106. )
  107. if html := soup.find("html"):
  108. metadata["language"] = html.get("lang", "No language found.")
  109. yield Document(page_content=text, metadata=metadata)
  110. except Exception as e:
  111. # Log the error and continue with the next URL
  112. log.error(f"Error loading {path}: {e}")
  113. def get_web_loader(
  114. urls: Union[str, Sequence[str]],
  115. verify_ssl: bool = True,
  116. requests_per_second: int = 2,
  117. trust_env: bool = False,
  118. ):
  119. # Check if the URLs are valid
  120. safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
  121. return SafeWebBaseLoader(
  122. web_path=safe_urls,
  123. verify_ssl=verify_ssl,
  124. requests_per_second=requests_per_second,
  125. continue_on_failure=True,
  126. trust_env=trust_env
  127. )