Browse Source

refac: web search

Timothy J. Baek 11 months ago
parent
commit
999d2bc21b

+ 77 - 11
backend/apps/rag/main.py

@@ -59,9 +59,16 @@ from apps.rag.utils import (
     query_doc_with_hybrid_search,
     query_doc_with_hybrid_search,
     query_collection,
     query_collection,
     query_collection_with_hybrid_search,
     query_collection_with_hybrid_search,
-    search_web,
 )
 )
 
 
+from apps.rag.search.brave import search_brave
+from apps.rag.search.google_pse import search_google_pse
+from apps.rag.search.main import SearchResult
+from apps.rag.search.searxng import search_searxng
+from apps.rag.search.serper import search_serper
+from apps.rag.search.serpstack import search_serpstack
+
+
 from utils.misc import (
 from utils.misc import (
     calculate_sha256,
     calculate_sha256,
     calculate_sha256_string,
     calculate_sha256_string,
@@ -716,19 +723,78 @@ def resolve_hostname(hostname):
     return ipv4_addresses, ipv6_addresses
     return ipv4_addresses, ipv6_addresses
 
 
 
 
+def search_web(engine: str, query: str) -> list[SearchResult]:
+    """Search the web using a search engine and return the results as a list of SearchResult objects.
+    Will look for a search engine API key in environment variables in the following order:
+    - SEARXNG_QUERY_URL
+    - GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
+    - BRAVE_SEARCH_API_KEY
+    - SERPSTACK_API_KEY
+    - SERPER_API_KEY
+
+    Args:
+        query (str): The query to search for
+    """
+
+    # TODO: add playwright to search the web
+    if engine == "searxng":
+        if app.state.config.SEARXNG_QUERY_URL:
+            return search_searxng(app.state.config.SEARXNG_QUERY_URL, query)
+        else:
+            raise Exception("No SEARXNG_QUERY_URL found in environment variables")
+    elif engine == "google_pse":
+        if (
+            app.state.config.GOOGLE_PSE_API_KEY
+            and app.state.config.GOOGLE_PSE_ENGINE_ID
+        ):
+            return search_google_pse(
+                app.state.config.GOOGLE_PSE_API_KEY,
+                app.state.config.GOOGLE_PSE_ENGINE_ID,
+                query,
+            )
+        else:
+            raise Exception(
+                "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
+            )
+    elif engine == "brave":
+        if app.state.config.BRAVE_SEARCH_API_KEY:
+            return search_brave(app.state.config.BRAVE_SEARCH_API_KEY, query)
+        else:
+            raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
+    elif engine == "serpstack":
+        if app.state.config.SERPSTACK_API_KEY:
+            return search_serpstack(
+                app.state.config.SERPSTACK_API_KEY,
+                query,
+                https_enabled=app.state.config.SERPSTACK_HTTPS,
+            )
+        else:
+            raise Exception("No SERPSTACK_API_KEY found in environment variables")
+    elif engine == "serper":
+        if app.state.config.SERPER_API_KEY:
+            return search_serper(app.state.config.SERPER_API_KEY, query)
+        else:
+            raise Exception("No SERPER_API_KEY found in environment variables")
+    else:
+        raise Exception("No search engine API key found in environment variables")
+
+
 @app.post("/web/search")
 @app.post("/web/search")
 def store_web_search(form_data: SearchForm, user=Depends(get_current_user)):
 def store_web_search(form_data: SearchForm, user=Depends(get_current_user)):
     try:
     try:
-        try:
-            web_results = search_web(
-                app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
-            )
-        except Exception as e:
-            log.exception(e)
-            raise HTTPException(
-                status_code=status.HTTP_400_BAD_REQUEST,
-                detail=ERROR_MESSAGES.WEB_SEARCH_ERROR,
-            )
+        web_results = search_web(
+            app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
+        )
+    except Exception as e:
+        log.exception(e)
+
+        print(e)
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
+        )
+
+    try:
         urls = [result.link for result in web_results]
         urls = [result.link for result in web_results]
         loader = get_web_loader(urls)
         loader = get_web_loader(urls)
         data = loader.load()
         data = loader.load()

+ 1 - 53
backend/apps/rag/utils.py

@@ -20,12 +20,7 @@ from langchain.retrievers import (
 
 
 from typing import Optional
 from typing import Optional
 
 
-from apps.rag.search.brave import search_brave
-from apps.rag.search.google_pse import search_google_pse
-from apps.rag.search.main import SearchResult
-from apps.rag.search.searxng import search_searxng
-from apps.rag.search.serper import search_serper
-from apps.rag.search.serpstack import search_serpstack
+
 from config import (
 from config import (
     SRC_LOG_LEVELS,
     SRC_LOG_LEVELS,
     CHROMA_CLIENT,
     CHROMA_CLIENT,
@@ -536,50 +531,3 @@ class RerankCompressor(BaseDocumentCompressor):
             )
             )
             final_results.append(doc)
             final_results.append(doc)
         return final_results
         return final_results
-
-
-def search_web(engine: str, query: str) -> list[SearchResult]:
-    """Search the web using a search engine and return the results as a list of SearchResult objects.
-    Will look for a search engine API key in environment variables in the following order:
-    - SEARXNG_QUERY_URL
-    - GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
-    - BRAVE_SEARCH_API_KEY
-    - SERPSTACK_API_KEY
-    - SERPER_API_KEY
-
-    Args:
-        query (str): The query to search for
-    """
-
-    # TODO: add playwright to search the web
-    if engine == "searxng":
-        if SEARXNG_QUERY_URL:
-            return search_searxng(SEARXNG_QUERY_URL, query)
-        else:
-            raise Exception("No SEARXNG_QUERY_URL found in environment variables")
-    elif engine == "google_pse":
-        if GOOGLE_PSE_API_KEY and GOOGLE_PSE_ENGINE_ID:
-            return search_google_pse(GOOGLE_PSE_API_KEY, GOOGLE_PSE_ENGINE_ID, query)
-        else:
-            raise Exception(
-                "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
-            )
-    elif engine == "brave":
-        if BRAVE_SEARCH_API_KEY:
-            return search_brave(BRAVE_SEARCH_API_KEY, query)
-        else:
-            raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
-    elif engine == "serpstack":
-        if SERPSTACK_API_KEY:
-            return search_serpstack(
-                SERPSTACK_API_KEY, query, https_enabled=SERPSTACK_HTTPS
-            )
-        else:
-            raise Exception("No SERPSTACK_API_KEY found in environment variables")
-    elif engine == "serper":
-        if SERPER_API_KEY:
-            return search_serper(SERPER_API_KEY, query)
-        else:
-            raise Exception("No SERPER_API_KEY found in environment variables")
-    else:
-        raise Exception("No search engine API key found in environment variables")

+ 1 - 1
backend/constants.py

@@ -82,5 +82,5 @@ class ERROR_MESSAGES(str, Enum):
     )
     )
 
 
     WEB_SEARCH_ERROR = (
     WEB_SEARCH_ERROR = (
-        "Oops! Something went wrong while searching the web. Please try again later."
+        lambda err="": f"{err if err else 'Oops! Something went wrong while searching the web.'}"
     )
     )

