فهرست منبع

Merge pull request #3112 from que-nguyen/searxng

Domain whitelisting for web search results
Timothy Jaeryang Baek 10 ماه پیش
والد
کامیت
20f052eb37

+ 9 - 1
backend/apps/rag/main.py

@@ -112,6 +112,7 @@ from config import (
     YOUTUBE_LOADER_LANGUAGE,
     ENABLE_RAG_WEB_SEARCH,
     RAG_WEB_SEARCH_ENGINE,
+    RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
     SEARXNG_QUERY_URL,
     GOOGLE_PSE_API_KEY,
     GOOGLE_PSE_ENGINE_ID,
@@ -165,6 +166,7 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None
 
 app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
 app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
+app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
 
 app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
 app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
@@ -775,6 +777,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
                 app.state.config.SEARXNG_QUERY_URL,
                 query,
                 app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
+                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
             )
         else:
             raise Exception("No SEARXNG_QUERY_URL found in environment variables")
@@ -788,6 +791,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
                 app.state.config.GOOGLE_PSE_ENGINE_ID,
                 query,
                 app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
+                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
             )
         else:
             raise Exception(
@@ -799,6 +803,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
                 app.state.config.BRAVE_SEARCH_API_KEY,
                 query,
                 app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
+                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
             )
         else:
             raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
@@ -808,6 +813,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
                 app.state.config.SERPSTACK_API_KEY,
                 query,
                 app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
+                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
                 https_enabled=app.state.config.SERPSTACK_HTTPS,
             )
         else:
@@ -818,6 +824,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
                 app.state.config.SERPER_API_KEY,
                 query,
                 app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
+                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
             )
         else:
             raise Exception("No SERPER_API_KEY found in environment variables")
@@ -827,11 +834,12 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
                 app.state.config.SERPLY_API_KEY,
                 query,
                 app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
+                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
             )
         else:
             raise Exception("No SERPLY_API_KEY found in environment variables")
     elif engine == "duckduckgo":
-        return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
+        return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST)
     elif engine == "tavily":
         if app.state.config.TAVILY_API_KEY:
             return search_tavily(

+ 6 - 3
backend/apps/rag/search/brave.py

@@ -1,15 +1,15 @@
 import logging
-
+from typing import List, Optional
 import requests
 
-from apps.rag.search.main import SearchResult
+from apps.rag.search.main import SearchResult, get_filtered_results
 from config import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
-def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]:
+def search_brave(api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None) -> list[SearchResult]:
     """Search using Brave's Search API and return the results as a list of SearchResult objects.
 
     Args:
@@ -29,6 +29,9 @@ def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]:
 
     json_response = response.json()
     results = json_response.get("web", {}).get("results", [])
+    if filter_list:
+        results = get_filtered_results(results, filter_list)
+    
     return [
         SearchResult(
             link=result["url"], title=result.get("title"), snippet=result.get("snippet")

+ 5 - 4
backend/apps/rag/search/duckduckgo.py

@@ -1,6 +1,6 @@
 import logging
-
-from apps.rag.search.main import SearchResult
+from typing import List, Optional
+from apps.rag.search.main import SearchResult, get_filtered_results
 from duckduckgo_search import DDGS
 from config import SRC_LOG_LEVELS
 
@@ -8,7 +8,7 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
-def search_duckduckgo(query: str, count: int) -> list[SearchResult]:
+def search_duckduckgo(query: str, count: int, filter_list: Optional[List[str]] = None) -> list[SearchResult]:
     """
     Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
     Args:
@@ -41,6 +41,7 @@ def search_duckduckgo(query: str, count: int) -> list[SearchResult]:
                 snippet=result.get("body"),
             )
         )
-    print(results)
+    if filter_list:
+        results = get_filtered_results(results, filter_list)
     # Return the list of search results
     return results

+ 5 - 3
backend/apps/rag/search/google_pse.py

@@ -1,9 +1,9 @@
 import json
 import logging
-
+from typing import List, Optional
 import requests
 
-from apps.rag.search.main import SearchResult
+from apps.rag.search.main import SearchResult, get_filtered_results
 from config import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
@@ -11,7 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
 def search_google_pse(
-    api_key: str, search_engine_id: str, query: str, count: int
+    api_key: str, search_engine_id: str, query: str, count: int, filter_list: Optional[List[str]] = None
 ) -> list[SearchResult]:
     """Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
 
@@ -35,6 +35,8 @@ def search_google_pse(
 
     json_response = response.json()
     results = json_response.get("items", [])
+    if filter_list:
+        results = get_filtered_results(results, filter_list)
     return [
         SearchResult(
             link=result["link"],

+ 11 - 1
backend/apps/rag/search/main.py

@@ -1,8 +1,18 @@
 from typing import Optional
-
+from urllib.parse import urlparse
 from pydantic import BaseModel
 
 
+def get_filtered_results(results, filter_list):
+    if not filter_list:
+        return results
+    filtered_results = []
+    for result in results:
+        domain = urlparse(result["url"]).netloc
+        if any(domain.endswith(filtered_domain) for filtered_domain in filter_list):
+            filtered_results.append(result)
+    return filtered_results
+
 class SearchResult(BaseModel):
     link: str
     title: Optional[str]

+ 5 - 3
backend/apps/rag/search/searxng.py

@@ -1,9 +1,9 @@
 import logging
 import requests
 
-from typing import List
+from typing import List, Optional
 
-from apps.rag.search.main import SearchResult
+from apps.rag.search.main import SearchResult, get_filtered_results
 from config import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
@@ -11,7 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
 def search_searxng(
-    query_url: str, query: str, count: int, **kwargs
+    query_url: str, query: str, count: int, filter_list: Optional[List[str]] = None,  **kwargs
 ) -> List[SearchResult]:
     """
     Search a SearXNG instance for a given query and return the results as a list of SearchResult objects.
@@ -78,6 +78,8 @@ def search_searxng(
     json_response = response.json()
     results = json_response.get("results", [])
     sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True)
+    if filter_list:
+        sorted_results = get_filtered_results(sorted_results, whitelist)
     return [
         SearchResult(
             link=result["url"], title=result.get("title"), snippet=result.get("content")

+ 5 - 3
backend/apps/rag/search/serper.py

@@ -1,16 +1,16 @@
 import json
 import logging
-
+from typing import List, Optional
 import requests
 
-from apps.rag.search.main import SearchResult
+from apps.rag.search.main import SearchResult, get_filtered_results
 from config import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
-def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]:
+def search_serper(api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None) -> list[SearchResult]:
     """Search using serper.dev's API and return the results as a list of SearchResult objects.
 
     Args:
@@ -29,6 +29,8 @@ def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]:
     results = sorted(
         json_response.get("organic", []), key=lambda x: x.get("position", 0)
     )
+    if filter_list:
+        results = get_filtered_results(results, filter_list)
     return [
         SearchResult(
             link=result["link"],

+ 5 - 3
backend/apps/rag/search/serply.py

@@ -1,10 +1,10 @@
 import json
 import logging
-
+from typing import List, Optional
 import requests
 from urllib.parse import urlencode
 
-from apps.rag.search.main import SearchResult
+from apps.rag.search.main import SearchResult, get_filtered_results
 from config import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
@@ -19,6 +19,7 @@ def search_serply(
     limit: int = 10,
     device_type: str = "desktop",
     proxy_location: str = "US",
+    filter_list: Optional[List[str]] = None,
 ) -> list[SearchResult]:
     """Search using serper.dev's API and return the results as a list of SearchResult objects.
 
@@ -57,7 +58,8 @@ def search_serply(
     results = sorted(
         json_response.get("results", []), key=lambda x: x.get("realPosition", 0)
     )
-
+    if filter_list:
+        results = get_filtered_results(results, filter_list)
     return [
         SearchResult(
             link=result["link"],

+ 5 - 3
backend/apps/rag/search/serpstack.py

@@ -1,9 +1,9 @@
 import json
 import logging
-
+from typing import List, Optional
 import requests
 
-from apps.rag.search.main import SearchResult
+from apps.rag.search.main import SearchResult, get_filtered_results
 from config import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
@@ -11,7 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
 def search_serpstack(
-    api_key: str, query: str, count: int, https_enabled: bool = True
+    api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None, https_enabled: bool = True
 ) -> list[SearchResult]:
     """Search using serpstack.com's and return the results as a list of SearchResult objects.
 
@@ -35,6 +35,8 @@ def search_serpstack(
     results = sorted(
         json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
     )
+    if filter_list:
+        results = get_filtered_results(results, filter_list)
     return [
         SearchResult(
             link=result["url"], title=result.get("title"), snippet=result.get("snippet")

+ 12 - 0
backend/config.py

@@ -903,6 +903,18 @@ RAG_WEB_SEARCH_ENGINE = PersistentConfig(
     os.getenv("RAG_WEB_SEARCH_ENGINE", ""),
 )
 
+# You can provide a list of your own websites to filter after performing a web search. 
+# This ensures the highest level of safety and reliability of the information sources.
+RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
+    "RAG_WEB_SEARCH_DOMAIN_FILTER_LIST",
+    "rag.rag.web.search.domain.filter_list",
+    [
+        # "wikipedia.com", 
+        # "wikimedia.org",        
+        # "wikidata.org",
+    ],
+)
+
 SEARXNG_QUERY_URL = PersistentConfig(
     "SEARXNG_QUERY_URL",
     "rag.web.search.searxng_query_url",