Timothy J. Baek 10 月之前
父節點
當前提交
1163745a03
共有 1 個文件被更改,包括 22 次插入6 次删除
  1. 22 6
      backend/apps/rag/main.py

+ 22 - 6
backend/apps/rag/main.py

@@ -717,13 +717,18 @@ def validate_url(url: Union[str, Sequence[str]]):
         if isinstance(validators.url(url), validators.ValidationError):
             raise ValueError(ERROR_MESSAGES.INVALID_URL)
         if not ENABLE_RAG_LOCAL_WEB_FETCH:
-            # Check if the URL exists by making a HEAD request
-            try:
-                response = requests.head(url, allow_redirects=True)
-                if response.status_code != 200:
+            # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
+            parsed_url = urllib.parse.urlparse(url)
+            # Get IPv4 and IPv6 addresses
+            ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
+            # Check if any of the resolved addresses are private
+            # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
+            for ip in ipv4_addresses:
+                if validators.ipv4(ip, private=True):
+                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
+            for ip in ipv6_addresses:
+                if validators.ipv6(ip, private=True):
                     raise ValueError(ERROR_MESSAGES.INVALID_URL)
-            except requests.exceptions.RequestException:
-                raise ValueError(ERROR_MESSAGES.INVALID_URL)
         return True
     elif isinstance(url, Sequence):
         return all(validate_url(u) for u in url)
@@ -731,6 +736,17 @@ def validate_url(url: Union[str, Sequence[str]]):
         return False
 
 
+def resolve_hostname(hostname):
+    # Get address information
+    addr_info = socket.getaddrinfo(hostname, None)
+
+    # Extract IP addresses from address information
+    ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
+    ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
+
+    return ipv4_addresses, ipv6_addresses
+
+
 def search_web(engine: str, query: str) -> list[SearchResult]:
     """Search the web using a search engine and return the results as a list of SearchResult objects.
     Will look for a search engine API key in environment variables in the following order: