Browse Source

feat: add websearch endpoint to RAG API

fix: google PSE endpoint uses GET

fix: google PSE returns link, not url

fix: serper wrong field
Jun Siang Cheah 1 year ago
parent
commit
99e4edd364

+ 58 - 16
backend/apps/rag/main.py

@@ -11,7 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware
 import os, shutil, logging, re
 
 from pathlib import Path
-from typing import List
+from typing import List, Union, Sequence
 
 from chromadb.utils.batch_utils import create_batches
 
@@ -58,6 +58,7 @@ from apps.rag.utils import (
     query_doc_with_hybrid_search,
     query_collection,
     query_collection_with_hybrid_search,
+    search_web,
 )
 
 from utils.misc import (
@@ -186,6 +187,10 @@ class UrlForm(CollectionNameForm):
     url: str
 
 
+class SearchForm(CollectionNameForm):
+    query: str
+
+
 @app.get("/")
 async def get_status():
     return {
@@ -506,26 +511,37 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
         )
 
 
-def get_web_loader(url: str):
+def get_web_loader(url: Union[str, Sequence[str]]):
     # Check if the URL is valid
-    if isinstance(validators.url(url), validators.ValidationError):
+    if not validate_url(url):
         raise ValueError(ERROR_MESSAGES.INVALID_URL)
-    if not ENABLE_LOCAL_WEB_FETCH:
-        # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
-        parsed_url = urllib.parse.urlparse(url)
-        # Get IPv4 and IPv6 addresses
-        ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
-        # Check if any of the resolved addresses are private
-        # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
-        for ip in ipv4_addresses:
-            if validators.ipv4(ip, private=True):
-                raise ValueError(ERROR_MESSAGES.INVALID_URL)
-        for ip in ipv6_addresses:
-            if validators.ipv6(ip, private=True):
-                raise ValueError(ERROR_MESSAGES.INVALID_URL)
     return WebBaseLoader(url)
 
 
+def validate_url(url: Union[str, Sequence[str]]):
+    if isinstance(url, str):
+        if isinstance(validators.url(url), validators.ValidationError):
+            raise ValueError(ERROR_MESSAGES.INVALID_URL)
+        if not ENABLE_LOCAL_WEB_FETCH:
+            # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
+            parsed_url = urllib.parse.urlparse(url)
+            # Get IPv4 and IPv6 addresses
+            ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
+            # Check if any of the resolved addresses are private
+            # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
+            for ip in ipv4_addresses:
+                if validators.ipv4(ip, private=True):
+                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
+            for ip in ipv6_addresses:
+                if validators.ipv6(ip, private=True):
+                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
+        return True
+    elif isinstance(url, Sequence):
+        return all(validate_url(u) for u in url)
+    else:
+        return False
+
+
 def resolve_hostname(hostname):
     # Get address information
     addr_info = socket.getaddrinfo(hostname, None)
@@ -537,6 +553,32 @@ def resolve_hostname(hostname):
     return ipv4_addresses, ipv6_addresses
 
 
+@app.post("/websearch")
+def store_websearch(form_data: SearchForm, user=Depends(get_current_user)):
+    try:
+        web_results = search_web(form_data.query)
+        urls = [result.link for result in web_results]
+        loader = get_web_loader(urls)
+        data = loader.load()
+
+        collection_name = form_data.collection_name
+        if collection_name == "":
+            collection_name = calculate_sha256_string(form_data.query)[:63]
+
+        store_data_in_vector_db(data, collection_name, overwrite=True)
+        return {
+            "status": True,
+            "collection_name": collection_name,
+            "filenames": urls,
+        }
+    except Exception as e:
+        log.exception(e)
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.DEFAULT(e),
+        )
+
+
 def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
 
     text_splitter = RecursiveCharacterTextSplitter(

+ 4 - 2
backend/apps/rag/search/google_pse.py

@@ -30,14 +30,16 @@ def search_google_pse(
         "num": 5,
     }
 
-    response = requests.request("POST", url, headers=headers, params=params)
+    response = requests.request("GET", url, headers=headers, params=params)
     response.raise_for_status()
 
     json_response = response.json()
     results = json_response.get("items", [])
     return [
         SearchResult(
-            link=result["url"], title=result.get("title"), snippet=result.get("snippet")
+            link=result["link"],
+            title=result.get("title"),
+            snippet=result.get("snippet"),
         )
         for result in results
     ]

+ 2 - 2
backend/apps/rag/search/serper.py

@@ -31,9 +31,9 @@ def search_serper(api_key: str, query: str) -> list[SearchResult]:
     )
     return [
         SearchResult(
-            link=result["url"],
+            link=result["link"],
             title=result.get("title"),
             snippet=result.get("description"),
         )
-        for result in results
+        for result in results[:5]
     ]

+ 12 - 18
backend/apps/rag/utils.py

@@ -545,21 +545,15 @@ def search_web(query: str) -> list[SearchResult]:
     Args:
         query (str): The query to search for
     """
-    try:
-        if SEARXNG_QUERY_URL:
-            return search_searxng(SEARXNG_QUERY_URL, query)
-        elif GOOGLE_PSE_API_KEY and GOOGLE_PSE_ENGINE_ID:
-            return search_google_pse(GOOGLE_PSE_API_KEY, GOOGLE_PSE_ENGINE_ID, query)
-        elif BRAVE_SEARCH_API_KEY:
-            return search_brave(BRAVE_SEARCH_API_KEY, query)
-        elif SERPSTACK_API_KEY:
-            return search_serpstack(
-                SERPSTACK_API_KEY, query, https_enabled=SERPSTACK_HTTPS
-            )
-        elif SERPER_API_KEY:
-            return search_serper(SERPER_API_KEY, query)
-        else:
-            raise Exception("No search engine API key found in environment variables")
-    except Exception as e:
-        log.error(f"Web search failed: {e}")
-        return []
+    if SEARXNG_QUERY_URL:
+        return search_searxng(SEARXNG_QUERY_URL, query)
+    elif GOOGLE_PSE_API_KEY and GOOGLE_PSE_ENGINE_ID:
+        return search_google_pse(GOOGLE_PSE_API_KEY, GOOGLE_PSE_ENGINE_ID, query)
+    elif BRAVE_SEARCH_API_KEY:
+        return search_brave(BRAVE_SEARCH_API_KEY, query)
+    elif SERPSTACK_API_KEY:
+        return search_serpstack(SERPSTACK_API_KEY, query, https_enabled=SERPSTACK_HTTPS)
+    elif SERPER_API_KEY:
+        return search_serper(SERPER_API_KEY, query)
+    else:
+        raise Exception("No search engine API key found in environment variables")