Ver Fonte

feat: add ENABLE_LOCAL_WEB_FETCH to protect against SSRF attacks

Jun Siang Cheah há 1 ano atrás
pai
commit
1c4e63f71e
4 ficheiros alterados com 45 adições e 1 exclusões
  1. 38 1
      backend/apps/rag/main.py
  2. 2 0
      backend/config.py
  3. 4 0
      backend/constants.py
  4. 1 0
      backend/requirements.txt

+ 38 - 1
backend/apps/rag/main.py

@@ -31,6 +31,11 @@ from langchain_community.document_loaders import (
 )
 )
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 
+import validators
+import urllib.parse
+import socket
+
+
 from pydantic import BaseModel
 from pydantic import BaseModel
 from typing import Optional
 from typing import Optional
 import mimetypes
 import mimetypes
@@ -84,6 +89,7 @@ from config import (
     CHUNK_SIZE,
     CHUNK_SIZE,
     CHUNK_OVERLAP,
     CHUNK_OVERLAP,
     RAG_TEMPLATE,
     RAG_TEMPLATE,
+    ENABLE_LOCAL_WEB_FETCH,
 )
 )
 
 
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
@@ -454,7 +460,7 @@ def query_collection_handler(
 def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
 def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
     # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
     # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
     try:
     try:
-        loader = WebBaseLoader(form_data.url)
+        loader = get_web_loader(form_data.url)
         data = loader.load()
         data = loader.load()
 
 
         collection_name = form_data.collection_name
         collection_name = form_data.collection_name
@@ -475,6 +481,37 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
         )
         )
 
 
 
 
+def get_web_loader(url: str):
+    # Check if the URL is valid
+    if isinstance(validators.url(url), validators.ValidationError):
+        raise ValueError(ERROR_MESSAGES.INVALID_URL)
+    if not ENABLE_LOCAL_WEB_FETCH:
+        # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
+        parsed_url = urllib.parse.urlparse(url)
+        # Get IPv4 and IPv6 addresses
+        ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
+        # Check if any of the resolved addresses are private
+        # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
+        for ip in ipv4_addresses:
+            if validators.ipv4(ip, private=True):
+                raise ValueError(ERROR_MESSAGES.INVALID_URL)
+        for ip in ipv6_addresses:
+            if validators.ipv6(ip, private=True):
+                raise ValueError(ERROR_MESSAGES.INVALID_URL)
+    return WebBaseLoader(url)
+
+
+def resolve_hostname(hostname):
+    # Get address information
+    addr_info = socket.getaddrinfo(hostname, None)
+
+    # Extract IP addresses from address information
+    ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
+    ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
+
+    return ipv4_addresses, ipv6_addresses
+
+
 def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
 def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
 
 
     text_splitter = RecursiveCharacterTextSplitter(
     text_splitter = RecursiveCharacterTextSplitter(

+ 2 - 0
backend/config.py

@@ -520,6 +520,8 @@ RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE)
 RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
 RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
 RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY)
 RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY)
 
 
+ENABLE_LOCAL_WEB_FETCH = os.getenv("ENABLE_LOCAL_WEB_FETCH", "False").lower() == "true"
+
 ####################################
 ####################################
 # Transcribe
 # Transcribe
 ####################################
 ####################################

+ 4 - 0
backend/constants.py

@@ -71,3 +71,7 @@ class ERROR_MESSAGES(str, Enum):
     EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
     EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
 
 
     DB_NOT_SQLITE = "This feature is only available when running with SQLite databases."
     DB_NOT_SQLITE = "This feature is only available when running with SQLite databases."
+
+    INVALID_URL = (
+        "Oops! The URL you provided is invalid. Please double-check and try again."
+    )

+ 1 - 0
backend/requirements.txt

@@ -43,6 +43,7 @@ pandas
 openpyxl
 openpyxl
 pyxlsb
 pyxlsb
 xlrd
 xlrd
+validators
 
 
 opencv-python-headless
 opencv-python-headless
 rapidocr-onnxruntime
 rapidocr-onnxruntime