|
@@ -31,6 +31,11 @@ from langchain_community.document_loaders import (
|
|
)
|
|
)
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
|
|
|
|
|
+import validators
|
|
|
|
+import urllib.parse
|
|
|
|
+import socket
|
|
|
|
+
|
|
|
|
+
|
|
from pydantic import BaseModel
|
|
from pydantic import BaseModel
|
|
from typing import Optional
|
|
from typing import Optional
|
|
import mimetypes
|
|
import mimetypes
|
|
@@ -84,6 +89,7 @@ from config import (
|
|
CHUNK_SIZE,
|
|
CHUNK_SIZE,
|
|
CHUNK_OVERLAP,
|
|
CHUNK_OVERLAP,
|
|
RAG_TEMPLATE,
|
|
RAG_TEMPLATE,
|
|
|
|
+ ENABLE_LOCAL_WEB_FETCH,
|
|
)
|
|
)
|
|
|
|
|
|
from constants import ERROR_MESSAGES
|
|
from constants import ERROR_MESSAGES
|
|
@@ -454,7 +460,7 @@ def query_collection_handler(
|
|
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
|
|
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
|
|
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
|
|
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
|
|
try:
|
|
try:
|
|
- loader = WebBaseLoader(form_data.url)
|
|
|
|
|
|
+ loader = get_web_loader(form_data.url)
|
|
data = loader.load()
|
|
data = loader.load()
|
|
|
|
|
|
collection_name = form_data.collection_name
|
|
collection_name = form_data.collection_name
|
|
@@ -475,6 +481,37 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
+def get_web_loader(url: str):
|
|
|
|
+ # Check if the URL is valid
|
|
|
|
+ if isinstance(validators.url(url), validators.ValidationError):
|
|
|
|
+ raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
|
|
|
+ if not ENABLE_LOCAL_WEB_FETCH:
|
|
|
|
+ # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
|
|
|
|
+ parsed_url = urllib.parse.urlparse(url)
|
|
|
|
+ # Get IPv4 and IPv6 addresses
|
|
|
|
+ ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
|
|
|
|
+ # Check if any of the resolved addresses are private
|
|
|
|
+ # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
|
|
|
|
+ for ip in ipv4_addresses:
|
|
|
|
+ if validators.ipv4(ip, private=True):
|
|
|
|
+ raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
|
|
|
+ for ip in ipv6_addresses:
|
|
|
|
+ if validators.ipv6(ip, private=True):
|
|
|
|
+ raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
|
|
|
+ return WebBaseLoader(url)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def resolve_hostname(hostname):
|
|
|
|
+ # Get address information
|
|
|
|
+ addr_info = socket.getaddrinfo(hostname, None)
|
|
|
|
+
|
|
|
|
+ # Extract IP addresses from address information
|
|
|
|
+ ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
|
|
|
|
+ ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
|
|
|
|
+
|
|
|
|
+ return ipv4_addresses, ipv6_addresses
|
|
|
|
+
|
|
|
|
+
|
|
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
|
|
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
|
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(
|
|
text_splitter = RecursiveCharacterTextSplitter(
|