|
@@ -11,7 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
import os, shutil, logging, re
|
|
import os, shutil, logging, re
|
|
|
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
-from typing import List
|
|
|
|
|
|
+from typing import List, Union, Sequence
|
|
|
|
|
|
from chromadb.utils.batch_utils import create_batches
|
|
from chromadb.utils.batch_utils import create_batches
|
|
|
|
|
|
@@ -58,6 +58,7 @@ from apps.rag.utils import (
|
|
query_doc_with_hybrid_search,
|
|
query_doc_with_hybrid_search,
|
|
query_collection,
|
|
query_collection,
|
|
query_collection_with_hybrid_search,
|
|
query_collection_with_hybrid_search,
|
|
|
|
+ search_web,
|
|
)
|
|
)
|
|
|
|
|
|
from utils.misc import (
|
|
from utils.misc import (
|
|
@@ -186,6 +187,10 @@ class UrlForm(CollectionNameForm):
|
|
url: str
|
|
url: str
|
|
|
|
|
|
|
|
|
|
|
|
+class SearchForm(CollectionNameForm):
|
|
|
|
+ query: str
|
|
|
|
+
|
|
|
|
+
|
|
@app.get("/")
|
|
@app.get("/")
|
|
async def get_status():
|
|
async def get_status():
|
|
return {
|
|
return {
|
|
@@ -506,26 +511,37 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
-def get_web_loader(url: str):
|
|
|
|
|
|
+def get_web_loader(url: Union[str, Sequence[str]]):
|
|
# Check if the URL is valid
|
|
# Check if the URL is valid
|
|
- if isinstance(validators.url(url), validators.ValidationError):
|
|
|
|
|
|
+ if not validate_url(url):
|
|
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
|
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
|
- if not ENABLE_LOCAL_WEB_FETCH:
|
|
|
|
- # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
|
|
|
|
- parsed_url = urllib.parse.urlparse(url)
|
|
|
|
- # Get IPv4 and IPv6 addresses
|
|
|
|
- ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
|
|
|
|
- # Check if any of the resolved addresses are private
|
|
|
|
- # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
|
|
|
|
- for ip in ipv4_addresses:
|
|
|
|
- if validators.ipv4(ip, private=True):
|
|
|
|
- raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
|
|
|
- for ip in ipv6_addresses:
|
|
|
|
- if validators.ipv6(ip, private=True):
|
|
|
|
- raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
|
|
|
return WebBaseLoader(url)
|
|
return WebBaseLoader(url)
|
|
|
|
|
|
|
|
|
|
|
|
+def validate_url(url: Union[str, Sequence[str]]):
|
|
|
|
+ if isinstance(url, str):
|
|
|
|
+ if isinstance(validators.url(url), validators.ValidationError):
|
|
|
|
+ raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
|
|
|
+ if not ENABLE_LOCAL_WEB_FETCH:
|
|
|
|
+ # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
|
|
|
|
+ parsed_url = urllib.parse.urlparse(url)
|
|
|
|
+ # Get IPv4 and IPv6 addresses
|
|
|
|
+ ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
|
|
|
|
+ # Check if any of the resolved addresses are private
|
|
|
|
+ # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
|
|
|
|
+ for ip in ipv4_addresses:
|
|
|
|
+ if validators.ipv4(ip, private=True):
|
|
|
|
+ raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
|
|
|
+ for ip in ipv6_addresses:
|
|
|
|
+ if validators.ipv6(ip, private=True):
|
|
|
|
+ raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
|
|
|
+ return True
|
|
|
|
+ elif isinstance(url, Sequence):
|
|
|
|
+ return all(validate_url(u) for u in url)
|
|
|
|
+ else:
|
|
|
|
+ return False
|
|
|
|
+
|
|
|
|
+
|
|
def resolve_hostname(hostname):
|
|
def resolve_hostname(hostname):
|
|
# Get address information
|
|
# Get address information
|
|
addr_info = socket.getaddrinfo(hostname, None)
|
|
addr_info = socket.getaddrinfo(hostname, None)
|
|
@@ -537,6 +553,32 @@ def resolve_hostname(hostname):
|
|
return ipv4_addresses, ipv6_addresses
|
|
return ipv4_addresses, ipv6_addresses
|
|
|
|
|
|
|
|
|
|
|
|
+@app.post("/websearch")
|
|
|
|
+def store_websearch(form_data: SearchForm, user=Depends(get_current_user)):
|
|
|
|
+ try:
|
|
|
|
+ web_results = search_web(form_data.query)
|
|
|
|
+ urls = [result.link for result in web_results]
|
|
|
|
+ loader = get_web_loader(urls)
|
|
|
|
+ data = loader.load()
|
|
|
|
+
|
|
|
|
+ collection_name = form_data.collection_name
|
|
|
|
+ if collection_name == "":
|
|
|
|
+ collection_name = calculate_sha256_string(form_data.query)[:63]
|
|
|
|
+
|
|
|
|
+ store_data_in_vector_db(data, collection_name, overwrite=True)
|
|
|
|
+ return {
|
|
|
|
+ "status": True,
|
|
|
|
+ "collection_name": collection_name,
|
|
|
|
+ "filenames": urls,
|
|
|
|
+ }
|
|
|
|
+ except Exception as e:
|
|
|
|
+ log.exception(e)
|
|
|
|
+ raise HTTPException(
|
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
|
+ detail=ERROR_MESSAGES.DEFAULT(e),
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+
|
|
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
|
|
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
|
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(
|
|
text_splitter = RecursiveCharacterTextSplitter(
|