浏览代码

Implement domain whitelisting for web search results

Que Nguyen 10 月之前
父节点
当前提交
7b5f434a07

+ 9 - 1
backend/apps/rag/main.py

@@ -111,6 +111,7 @@ from config import (
     YOUTUBE_LOADER_LANGUAGE,
     ENABLE_RAG_WEB_SEARCH,
     RAG_WEB_SEARCH_ENGINE,
+    RAG_WEB_SEARCH_WHITE_LIST_DOMAINS,
     SEARXNG_QUERY_URL,
     GOOGLE_PSE_API_KEY,
     GOOGLE_PSE_ENGINE_ID,
@@ -163,6 +164,7 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None
 
 app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
 app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
+app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS = RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
 
 app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
 app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
@@ -768,6 +770,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
                 app.state.config.SEARXNG_QUERY_URL,
                 query,
                 app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
+                app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
             )
         else:
             raise Exception("No SEARXNG_QUERY_URL found in environment variables")
@@ -781,6 +784,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
                 app.state.config.GOOGLE_PSE_ENGINE_ID,
                 query,
                 app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
+                app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
             )
         else:
             raise Exception(
@@ -792,6 +796,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
                 app.state.config.BRAVE_SEARCH_API_KEY,
                 query,
                 app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
+                app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
             )
         else:
             raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
@@ -801,6 +806,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
                 app.state.config.SERPSTACK_API_KEY,
                 query,
                 app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
+                app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS,
                 https_enabled=app.state.config.SERPSTACK_HTTPS,
             )
         else:
@@ -811,6 +817,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
                 app.state.config.SERPER_API_KEY,
                 query,
                 app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
+                app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
             )
         else:
             raise Exception("No SERPER_API_KEY found in environment variables")
@@ -820,11 +827,12 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
                 app.state.config.SERPLY_API_KEY,
                 query,
                 app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
+                app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
             )
         else:
             raise Exception("No SERPLY_API_KEY found in environment variables")
     elif engine == "duckduckgo":
-        return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
+        return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS)
     else:
         raise Exception("No search engine API key found in environment variables")
 

+ 5 - 4
backend/apps/rag/search/brave.py

@@ -1,15 +1,15 @@
 import logging
-
+from typing import List
 import requests
 
-from apps.rag.search.main import SearchResult
+from apps.rag.search.main import SearchResult, filter_by_whitelist
 from config import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
-def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]:
+def search_brave(api_key: str, query: str, whitelist:List[str], count: int) -> list[SearchResult]:
     """Search using Brave's Search API and return the results as a list of SearchResult objects.
 
     Args:
@@ -29,9 +29,10 @@ def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]:
 
     json_response = response.json()
     results = json_response.get("web", {}).get("results", [])
+    filtered_results = filter_by_whitelist(results, whitelist)
     return [
         SearchResult(
             link=result["url"], title=result.get("title"), snippet=result.get("snippet")
         )
-        for result in results[:count]
+        for result in filtered_results[:count]
     ]

+ 6 - 5
backend/apps/rag/search/duckduckgo.py

@@ -1,6 +1,6 @@
 import logging
-
-from apps.rag.search.main import SearchResult
+from typing import List
+from apps.rag.search.main import SearchResult, filter_by_whitelist
 from duckduckgo_search import DDGS
 from config import SRC_LOG_LEVELS
 
@@ -8,7 +8,7 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
-def search_duckduckgo(query: str, count: int) -> list[SearchResult]:
+def search_duckduckgo(query: str, count: int, whitelist:List[str]) -> list[SearchResult]:
     """
     Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
     Args:
@@ -41,6 +41,7 @@ def search_duckduckgo(query: str, count: int) -> list[SearchResult]:
                 snippet=result.get("body"),
             )
         )
-    print(results)
+    # print(results)
+    filtered_results = filter_by_whitelist(results, whitelist)
     # Return the list of search results
-    return results
+    return filtered_results

+ 5 - 4
backend/apps/rag/search/google_pse.py

@@ -1,9 +1,9 @@
 import json
 import logging
-
+from typing import List
 import requests
 
-from apps.rag.search.main import SearchResult
+from apps.rag.search.main import SearchResult, filter_by_whitelist
 from config import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
