فهرست منبع

Merge pull request #9989 from Yimi81/feate-webloader-support-proxy

feat: web loader support proxy
Timothy Jaeryang Baek 2 ماه پیش
والد
کامیت
74c8690cd1
4فایلهای تغییر یافته به همراه57 افزوده شده و 1 حذف شده
  1. 5 0
      backend/open_webui/config.py
  2. 2 0
      backend/open_webui/main.py
  3. 44 1
      backend/open_webui/retrieval/web/utils.py
  4. 6 0
      backend/open_webui/routers/retrieval.py

+ 5 - 0
backend/open_webui/config.py

@@ -1853,6 +1853,11 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
     int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
     int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
 )
 )
 
 
+RAG_WEB_SEARCH_TRUST_ENV = PersistentConfig(
+    "RAG_WEB_SEARCH_TRUST_ENV",
+    "rag.web.search.trust_env",
+    os.getenv("RAG_WEB_SEARCH_TRUST_ENV", False),
+)
 
 
 ####################################
 ####################################
 # Images
 # Images

+ 2 - 0
backend/open_webui/main.py

@@ -175,6 +175,7 @@ from open_webui.config import (
     RAG_WEB_SEARCH_ENGINE,
     RAG_WEB_SEARCH_ENGINE,
     RAG_WEB_SEARCH_RESULT_COUNT,
     RAG_WEB_SEARCH_RESULT_COUNT,
     RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
     RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
+    RAG_WEB_SEARCH_TRUST_ENV,
     RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
     RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
     JINA_API_KEY,
     JINA_API_KEY,
     SEARCHAPI_API_KEY,
     SEARCHAPI_API_KEY,
@@ -558,6 +559,7 @@ app.state.config.EXA_API_KEY = EXA_API_KEY
 
 
 app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
 app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
 app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
 app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
+app.state.config.RAG_WEB_SEARCH_TRUST_ENV = RAG_WEB_SEARCH_TRUST_ENV
 
 
 app.state.EMBEDDING_FUNCTION = None
 app.state.EMBEDDING_FUNCTION = None
 app.state.ef = None
 app.state.ef = None

+ 44 - 1
backend/open_webui/retrieval/web/utils.py

@@ -5,6 +5,7 @@ import urllib.parse
 import validators
 import validators
 from typing import Any, AsyncIterator, Dict, Iterator, List, Sequence, Union
 from typing import Any, AsyncIterator, Dict, Iterator, List, Sequence, Union
 
 
+
 from langchain_community.document_loaders import (
 from langchain_community.document_loaders import (
     WebBaseLoader,
     WebBaseLoader,
 )
 )
@@ -70,6 +71,45 @@ def resolve_hostname(hostname):
 class SafeWebBaseLoader(WebBaseLoader):
 class SafeWebBaseLoader(WebBaseLoader):
     """WebBaseLoader with enhanced error handling for URLs."""
     """WebBaseLoader with enhanced error handling for URLs."""
 
 
+    def __init__(self, trust_env: bool = False, *args, **kwargs):
+        """Initialize SafeWebBaseLoader
+        Args:
+            trust_env (bool, optional): set to True if using proxy to make web requests, for example
+                using http(s)_proxy environment variables. Defaults to False.
+        """
+        super().__init__(*args, **kwargs)
+        self.trust_env = trust_env
+
+    async def _fetch(
+        self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
+    ) -> str:
+        async with aiohttp.ClientSession(trust_env=self.trust_env) as session:
+            for i in range(retries):
+                try:
+                    kwargs: Dict = dict(
+                        headers=self.session.headers,
+                        cookies=self.session.cookies.get_dict(),
+                    )
+                    if not self.session.verify:
+                        kwargs["ssl"] = False
+
+                    async with session.get(
+                        url, **(self.requests_kwargs | kwargs)
+                    ) as response:
+                        if self.raise_for_status:
+                            response.raise_for_status()
+                        return await response.text()
+                except aiohttp.ClientConnectionError as e:
+                    if i == retries - 1:
+                        raise
+                    else:
+                        log.warning(
+                            f"Error fetching {url} with attempt "
+                            f"{i + 1}/{retries}: {e}. Retrying..."
+                        )
+                        await asyncio.sleep(cooldown * backoff**i)
+        raise ValueError("retry count exceeded")
+
     def _unpack_fetch_results(
     def _unpack_fetch_results(
         self, results: Any, urls: List[str], parser: Union[str, None] = None
         self, results: Any, urls: List[str], parser: Union[str, None] = None
     ) -> List[Any]:
     ) -> List[Any]:
@@ -95,6 +135,7 @@ class SafeWebBaseLoader(WebBaseLoader):
         results = await self.fetch_all(urls)
         results = await self.fetch_all(urls)
         return self._unpack_fetch_results(results, urls, parser=parser)
         return self._unpack_fetch_results(results, urls, parser=parser)
 
 
+
     def lazy_load(self) -> Iterator[Document]:
     def lazy_load(self) -> Iterator[Document]:
         """Lazy load text from the url(s) in web_path with error handling."""
         """Lazy load text from the url(s) in web_path with error handling."""
         for path in self.web_paths:
         for path in self.web_paths:
@@ -143,13 +184,15 @@ def get_web_loader(
     urls: Union[str, Sequence[str]],
     urls: Union[str, Sequence[str]],
     verify_ssl: bool = True,
     verify_ssl: bool = True,
     requests_per_second: int = 2,
     requests_per_second: int = 2,
+    trust_env: bool = False,
 ):
 ):
     # Check if the URLs are valid
     # Check if the URLs are valid
     safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
     safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
 
 
     return SafeWebBaseLoader(
     return SafeWebBaseLoader(
-        safe_urls,
+        web_path=safe_urls,
         verify_ssl=verify_ssl,
         verify_ssl=verify_ssl,
         requests_per_second=requests_per_second,
         requests_per_second=requests_per_second,
         continue_on_failure=True,
         continue_on_failure=True,
+        trust_env=trust_env
     )
     )

+ 6 - 0
backend/open_webui/routers/retrieval.py

@@ -451,6 +451,7 @@ class WebSearchConfig(BaseModel):
     exa_api_key: Optional[str] = None
     exa_api_key: Optional[str] = None
     result_count: Optional[int] = None
     result_count: Optional[int] = None
     concurrent_requests: Optional[int] = None
     concurrent_requests: Optional[int] = None
+    trust_env: Optional[bool] = None
     domain_filter_list: Optional[List[str]] = []
     domain_filter_list: Optional[List[str]] = []
 
 
 
 
@@ -570,6 +571,9 @@ async def update_rag_config(
         request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
         request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
             form_data.web.search.concurrent_requests
             form_data.web.search.concurrent_requests
         )
         )
+        request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV = (
+            form_data.web.search.trust_env
+        )
         request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = (
         request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = (
             form_data.web.search.domain_filter_list
             form_data.web.search.domain_filter_list
         )
         )
@@ -622,6 +626,7 @@ async def update_rag_config(
                 "exa_api_key": request.app.state.config.EXA_API_KEY,
                 "exa_api_key": request.app.state.config.EXA_API_KEY,
                 "result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
                 "result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
                 "concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
                 "concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
+                "trust_env": request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
                 "domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
                 "domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
             },
             },
         },
         },
@@ -1341,6 +1346,7 @@ async def process_web_search(
             urls,
             urls,
             verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
             verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
             requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
             requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
+            trust_env=request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
         )
         )
         docs = await loader.aload()
         docs = await loader.aload()
         await run_in_threadpool(
         await run_in_threadpool(