Przeglądaj źródła

support async load for websearch

Yimi81 2 miesięcy temu
rodzic
commit
ceef600223

+ 48 - 1
backend/open_webui/retrieval/web/utils.py

@@ -1,7 +1,9 @@
 import socket
+import aiohttp
+import asyncio
 import urllib.parse
 import validators
-from typing import Union, Sequence, Iterator
+from typing import Any, AsyncIterator, Dict, Iterator, List, Sequence, Union
 
 from langchain_community.document_loaders import (
     WebBaseLoader,
@@ -68,6 +70,31 @@ def resolve_hostname(hostname):
 class SafeWebBaseLoader(WebBaseLoader):
     """WebBaseLoader with enhanced error handling for URLs."""
 
+    def _unpack_fetch_results(
+        self, results: Any, urls: List[str], parser: Union[str, None] = None
+    ) -> List[Any]:
+        """Unpack fetch results into BeautifulSoup objects."""
+        from bs4 import BeautifulSoup
+
+        final_results = []
+        for i, result in enumerate(results):
+            url = urls[i]
+            if parser is None:
+                if url.endswith(".xml"):
+                    parser = "xml"
+                else:
+                    parser = self.default_parser
+                self._check_parser(parser)
+            final_results.append(BeautifulSoup(result, parser, **self.bs_kwargs))
+        return final_results
+
+    async def ascrape_all(
+        self, urls: List[str], parser: Union[str, None] = None
+    ) -> List[Any]:
+        """Async fetch all urls, then return soups for all results."""
+        results = await self.fetch_all(urls)
+        return self._unpack_fetch_results(results, urls, parser=parser)
+
     def lazy_load(self) -> Iterator[Document]:
         """Lazy load text from the url(s) in web_path with error handling."""
         for path in self.web_paths:
@@ -91,6 +118,26 @@ class SafeWebBaseLoader(WebBaseLoader):
                 # Log the error and continue with the next URL
                 log.error(f"Error loading {path}: {e}")
 
+    async def alazy_load(self) -> AsyncIterator[Document]:
+        """Async lazy load text from the url(s) in web_path."""
+        results = await self.ascrape_all(self.web_paths)
+        for path, soup in zip(self.web_paths, results):
+            text = soup.get_text(**self.bs_get_text_kwargs)
+            metadata = {"source": path}
+            if title := soup.find("title"):
+                metadata["title"] = title.get_text()
+            if description := soup.find("meta", attrs={"name": "description"}):
+                metadata["description"] = description.get(
+                    "content", "No description found."
+                )
+            if html := soup.find("html"):
+                metadata["language"] = html.get("lang", "No language found.")
+            yield Document(page_content=text, metadata=metadata)
+
+    async def aload(self) -> list[Document]:
+        """Load data into Document objects."""
+        return [document async for document in self.alazy_load()]
+
 
 def get_web_loader(
     urls: Union[str, Sequence[str]],

+ 11 - 4
backend/open_webui/routers/retrieval.py

@@ -21,6 +21,7 @@ from fastapi import (
     APIRouter,
 )
 from fastapi.middleware.cors import CORSMiddleware
+from fastapi.concurrency import run_in_threadpool
 from pydantic import BaseModel
 import tiktoken
 
@@ -1308,7 +1309,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
 
 
 @router.post("/process/web/search")
-def process_web_search(
+async def process_web_search(
     request: Request, form_data: SearchForm, user=Depends(get_verified_user)
 ):
     try:
@@ -1341,15 +1342,21 @@ def process_web_search(
             verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
             requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
         )
-        docs = loader.load()
-        save_docs_to_vector_db(
-            request, docs, collection_name, overwrite=True, user=user
+        docs = await loader.aload()
+        await run_in_threadpool(
+            save_docs_to_vector_db,
+            request,
+            docs,
+            collection_name,
+            overwrite=True,
+            user=user,
         )
 
         return {
             "status": True,
             "collection_name": collection_name,
             "filenames": urls,
+            "loaded_count": len(docs),
         }
     except Exception as e:
         log.exception(e)

+ 9 - 15
backend/open_webui/utils/middleware.py

@@ -334,21 +334,15 @@ async def chat_web_search_handler(
 
     try:
 
-        # Offload process_web_search to a separate thread
-        loop = asyncio.get_running_loop()
-        with ThreadPoolExecutor() as executor:
-            results = await loop.run_in_executor(
-                executor,
-                lambda: process_web_search(
-                    request,
-                    SearchForm(
-                        **{
-                            "query": searchQuery,
-                        }
-                    ),
-                    user,
-                ),
-            )
+        results = await process_web_search(
+            request,
+            SearchForm(
+                **{
+                    "query": searchQuery,
+                }
+            ),
+            user,
+        )
 
         if results:
             await event_emitter(