浏览代码

refac: web search

Timothy Jaeryang Baek 4 月之前
父节点
当前提交
6b25139d4f

+ 1 - 0
backend/open_webui/main.py

@@ -856,6 +856,7 @@ async def chat_completion(
             "session_id": form_data.pop("session_id", None),
             "tool_ids": form_data.get("tool_ids", None),
             "files": form_data.get("files", None),
+            "features": form_data.get("features", None),
         }
         form_data["metadata"] = metadata
 

+ 5 - 4
backend/open_webui/routers/retrieval.py

@@ -1238,7 +1238,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
 
 
 @router.post("/process/web/search")
-def process_web_search(
+async def process_web_search(
     request: Request, form_data: SearchForm, user=Depends(get_verified_user)
 ):
     try:
@@ -1256,9 +1256,11 @@ def process_web_search(
             detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
         )
 
+    log.debug(f"web_results: {web_results}")
+
     try:
         collection_name = form_data.collection_name
-        if collection_name == "":
+        if collection_name == "" or collection_name is None:
             collection_name = f"web-search-{calculate_sha256_string(form_data.query)}"[
                 :63
             ]
@@ -1269,8 +1271,7 @@ def process_web_search(
             verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
             requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
         )
-        docs = loader.aload()
-
+        docs = loader.load()
         save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
 
         return {

+ 151 - 1
backend/open_webui/utils/middleware.py

@@ -29,6 +29,7 @@ from open_webui.routers.tasks import (
     generate_title,
     generate_chat_tags,
 )
+from open_webui.routers.retrieval import process_web_search, SearchForm
 from open_webui.utils.webhook import post_webhook
 
 
@@ -333,6 +334,149 @@ async def chat_completion_tools_handler(
     return body, {"sources": sources}
 
 
+async def chat_web_search_handler(
+    request: Request, form_data: dict, extra_params: dict, user
+):
+    event_emitter = extra_params["__event_emitter__"]
+    await event_emitter(
+        {
+            "type": "status",
+            "data": {
+                "action": "web_search",
+                "description": "Generating search query",
+                "done": False,
+            },
+        }
+    )
+
+    messages = form_data["messages"]
+    user_message = get_last_user_message(messages)
+
+    queries = []
+    try:
+        res = await generate_queries(
+            request,
+            {
+                "model": form_data["model"],
+                "messages": messages,
+                "prompt": user_message,
+                "type": "web_search",
+            },
+            user,
+        )
+
+        response = res["choices"][0]["message"]["content"]
+
+        try:
+            bracket_start = response.find("{")
+            bracket_end = response.rfind("}") + 1
+
+            if bracket_start == -1 or bracket_end == -1:
+                raise Exception("No JSON object found in the response")
+
+            response = response[bracket_start:bracket_end]
+            queries = json.loads(response)
+            queries = queries.get("queries", [])
+        except Exception as e:
+            queries = [response]
+
+    except Exception as e:
+        log.exception(e)
+        queries = [user_message]
+
+    if len(queries) == 0:
+        await event_emitter(
+            {
+                "type": "status",
+                "data": {
+                    "action": "web_search",
+                    "description": "No search query generated",
+                    "done": True,
+                },
+            }
+        )
+        return
+
+    searchQuery = queries[0]
+
+    await event_emitter(
+        {
+            "type": "status",
+            "data": {
+                "action": "web_search",
+                "description": 'Searching "{{searchQuery}}"',
+                "query": searchQuery,
+                "done": False,
+            },
+        }
+    )
+
+    try:
+        results = await process_web_search(
+            request,
+            SearchForm(
+                **{
+                    "query": searchQuery,
+                }
+            ),
+            user,
+        )
+
+        if results:
+            await event_emitter(
+                {
+                    "type": "status",
+                    "data": {
+                        "action": "web_search",
+                        "description": "Searched {{count}} sites",
+                        "query": searchQuery,
+                        "urls": results["filenames"],
+                        "done": True,
+                    },
+                }
+            )
+
+            files = form_data.get("files", [])
+            files.append(
+                {
+                    "collection_name": results["collection_name"],
+                    "name": searchQuery,
+                    "type": "web_search_results",
+                    "urls": results["filenames"],
+                }
+            )
+            form_data["files"] = files
+        else:
+            await event_emitter(
+                {
+                    "type": "status",
+                    "data": {
+                        "action": "web_search",
+                        "description": "No search results found",
+                        "query": searchQuery,
+                        "done": True,
+                        "error": True,
+                    },
+                }
+            )
+    except Exception as e:
+        log.exception(e)
+        await event_emitter(
+            {
+                "type": "status",
+                "data": {
+                    "action": "web_search",
+                    "description": 'Error searching "{{searchQuery}}"',
+                    "query": searchQuery,
+                    "done": True,
+                    "error": True,
+                },
+            }
+        )
+
+    return form_data
+
+
 async def chat_completion_files_handler(
     request: Request, body: dict, user: UserModel
 ) -> tuple[dict, dict[str, list]]:
@@ -456,7 +600,6 @@ async def process_chat_payload(request, form_data, metadata, user, model):
 
         knowledge_files = []
         for item in model_knowledge:
-            print(item)
             if item.get("collection_name"):
                 knowledge_files.append(
                     {
@@ -481,6 +624,13 @@ async def process_chat_payload(request, form_data, metadata, user, model):
         files.extend(knowledge_files)
         form_data["files"] = files
 
+    features = form_data.pop("features", None)
+    if features:
+        if "web_search" in features and features["web_search"]:
+            form_data = await chat_web_search_handler(
+                request, form_data, extra_params, user
+            )
+
     try:
         form_data, flags = await chat_completion_filter_functions_handler(
             request, form_data, model, extra_params

+ 6 - 93
src/lib/components/chat/Chat.svelte

@@ -1419,11 +1419,8 @@
 					const chatEventEmitter = await getChatEventEmitter(model.id, _chatId);
 
 					scrollToBottom();
-					if (webSearchEnabled) {
-						await getWebSearchResults(model.id, parentId, responseMessageId);
-					}
-
 					await sendPromptSocket(model, responseMessageId, _chatId);
+
 					if (chatEventEmitter) clearInterval(chatEventEmitter);
 				} else {
 					toast.error($i18n.t(`Model {{modelId}} not found`, { modelId }));
@@ -1533,8 +1530,12 @@
 							: undefined
 				},
 
-				tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
 				files: files.length > 0 ? files : undefined,
+				tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
+				features: {
+					web_search: webSearchEnabled
+				},
+
 				session_id: $socket?.id,
 				chat_id: $chatId,
 				id: responseMessageId,
@@ -1751,94 +1752,6 @@
 		}
 	};
 
-	const getWebSearchResults = async (
-		model: string,
-		parentId: string,
-		responseMessageId: string
-	) => {
-		// TODO: move this to the backend
-		const responseMessage = history.messages[responseMessageId];
-		const userMessage = history.messages[parentId];
-		const messages = createMessagesList(history.currentId);
-
-		responseMessage.statusHistory = [
-			{
-				done: false,
-				action: 'web_search',
-				description: $i18n.t('Generating search query')
-			}
-		];
-		history.messages[responseMessageId] = responseMessage;
-
-		const prompt = userMessage.content;
-		let queries = await generateQueries(
-			localStorage.token,
-			model,
-			messages.filter((message) => message?.content?.trim()),
-			prompt
-		).catch((error) => {
-			console.log(error);
-			return [prompt];
-		});
-
-		if (queries.length === 0) {
-			responseMessage.statusHistory.push({
-				done: true,
-				error: true,
-				action: 'web_search',
-				description: $i18n.t('No search query generated')
-			});
-			history.messages[responseMessageId] = responseMessage;
-			return;
-		}
-
-		const searchQuery = queries[0];
-
-		responseMessage.statusHistory.push({
-			done: false,
-			action: 'web_search',
-			description: $i18n.t(`Searching "{{searchQuery}}"`, { searchQuery })
-		});
-		history.messages[responseMessageId] = responseMessage;
-
-		const results = await processWebSearch(localStorage.token, searchQuery).catch((error) => {
-			console.log(error);
-			toast.error(error);
-
-			return null;
-		});
-
-		if (results) {
-			responseMessage.statusHistory.push({
-				done: true,
-				action: 'web_search',
-				description: $i18n.t('Searched {{count}} sites', { count: results.filenames.length }),
-				query: searchQuery,
-				urls: results.filenames
-			});
-
-			if (responseMessage?.files ?? undefined === undefined) {
-				responseMessage.files = [];
-			}
-
-			responseMessage.files.push({
-				collection_name: results.collection_name,
-				name: searchQuery,
-				type: 'web_search_results',
-				urls: results.filenames
-			});
-			history.messages[responseMessageId] = responseMessage;
-		} else {
-			responseMessage.statusHistory.push({
-				done: true,
-				error: true,
-				action: 'web_search',
-				description: 'No search results found'
-			});
-			history.messages[responseMessageId] = responseMessage;
-		}
-	};
-
 	const initChatHandler = async () => {
 		if (!$temporaryChatEnabled) {
 			chat = await createNewChat(localStorage.token, {

+ 16 - 2
src/lib/components/chat/Messages/ResponseMessage.svelte

@@ -535,7 +535,14 @@
 														? 'shimmer'
 														: ''} text-base line-clamp-1 text-wrap"
 												>
-													{status?.description}
+													<!-- $i18n.t('Searched {{count}} sites') -->
+													{#if status?.description.includes('{{count}}')}
+														{$i18n.t(status?.description, {
+															count: status?.urls.length
+														})}
+													{:else}
+														{$i18n.t(status?.description)}
+													{/if}
 												</div>
 											</div>
 										</WebSearchResults>
@@ -558,7 +565,14 @@
 													? 'shimmer'
 													: ''} text-gray-500 dark:text-gray-500 text-base line-clamp-1 text-wrap"
 											>
-												{status?.description}
+												<!-- $i18n.t(`Searching "{{searchQuery}}"`) -->
+												{#if status?.description.includes('{{searchQuery}}')}
+													{$i18n.t(status?.description, {
+														searchQuery: status?.query
+													})}
+												{:else}
+													{$i18n.t(status?.description)}
+												{/if}
 											</div>
 										</div>
 									{/if}