+ 12 - 3
src/lib/apis/rag/index.ts

@@ -518,8 +518,10 @@ export const runWebSearch = async (
 	token: string,
 	token: string,
 	query: string,
 	query: string,
 	collection_name?: string
 	collection_name?: string
-): Promise<SearchDocument | undefined> => {
-	return await fetch(`${RAG_API_BASE_URL}/web/search`, {
+): Promise<SearchDocument | null> => {
+	let error = null;
+
+	const res = await fetch(`${RAG_API_BASE_URL}/web/search`, {
 		method: 'POST',
 		method: 'POST',
 		headers: {
 		headers: {
 			'Content-Type': 'application/json',
 			'Content-Type': 'application/json',
@@ -536,8 +538,15 @@ export const runWebSearch = async (
 		})
 		})
 		.catch((err) => {
 		.catch((err) => {
 			console.log(err);
 			console.log(err);
-			return undefined;
+			error = err.detail;
+			return null;
 		});
 		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
 };
 };
 
 
 export interface SearchDocument {
 export interface SearchDocument {

+ 28 - 24
src/lib/components/chat/Chat.svelte

@@ -473,9 +473,34 @@
 		};
 		};
 		messages = messages;
 		messages = messages;
 
 
-		const results = await runWebSearch(localStorage.token, searchQuery);
-		if (results === undefined) {
-			toast.warning($i18n.t('No search results found'));
+		const results = await runWebSearch(localStorage.token, searchQuery).catch((error) => {
+			console.log(error);
+			toast.error(error);
+
+			return null;
+		});
+
+		if (results) {
+			responseMessage.status = {
+				...responseMessage.status,
+				done: true,
+				description: $i18n.t('Searched {{count}} sites', { count: results.filenames.length }),
+				urls: results.filenames
+			};
+
+			if (responseMessage?.files ?? undefined === undefined) {
+				responseMessage.files = [];
+			}
+
+			responseMessage.files.push({
+				collection_name: results.collection_name,
+				name: searchQuery,
+				type: 'web_search_results',
+				urls: results.filenames
+			});
+
+			messages = messages;
+		} else {
 			responseMessage.status = {
 			responseMessage.status = {
 				...responseMessage.status,
 				...responseMessage.status,
 				done: true,
 				done: true,
@@ -483,28 +508,7 @@
 				description: 'No search results found'
 				description: 'No search results found'
 			};
 			};
 			messages = messages;
 			messages = messages;
-			return;
-		}
-
-		responseMessage.status = {
-			...responseMessage.status,
-			done: true,
-			description: $i18n.t('Searched {{count}} sites', { count: results.filenames.length }),
-			urls: results.filenames
-		};
-
-		if (responseMessage?.files ?? undefined === undefined) {
-			responseMessage.files = [];
 		}
 		}
-
-		responseMessage.files.push({
-			collection_name: results.collection_name,
-			name: searchQuery,
-			type: 'web_search_results',
-			urls: results.filenames
-		});
-
-		messages = messages;
 	};
 	};
 
 
 	const sendPromptOllama = async (model, userPrompt, responseMessageId, _chatId) => {
 	const sendPromptOllama = async (model, userPrompt, responseMessageId, _chatId) => {