Pārlūkot izejas kodu

Finalize incomplete merge to update playwright branch

Introduced feature parity for trust_env
Rory 2 mēneši atpakaļ
vecāks
revīzija
aa2b764d74
1 mainītis faili ar 37 papildinājumiem un 22 dzēšanām
  1. 37 22
      backend/open_webui/retrieval/web/utils.py

+ 37 - 22
backend/open_webui/retrieval/web/utils.py

@@ -1,30 +1,33 @@
 import asyncio
-from datetime import datetime, time, timedelta
+import logging
 import socket
 import ssl
-import aiohttp
-import asyncio
 import urllib.parse
+import urllib.request
+from collections import defaultdict
+from datetime import datetime, time, timedelta
+from typing import (
+    Any,
+    AsyncIterator,
+    Dict,
+    Iterator,
+    List,
+    Optional,
+    Sequence,
+    Union
+)
+import aiohttp
 import certifi
 import validators
-from collections import defaultdict
-from typing import AsyncIterator, Dict, List, Optional, Union, Sequence, Iterator
-from typing import Any, AsyncIterator, Dict, Iterator, List, Sequence, Union
-
-
 from langchain_community.document_loaders import (
-    WebBaseLoader,
-    PlaywrightURLLoader
+    PlaywrightURLLoader,
+    WebBaseLoader
 )
 from langchain_core.documents import Document
-
-
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH, PLAYWRIGHT_WS_URI, RAG_WEB_LOADER
 from open_webui.env import SRC_LOG_LEVELS
 
-import logging
-
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
@@ -91,18 +94,20 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
     """Load HTML pages safely with Playwright, supporting SSL verification, rate limiting, and remote browser connection.
     
     Attributes:
-        urls (List[str]): List of URLs to load.
+        web_paths (List[str]): List of URLs to load.
         verify_ssl (bool): If True, verify SSL certificates.
         requests_per_second (Optional[float]): Number of requests per second to limit to.
         continue_on_failure (bool): If True, continue loading other URLs on failure.
         headless (bool): If True, the browser will run in headless mode.
         playwright_ws_url (Optional[str]): WebSocket endpoint URI for remote browser connection.
+        trust_env (bool): If True, use proxy settings from environment variables.
     """
 
     def __init__(
         self,
-        urls: List[str],
+        web_paths: List[str],
         verify_ssl: bool = True,
+        trust_env: bool = False,
         requests_per_second: Optional[float] = None,
         continue_on_failure: bool = True,
         headless: bool = True,
@@ -111,9 +116,20 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
         playwright_ws_url: Optional[str] = None
     ):
         """Initialize with additional safety parameters and remote browser support."""
+
+        proxy_server = proxy.get('server') if proxy else None
+        if trust_env and not proxy_server:
+            env_proxies = urllib.request.getproxies()
+            env_proxy_server = env_proxies.get('https') or env_proxies.get('http')
+            if env_proxy_server:
+                if proxy:
+                    proxy['server'] = env_proxy_server
+                else:
+                    proxy = { 'server': env_proxy_server }
+
         # We'll set headless to False if using playwright_ws_url since it's handled by the remote browser
         super().__init__(
-            urls=urls,
+            urls=web_paths,
             continue_on_failure=continue_on_failure,
             headless=headless if playwright_ws_url is None else False,
             remove_selectors=remove_selectors,
@@ -123,6 +139,7 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
         self.requests_per_second = requests_per_second
         self.last_request_time = None
         self.playwright_ws_url = playwright_ws_url
+        self.trust_env = trust_env
 
     def lazy_load(self) -> Iterator[Document]:
         """Safely load URLs synchronously with support for remote browser."""
@@ -347,14 +364,12 @@ def get_web_loader(
     # Check if the URLs are valid
     safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
 
-
     web_loader_args = {
-        web_path=safe_urls,
-        "urls": safe_urls,
+        "web_paths": safe_urls,
         "verify_ssl": verify_ssl,
         "requests_per_second": requests_per_second,
         "continue_on_failure": True,
-        trust_env=trust_env
+        "trust_env": trust_env
     }
 
     if PLAYWRIGHT_WS_URI.value:
@@ -364,6 +379,6 @@ def get_web_loader(
     WebLoaderClass = RAG_WEB_LOADERS[RAG_WEB_LOADER.value]
     web_loader = WebLoaderClass(**web_loader_args)
 
-    log.debug("Using RAG_WEB_LOADER %s for %s URLs", web_loader.__class__.__name__, len(urls))
+    log.debug("Using RAG_WEB_LOADER %s for %s URLs", web_loader.__class__.__name__, len(safe_urls))
 
     return web_loader