瀏覽代碼

enh: RAG full context mode

Timothy Jaeryang Baek 2 月之前
父節點
當前提交
81715f6553

+ 8 - 2
backend/open_webui/config.py

@@ -1578,6 +1578,12 @@ ENABLE_RAG_HYBRID_SEARCH = PersistentConfig(
     os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true",
     os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true",
 )
 )
 
 
+RAG_FULL_CONTEXT = PersistentConfig(
+    "RAG_FULL_CONTEXT",
+    "rag.full_context",
+    os.getenv("RAG_FULL_CONTEXT", "False").lower() == "true",
+)
+
 RAG_FILE_MAX_COUNT = PersistentConfig(
 RAG_FILE_MAX_COUNT = PersistentConfig(
     "RAG_FILE_MAX_COUNT",
     "RAG_FILE_MAX_COUNT",
     "rag.file.max_count",
     "rag.file.max_count",
@@ -1929,7 +1935,7 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
 RAG_WEB_LOADER_ENGINE = PersistentConfig(
 RAG_WEB_LOADER_ENGINE = PersistentConfig(
     "RAG_WEB_LOADER_ENGINE",
     "RAG_WEB_LOADER_ENGINE",
     "rag.web.loader.engine",
     "rag.web.loader.engine",
-    os.environ.get("RAG_WEB_LOADER_ENGINE", "safe_web")
+    os.environ.get("RAG_WEB_LOADER_ENGINE", "safe_web"),
 )
 )
 
 
 RAG_WEB_SEARCH_TRUST_ENV = PersistentConfig(
 RAG_WEB_SEARCH_TRUST_ENV = PersistentConfig(
@@ -1941,7 +1947,7 @@ RAG_WEB_SEARCH_TRUST_ENV = PersistentConfig(
 PLAYWRIGHT_WS_URI = PersistentConfig(
 PLAYWRIGHT_WS_URI = PersistentConfig(
     "PLAYWRIGHT_WS_URI",
     "PLAYWRIGHT_WS_URI",
     "rag.web.loader.engine.playwright.ws.uri",
     "rag.web.loader.engine.playwright.ws.uri",
-    os.environ.get("PLAYWRIGHT_WS_URI", None)
+    os.environ.get("PLAYWRIGHT_WS_URI", None),
 )
 )
 
 
 ####################################
 ####################################

+ 3 - 0
backend/open_webui/main.py

@@ -156,6 +156,7 @@ from open_webui.config import (
     # Retrieval
     # Retrieval
     RAG_TEMPLATE,
     RAG_TEMPLATE,
     DEFAULT_RAG_TEMPLATE,
     DEFAULT_RAG_TEMPLATE,
+    RAG_FULL_CONTEXT,
     RAG_EMBEDDING_MODEL,
     RAG_EMBEDDING_MODEL,
     RAG_EMBEDDING_MODEL_AUTO_UPDATE,
     RAG_EMBEDDING_MODEL_AUTO_UPDATE,
     RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
     RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
@@ -519,6 +520,8 @@ app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
 app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
 app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
 app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
 app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
 
 
+
+app.state.config.RAG_FULL_CONTEXT = RAG_FULL_CONTEXT
 app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
 app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
 app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
 app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
     ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
     ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION

+ 86 - 26
backend/open_webui/retrieval/utils.py

@@ -84,6 +84,19 @@ def query_doc(
         raise e
         raise e
 
 
 
 
+def get_doc(collection_name: str, user: UserModel = None):
+    try:
+        result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
+
+        if result:
+            log.info(f"query_doc:result {result.ids} {result.metadatas}")
+
+        return result
+    except Exception as e:
+        print(e)
+        raise e
+
+
 def query_doc_with_hybrid_search(
 def query_doc_with_hybrid_search(
     collection_name: str,
     collection_name: str,
     query: str,
     query: str,
@@ -137,6 +150,24 @@ def query_doc_with_hybrid_search(
         raise e
         raise e
 
 
 
 
+def merge_get_results(get_results: list[dict]) -> dict:
+    # Initialize lists to store combined data
+    combined_documents = []
+    combined_metadatas = []
+
+    for data in get_results:
+        combined_documents.extend(data["documents"][0])
+        combined_metadatas.extend(data["metadatas"][0])
+
+    # Create the output dictionary
+    result = {
+        "documents": [combined_documents],
+        "metadatas": [combined_metadatas],
+    }
+
+    return result
+
+
 def merge_and_sort_query_results(
 def merge_and_sort_query_results(
     query_results: list[dict], k: int, reverse: bool = False
     query_results: list[dict], k: int, reverse: bool = False
 ) -> list[dict]:
 ) -> list[dict]:
@@ -194,6 +225,23 @@ def merge_and_sort_query_results(
     return result
     return result
 
 
 
 
+def get_all_items_from_collections(collection_names: list[str]) -> dict:
+    results = []
+
+    for collection_name in collection_names:
+        if collection_name:
+            try:
+                result = get_doc(collection_name=collection_name)
+                if result is not None:
+                    results.append(result.model_dump())
+            except Exception as e:
+                log.exception(f"Error when querying the collection: {e}")
+        else:
+            pass
+
+    return merge_get_results(results)
+
+
 def query_collection(
 def query_collection(
     collection_names: list[str],
     collection_names: list[str],
     queries: list[str],
     queries: list[str],
@@ -311,8 +359,11 @@ def get_sources_from_files(
     reranking_function,
     reranking_function,
     r,
     r,
     hybrid_search,
     hybrid_search,
+    full_context=False,
 ):
 ):
-    log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}")
+    log.debug(
+        f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
+    )
 
 
     extracted_collections = []
     extracted_collections = []
     relevant_contexts = []
     relevant_contexts = []
@@ -350,36 +401,45 @@ def get_sources_from_files(
                 log.debug(f"skipping {file} as it has already been extracted")
                 log.debug(f"skipping {file} as it has already been extracted")
                 continue
                 continue
 
 
-            try:
-                context = None
-                if file.get("type") == "text":
-                    context = file["content"]
-                else:
-                    if hybrid_search:
-                        try:
-                            context = query_collection_with_hybrid_search(
+            if full_context:
+                try:
+                    context = get_all_items_from_collections(collection_names)
+
+                    print("context", context)
+                except Exception as e:
+                    log.exception(e)
+
+            else:
+                try:
+                    context = None
+                    if file.get("type") == "text":
+                        context = file["content"]
+                    else:
+                        if hybrid_search:
+                            try:
+                                context = query_collection_with_hybrid_search(
+                                    collection_names=collection_names,
+                                    queries=queries,
+                                    embedding_function=embedding_function,
+                                    k=k,
+                                    reranking_function=reranking_function,
+                                    r=r,
+                                )
+                            except Exception as e:
+                                log.debug(
+                                    "Error when using hybrid search, using"
+                                    " non hybrid search as fallback."
+                                )
+
+                        if (not hybrid_search) or (context is None):
+                            context = query_collection(
                                 collection_names=collection_names,
                                 collection_names=collection_names,
                                 queries=queries,
                                 queries=queries,
                                 embedding_function=embedding_function,
                                 embedding_function=embedding_function,
                                 k=k,
                                 k=k,
-                                reranking_function=reranking_function,
-                                r=r,
-                            )
-                        except Exception as e:
-                            log.debug(
-                                "Error when using hybrid search, using"
-                                " non hybrid search as fallback."
                             )
                             )
-
-                    if (not hybrid_search) or (context is None):
-                        context = query_collection(
-                            collection_names=collection_names,
-                            queries=queries,
-                            embedding_function=embedding_function,
-                            k=k,
-                        )
-            except Exception as e:
-                log.exception(e)
+                except Exception as e:
+                    log.exception(e)
 
 
             extracted_collections.extend(collection_names)
             extracted_collections.extend(collection_names)
 
 

+ 10 - 1
backend/open_webui/routers/retrieval.py

@@ -351,6 +351,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
     return {
     return {
         "status": True,
         "status": True,
         "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
         "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
+        "RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
         "enable_google_drive_integration": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
         "enable_google_drive_integration": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
         "content_extraction": {
         "content_extraction": {
             "engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
             "engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
@@ -463,6 +464,7 @@ class WebConfig(BaseModel):
 
 
 
 
 class ConfigUpdateForm(BaseModel):
 class ConfigUpdateForm(BaseModel):
+    RAG_FULL_CONTEXT: Optional[bool] = None
     pdf_extract_images: Optional[bool] = None
     pdf_extract_images: Optional[bool] = None
     enable_google_drive_integration: Optional[bool] = None
     enable_google_drive_integration: Optional[bool] = None
     file: Optional[FileConfig] = None
     file: Optional[FileConfig] = None
@@ -482,6 +484,12 @@ async def update_rag_config(
         else request.app.state.config.PDF_EXTRACT_IMAGES
         else request.app.state.config.PDF_EXTRACT_IMAGES
     )
     )
 
 
+    request.app.state.config.RAG_FULL_CONTEXT = (
+        form_data.RAG_FULL_CONTEXT
+        if form_data.RAG_FULL_CONTEXT is not None
+        else request.app.state.config.RAG_FULL_CONTEXT
+    )
+
     request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = (
     request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = (
         form_data.enable_google_drive_integration
         form_data.enable_google_drive_integration
         if form_data.enable_google_drive_integration is not None
         if form_data.enable_google_drive_integration is not None
@@ -588,6 +596,7 @@ async def update_rag_config(
     return {
     return {
         "status": True,
         "status": True,
         "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
         "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
+        "RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
         "file": {
         "file": {
             "max_size": request.app.state.config.FILE_MAX_SIZE,
             "max_size": request.app.state.config.FILE_MAX_SIZE,
             "max_count": request.app.state.config.FILE_MAX_COUNT,
             "max_count": request.app.state.config.FILE_MAX_COUNT,
@@ -1379,7 +1388,7 @@ async def process_web_search(
                 docs,
                 docs,
                 collection_name,
                 collection_name,
                 overwrite=True,
                 overwrite=True,
-                user=user
+                user=user,
             )
             )
 
 
             return {
             return {

+ 2 - 2
backend/open_webui/utils/middleware.py

@@ -344,7 +344,7 @@ async def chat_web_search_handler(
                     "query": searchQuery,
                     "query": searchQuery,
                 }
                 }
             ),
             ),
-            user=user
+            user=user,
         )
         )
 
 
         if results:
         if results:
@@ -560,9 +560,9 @@ async def chat_completion_files_handler(
                         reranking_function=request.app.state.rf,
                         reranking_function=request.app.state.rf,
                         r=request.app.state.config.RELEVANCE_THRESHOLD,
                         r=request.app.state.config.RELEVANCE_THRESHOLD,
                         hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
                         hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
+                        full_context=request.app.state.config.RAG_FULL_CONTEXT,
                     ),
                     ),
                 )
                 )
-
         except Exception as e:
         except Exception as e:
             log.exception(e)
             log.exception(e)
 
 

+ 18 - 1
src/lib/components/admin/Settings/Documents.svelte

@@ -27,7 +27,6 @@
 	import SensitiveInput from '$lib/components/common/SensitiveInput.svelte';
 	import SensitiveInput from '$lib/components/common/SensitiveInput.svelte';
 	import Tooltip from '$lib/components/common/Tooltip.svelte';
 	import Tooltip from '$lib/components/common/Tooltip.svelte';
 	import Switch from '$lib/components/common/Switch.svelte';
 	import Switch from '$lib/components/common/Switch.svelte';
-	import { text } from '@sveltejs/kit';
 	import Textarea from '$lib/components/common/Textarea.svelte';
 	import Textarea from '$lib/components/common/Textarea.svelte';
 
 
 	const i18n = getContext('i18n');
 	const i18n = getContext('i18n');
@@ -56,6 +55,8 @@
 	let chunkOverlap = 0;
 	let chunkOverlap = 0;
 	let pdfExtractImages = true;
 	let pdfExtractImages = true;
 
 
+	let RAG_FULL_CONTEXT = false;
+
 	let enableGoogleDriveIntegration = false;
 	let enableGoogleDriveIntegration = false;
 
 
 	let OpenAIUrl = '';
 	let OpenAIUrl = '';
@@ -182,6 +183,7 @@
 				max_size: fileMaxSize === '' ? null : fileMaxSize,
 				max_size: fileMaxSize === '' ? null : fileMaxSize,
 				max_count: fileMaxCount === '' ? null : fileMaxCount
 				max_count: fileMaxCount === '' ? null : fileMaxCount
 			},
 			},
+			RAG_FULL_CONTEXT: RAG_FULL_CONTEXT,
 			chunk: {
 			chunk: {
 				text_splitter: textSplitter,
 				text_splitter: textSplitter,
 				chunk_overlap: chunkOverlap,
 				chunk_overlap: chunkOverlap,
@@ -242,6 +244,8 @@
 			chunkSize = res.chunk.chunk_size;
 			chunkSize = res.chunk.chunk_size;
 			chunkOverlap = res.chunk.chunk_overlap;
 			chunkOverlap = res.chunk.chunk_overlap;
 
 
+			RAG_FULL_CONTEXT = res.RAG_FULL_CONTEXT;
+
 			contentExtractionEngine = res.content_extraction.engine;
 			contentExtractionEngine = res.content_extraction.engine;
 			tikaServerUrl = res.content_extraction.tika_server_url;
 			tikaServerUrl = res.content_extraction.tika_server_url;
 			showTikaServerUrl = contentExtractionEngine === 'tika';
 			showTikaServerUrl = contentExtractionEngine === 'tika';
@@ -388,6 +392,19 @@
 					{/if}
 					{/if}
 				</button>
 				</button>
 			</div>
 			</div>
+
+			<div class=" py-0.5 flex w-full justify-between">
+				<div class=" self-center text-xs font-medium">{$i18n.t('Full Context Mode')}</div>
+				<div class="flex items-center relative">
+					<Tooltip
+						content={RAG_FULL_CONTEXT
+							? 'Inject entire contents as context for comprehensive processing, this is recommended for complex queries.'
+							: 'Default to segmented retrieval for focused and relevant content extraction, this is recommended for most cases.'}
+					>
+						<Switch bind:state={RAG_FULL_CONTEXT} />
+					</Tooltip>
+				</div>
+			</div>
 		</div>
 		</div>
 
 
 		<hr class="border-gray-100 dark:border-gray-850" />
 		<hr class="border-gray-100 dark:border-gray-850" />