utils.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import socket
  2. import aiohttp
  3. import asyncio
  4. import urllib.parse
  5. import validators
  6. from typing import Any, AsyncIterator, Dict, Iterator, List, Sequence, Union
  7. from langchain_community.document_loaders import (
  8. WebBaseLoader,
  9. )
  10. from langchain_core.documents import Document
  11. from open_webui.constants import ERROR_MESSAGES
  12. from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH
  13. from open_webui.env import SRC_LOG_LEVELS
  14. import logging
  15. log = logging.getLogger(__name__)
  16. log.setLevel(SRC_LOG_LEVELS["RAG"])
  17. def validate_url(url: Union[str, Sequence[str]]):
  18. if isinstance(url, str):
  19. if isinstance(validators.url(url), validators.ValidationError):
  20. raise ValueError(ERROR_MESSAGES.INVALID_URL)
  21. if not ENABLE_RAG_LOCAL_WEB_FETCH:
  22. # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
  23. parsed_url = urllib.parse.urlparse(url)
  24. # Get IPv4 and IPv6 addresses
  25. ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
  26. # Check if any of the resolved addresses are private
  27. # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
  28. for ip in ipv4_addresses:
  29. if validators.ipv4(ip, private=True):
  30. raise ValueError(ERROR_MESSAGES.INVALID_URL)
  31. for ip in ipv6_addresses:
  32. if validators.ipv6(ip, private=True):
  33. raise ValueError(ERROR_MESSAGES.INVALID_URL)
  34. return True
  35. elif isinstance(url, Sequence):
  36. return all(validate_url(u) for u in url)
  37. else:
  38. return False
  39. def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
  40. valid_urls = []
  41. for u in url:
  42. try:
  43. if validate_url(u):
  44. valid_urls.append(u)
  45. except ValueError:
  46. continue
  47. return valid_urls
  48. def resolve_hostname(hostname):
  49. # Get address information
  50. addr_info = socket.getaddrinfo(hostname, None)
  51. # Extract IP addresses from address information
  52. ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
  53. ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
  54. return ipv4_addresses, ipv6_addresses
  55. class SafeWebBaseLoader(WebBaseLoader):
  56. """WebBaseLoader with enhanced error handling for URLs."""
  57. def _unpack_fetch_results(
  58. self, results: Any, urls: List[str], parser: Union[str, None] = None
  59. ) -> List[Any]:
  60. """Unpack fetch results into BeautifulSoup objects."""
  61. from bs4 import BeautifulSoup
  62. final_results = []
  63. for i, result in enumerate(results):
  64. url = urls[i]
  65. if parser is None:
  66. if url.endswith(".xml"):
  67. parser = "xml"
  68. else:
  69. parser = self.default_parser
  70. self._check_parser(parser)
  71. final_results.append(BeautifulSoup(result, parser, **self.bs_kwargs))
  72. return final_results
  73. async def ascrape_all(
  74. self, urls: List[str], parser: Union[str, None] = None
  75. ) -> List[Any]:
  76. """Async fetch all urls, then return soups for all results."""
  77. results = await self.fetch_all(urls)
  78. return self._unpack_fetch_results(results, urls, parser=parser)
  79. def lazy_load(self) -> Iterator[Document]:
  80. """Lazy load text from the url(s) in web_path with error handling."""
  81. for path in self.web_paths:
  82. try:
  83. soup = self._scrape(path, bs_kwargs=self.bs_kwargs)
  84. text = soup.get_text(**self.bs_get_text_kwargs)
  85. # Build metadata
  86. metadata = {"source": path}
  87. if title := soup.find("title"):
  88. metadata["title"] = title.get_text()
  89. if description := soup.find("meta", attrs={"name": "description"}):
  90. metadata["description"] = description.get(
  91. "content", "No description found."
  92. )
  93. if html := soup.find("html"):
  94. metadata["language"] = html.get("lang", "No language found.")
  95. yield Document(page_content=text, metadata=metadata)
  96. except Exception as e:
  97. # Log the error and continue with the next URL
  98. log.error(f"Error loading {path}: {e}")
  99. async def alazy_load(self) -> AsyncIterator[Document]:
  100. """Async lazy load text from the url(s) in web_path."""
  101. results = await self.ascrape_all(self.web_paths)
  102. for path, soup in zip(self.web_paths, results):
  103. text = soup.get_text(**self.bs_get_text_kwargs)
  104. metadata = {"source": path}
  105. if title := soup.find("title"):
  106. metadata["title"] = title.get_text()
  107. if description := soup.find("meta", attrs={"name": "description"}):
  108. metadata["description"] = description.get(
  109. "content", "No description found."
  110. )
  111. if html := soup.find("html"):
  112. metadata["language"] = html.get("lang", "No language found.")
  113. yield Document(page_content=text, metadata=metadata)
  114. async def aload(self) -> list[Document]:
  115. """Load data into Document objects."""
  116. return [document async for document in self.alazy_load()]
  117. def get_web_loader(
  118. urls: Union[str, Sequence[str]],
  119. verify_ssl: bool = True,
  120. requests_per_second: int = 2,
  121. ):
  122. # Check if the URLs are valid
  123. safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
  124. return SafeWebBaseLoader(
  125. safe_urls,
  126. verify_ssl=verify_ssl,
  127. requests_per_second=requests_per_second,
  128. continue_on_failure=True,
  129. )