@@ -11,7 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
 def search_google_pse(
-    api_key: str, search_engine_id: str, query: str, count: int
+    api_key: str, search_engine_id: str, query: str, count: int, whitelist:List[str]
 ) -> list[SearchResult]:
     """Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
 
@@ -35,11 +35,12 @@ def search_google_pse(
 
     json_response = response.json()
     results = json_response.get("items", [])
+    filtered_results = filter_by_whitelist(results, whitelist)
     return [
         SearchResult(
             link=result["link"],
             title=result.get("title"),
             snippet=result.get("snippet"),
         )
-        for result in results
+        for result in filtered_results
     ]

+ 11 - 1
backend/apps/rag/search/main.py

@@ -1,8 +1,18 @@
 from typing import Optional
-
+from urllib.parse import urlparse
 from pydantic import BaseModel
 
 
+def filter_by_whitelist(results, whitelist):
+    if not whitelist:
+        return results
+    filtered_results = []
+    for result in results:
+        domain = urlparse(result["url"]).netloc
+        if any(domain.endswith(whitelisted_domain) for whitelisted_domain in whitelist):
+            filtered_results.append(result)
+    return filtered_results
+
 class SearchResult(BaseModel):
     link: str
     title: Optional[str]

+ 3 - 2
backend/apps/rag/search/searxng.py

@@ -11,7 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
 def search_searxng(
-    query_url: str, query: str, count: int, **kwargs
+    query_url: str, query: str, count: int, whitelist:List[str],  **kwargs
 ) -> List[SearchResult]:
     """
     Search a SearXNG instance for a given query and return the results as a list of SearchResult objects.
@@ -78,9 +78,10 @@ def search_searxng(
     json_response = response.json()
     results = json_response.get("results", [])
     sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True)
+    filtered_results = filter_by_whitelist(sorted_results, whitelist)
     return [
         SearchResult(
             link=result["url"], title=result.get("title"), snippet=result.get("content")
         )
-        for result in sorted_results[:count]
+        for result in filtered_results[:count]
     ]

+ 5 - 4
backend/apps/rag/search/serper.py

@@ -1,16 +1,16 @@
 import json
 import logging
-
+from typing import List
 import requests
 
-from apps.rag.search.main import SearchResult
+from apps.rag.search.main import SearchResult, filter_by_whitelist
 from config import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
-def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]:
+def search_serper(api_key: str, query: str, count: int, whitelist:List[str]) -> list[SearchResult]:
     """Search using serper.dev's API and return the results as a list of SearchResult objects.
 
     Args:
@@ -29,11 +29,12 @@ def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]:
     results = sorted(
         json_response.get("organic", []), key=lambda x: x.get("position", 0)
     )
+    filtered_results = filter_by_whitelist(results, whitelist)
     return [
         SearchResult(
             link=result["link"],
             title=result.get("title"),
             snippet=result.get("description"),
         )
-        for result in results[:count]
+        for result in filtered_results[:count]
     ]

+ 5 - 4
backend/apps/rag/search/serply.py

@@ -1,10 +1,10 @@
 import json
 import logging
-
+from typing import List
 import requests
 from urllib.parse import urlencode
 
-from apps.rag.search.main import SearchResult
+from apps.rag.search.main import SearchResult, filter_by_whitelist
 from config import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
@@ -15,6 +15,7 @@ def search_serply(
     api_key: str,
     query: str,
     count: int,
+    whitelist:List[str],
     hl: str = "us",
     limit: int = 10,
     device_type: str = "desktop",
@@ -57,12 +58,12 @@ def search_serply(
     results = sorted(
         json_response.get("results", []), key=lambda x: x.get("realPosition", 0)
     )
-
+    filtered_results = filter_by_whitelist(results, whitelist)
     return [
         SearchResult(
             link=result["link"],
             title=result.get("title"),
             snippet=result.get("description"),
         )
-        for result in results[:count]
+        for result in filtered_results[:count]
     ]

+ 5 - 4
backend/apps/rag/search/serpstack.py

@@ -1,9 +1,9 @@
 import json
 import logging
-
+from typing import List
 import requests
 
-from apps.rag.search.main import SearchResult
+from apps.rag.search.main import SearchResult, filter_by_whitelist
 from config import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
@@ -11,7 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
 def search_serpstack(
-    api_key: str, query: str, count: int, https_enabled: bool = True
+    api_key: str, query: str, count: int, whitelist:List[str], https_enabled: bool = True
 ) -> list[SearchResult]:
     """Search using serpstack.com's and return the results as a list of SearchResult objects.
 
@@ -35,9 +35,10 @@ def search_serpstack(
     results = sorted(
         json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
     )
+    filtered_results = filter_by_whitelist(results, whitelist)
     return [
         SearchResult(
             link=result["url"], title=result.get("title"), snippet=result.get("snippet")
         )
-        for result in results[:count]
+        for result in filtered_results[:count]
     ]

+ 9 - 0
backend/config.py

@@ -894,6 +894,15 @@ RAG_WEB_SEARCH_ENGINE = PersistentConfig(
     os.getenv("RAG_WEB_SEARCH_ENGINE", ""),
 )
 
+RAG_WEB_SEARCH_WHITE_LIST_DOMAINS = PersistentConfig(
+    "RAG_WEB_SEARCH_WHITE_LIST_DOMAINS",
+    "rag.rag_web_search_white_list_domains",
+    [
+        # "example.com", 
+        # "anotherdomain.com",        
+    ],
+)
+
 SEARXNG_QUERY_URL = PersistentConfig(
     "SEARXNG_QUERY_URL",
     "rag.web.search.searxng_query_url",